Skip to content

Commit e0ac0e3

Browse files
authored
fix: resolve 'all' models when asked (#321)
* fix: resolve 'all' models when asked The dynamic I introduced with `AI_HORDE_MODEL_META_LARGE_MODELS` was confusing. Unfortunately, this is undoubtably going to create problems when workers start offering flux when they didn't intend to, but that's the cost of progress. * style: fix
1 parent a01b4a1 commit e0ac0e3

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

horde_sdk/ai_horde_worker/model_meta.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def resolve_meta_instructions(
5757
for possible_instruction in possible_meta_instructions:
5858
# If the instruction is to load all models, return all model names
5959
if ImageModelLoadResolver.meta_instruction_regex_match(MetaInstruction.ALL_REGEX, possible_instruction):
60-
return self.resolve_all_model_names()
60+
return self.remove_large_models(set(self.resolve_all_model_names()))
6161

6262
# If the instruction is to load the top N models, add the top N model names
6363
top_n_matches = ImageModelLoadResolver.meta_instruction_regex_match(
@@ -161,9 +161,16 @@ def remove_large_models(self, models: set[str]) -> set[str]:
161161
models = models - cascade_models - flux_models
162162
return models
163163

164-
def resolve_all_model_names(self) -> set[str]:
164+
def resolve_all_model_names(
165+
self,
166+
ignore_large_models_env_var: bool = True,
167+
) -> set[str]:
165168
"""Get the names of all models defined in the model reference.
166169
170+
Args:
171+
ignore_large_models_env_var (bool): A boolean representing whether to ignore the environment variable for
172+
large models, effectively returning all models regardless of `AI_HORDE_MODEL_META_LARGE_MODELS`.
173+
167174
Returns:
168175
A set of strings representing the names of all models.
169176
"""
@@ -173,7 +180,8 @@ def resolve_all_model_names(self) -> set[str]:
173180

174181
all_models = set(sd_model_references.root.keys()) if sd_model_references is not None else set()
175182

176-
all_models = self.remove_large_models(all_models)
183+
if not ignore_large_models_env_var:
184+
all_models = self.remove_large_models(all_models)
177185

178186
if not all_models:
179187
logger.error("No stable diffusion models found in model reference.")

tests/ai_horde_worker/test_model_meta_api_calls.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,33 @@ def test_image_model_load_resolver_all(image_model_load_resolver: ImageModelLoad
3131

3232
import os
3333

34-
os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "true"
35-
36-
all_model_names_with_large = image_model_load_resolver.resolve_all_model_names()
34+
os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "1"
35+
all_model_names_with_large = image_model_load_resolver.resolve_all_model_names(ignore_large_models_env_var=False)
3736

3837
del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"]
3938

40-
assert len(all_model_names_with_large) > len(all_model_names)
39+
assert len(all_model_names_with_large) == len(all_model_names)
40+
41+
42+
def test_image_model_load_resolver_all_without_large(image_model_load_resolver: ImageModelLoadResolver) -> None:
43+
import os
44+
45+
all_model_names = image_model_load_resolver.resolve_all_model_names()
46+
47+
assert len(all_model_names) > 0
48+
49+
stored_value = os.environ.get("AI_HORDE_MODEL_META_LARGE_MODELS")
50+
51+
if "AI_HORDE_MODEL_META_LARGE_MODELS" in os.environ:
52+
del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"]
53+
54+
all_model_names_without_large = image_model_load_resolver.resolve_all_model_names(
55+
ignore_large_models_env_var=False,
56+
)
57+
58+
if stored_value is not None:
59+
os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = stored_value
60+
assert len(all_model_names) > len(all_model_names_without_large)
4161

4262

4363
def test_image_model_load_resolver_top_n(

0 commit comments

Comments
 (0)