Skip to content

Commit d35bf64

Browse files
committed
make launch.py run installers for extensions that have ones
add some more classes to safety module for an extension
1 parent f126986 commit d35bf64

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

launch.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import platform
88

99
dir_repos = "repositories"
10+
dir_extensions = "extensions"
1011
python = sys.executable
1112
git = os.environ.get('GIT', "git")
1213
index_url = os.environ.get('INDEX_URL', "")
@@ -101,9 +102,24 @@ def version_check(commit):
101102
else:
102103
print("Not a git clone, can't perform version check.")
103104
except Exception as e:
104-
print("versipm check failed",e)
105+
print("version check failed", e)
106+
107+
108+
def run_extensions_installers():
109+
if not os.path.isdir(dir_extensions):
110+
return
111+
112+
for dirname_extension in os.listdir(dir_extensions):
113+
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
114+
if not os.path.isfile(path_installer):
115+
continue
116+
117+
try:
118+
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}"))
119+
except Exception as e:
120+
print(e, file=sys.stderr)
121+
105122

106-
107123
def prepare_enviroment():
108124
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
109125
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
@@ -189,6 +205,8 @@ def prepare_enviroment():
189205

190206
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
191207

208+
run_extensions_installers()
209+
192210
if update_check:
193211
version_check(commit)
194212

modules/safe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def find_class(self, module, name):
3232
return getattr(collections, name)
3333
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
3434
return getattr(torch._utils, name)
35-
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
35+
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
3636
return getattr(torch, name)
3737
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
3838
return getattr(torch.nn.modules.container, name)

0 commit comments

Comments
 (0)