Skip to content

Commit a1a3763

Browse files
committed
make existing script loading and new preload code use same code for loading modules
limit extension preload scripts to just one file named preload.py
1 parent e5690d0 commit a1a3763

File tree

4 files changed

+53
-53
lines changed

4 files changed

+53
-53
lines changed

modules/extensions.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import sys
33
import traceback
4-
from importlib.machinery import SourceFileLoader
54

65
import git
76

@@ -85,23 +84,3 @@ def list_extensions():
8584
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
8685
extensions.append(extension)
8786

88-
89-
def preload_extensions(parser):
90-
if not os.path.isdir(extensions_dir):
91-
return
92-
93-
for dirname in sorted(os.listdir(extensions_dir)):
94-
path = os.path.join(extensions_dir, dirname)
95-
if not os.path.isdir(path):
96-
continue
97-
for file in os.listdir(path):
98-
if "preload.py" in file:
99-
full_file = os.path.join(path, file)
100-
print(f"Got preload file: {full_file}")
101-
102-
try:
103-
ext = SourceFileLoader("preload", full_file).load_module()
104-
parser = ext.preload(parser)
105-
except Exception as e:
106-
print(f"Exception preloading script: {e}")
107-
return parser

modules/script_loading.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import sys
3+
import traceback
4+
from types import ModuleType
5+
6+
7+
def load_module(path):
8+
with open(path, "r", encoding="utf8") as file:
9+
text = file.read()
10+
11+
compiled = compile(text, path, 'exec')
12+
module = ModuleType(os.path.basename(path))
13+
exec(compiled, module.__dict__)
14+
15+
return module
16+
17+
18+
def preload_extensions(extensions_dir, parser):
19+
if not os.path.isdir(extensions_dir):
20+
return
21+
22+
for dirname in sorted(os.listdir(extensions_dir)):
23+
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
24+
if not os.path.isfile(preload_script):
25+
continue
26+
27+
try:
28+
module = load_module(preload_script)
29+
if hasattr(module, 'preload'):
30+
module.preload(parser)
31+
32+
except Exception:
33+
print(f"Error running preload() for {preload_script}", file=sys.stderr)
34+
print(traceback.format_exc(), file=sys.stderr)

modules/scripts.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import gradio as gr
77

88
from modules.processing import StableDiffusionProcessing
9-
from modules import shared, paths, script_callbacks, extensions
9+
from modules import shared, paths, script_callbacks, extensions, script_loading
1010

1111
AlwaysVisible = object()
1212

@@ -161,13 +161,7 @@ def load_scripts():
161161
sys.path = [scriptfile.basedir] + sys.path
162162
current_basedir = scriptfile.basedir
163163

164-
with open(scriptfile.path, "r", encoding="utf8") as file:
165-
text = file.read()
166-
167-
from types import ModuleType
168-
compiled = compile(text, scriptfile.path, 'exec')
169-
module = ModuleType(scriptfile.filename)
170-
exec(compiled, module.__dict__)
164+
module = script_loading.load_module(scriptfile.path)
171165

172166
for key, script_class in module.__dict__.items():
173167
if type(script_class) == type and issubclass(script_class, Script):
@@ -328,27 +322,21 @@ def postprocess(self, p, processed):
328322

329323
def reload_sources(self, cache):
330324
for si, script in list(enumerate(self.scripts)):
331-
with open(script.filename, "r", encoding="utf8") as file:
332-
args_from = script.args_from
333-
args_to = script.args_to
334-
filename = script.filename
335-
text = file.read()
336-
337-
from types import ModuleType
338-
339-
module = cache.get(filename, None)
340-
if module is None:
341-
compiled = compile(text, filename, 'exec')
342-
module = ModuleType(script.filename)
343-
exec(compiled, module.__dict__)
344-
cache[filename] = module
345-
346-
for key, script_class in module.__dict__.items():
347-
if type(script_class) == type and issubclass(script_class, Script):
348-
self.scripts[si] = script_class()
349-
self.scripts[si].filename = filename
350-
self.scripts[si].args_from = args_from
351-
self.scripts[si].args_to = args_to
325+
args_from = script.args_from
326+
args_to = script.args_to
327+
filename = script.filename
328+
329+
module = cache.get(filename, None)
330+
if module is None:
331+
module = script_loading.load_module(script.filename)
332+
cache[filename] = module
333+
334+
for key, script_class in module.__dict__.items():
335+
if type(script_class) == type and issubclass(script_class, Script):
336+
self.scripts[si] = script_class()
337+
self.scripts[si].filename = filename
338+
self.scripts[si].args_from = args_from
339+
self.scripts[si].args_to = args_to
352340

353341

354342
scripts_txt2img = ScriptRunner()

modules/shared.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import os
55
import sys
6-
from collections import OrderedDict
76
import time
87

98
import gradio as gr
@@ -15,7 +14,7 @@
1514
import modules.sd_models
1615
import modules.styles
1716
import modules.devices as devices
18-
from modules import sd_samplers, sd_models, localization, sd_vae, extensions
17+
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
1918
from modules.hypernetworks import hypernetwork
2019
from modules.paths import models_path, script_path, sd_path
2120

@@ -91,7 +90,7 @@
9190
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
9291
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
9392

94-
extensions.preload_extensions(parser)
93+
script_loading.preload_extensions(extensions.extensions_dir, parser)
9594

9695
cmd_opts = parser.parse_args()
9796

0 commit comments

Comments
 (0)