Skip to content

Commit fc5c14a

Browse files
committed
fix: update model reference import and adjust model resolution logic
1 parent d73a39f commit fc5c14a

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

horde_sdk/ai_horde_worker/model_meta.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY
55
from horde_model_reference.model_reference_manager import ModelReferenceManager
6-
from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord
6+
from horde_model_reference.model_reference_records import ImageGenerationModelRecord
77
from loguru import logger
88

99
from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient
@@ -176,9 +176,9 @@ def resolve_all_model_names(
176176
"""
177177
all_model_references = self._model_reference_manager.get_all_model_references()
178178

179-
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]
179+
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.image_generation]
180180

181-
all_models = set(sd_model_references.root.keys()) if sd_model_references is not None else set()
181+
all_models = set(sd_model_references.keys()) if sd_model_references is not None else set()
182182

183183
if not ignore_large_models_env_var:
184184
all_models = self.remove_large_models(all_models)
@@ -198,21 +198,21 @@ def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]:
198198
"""
199199
all_model_references = self._model_reference_manager.get_all_model_references()
200200

201-
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]
201+
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.image_generation]
202202

203203
found_models: set[str] = set()
204204

205205
if sd_model_references is None:
206206
logger.error("No stable diffusion models found in model reference.")
207207
return found_models
208208

209-
for model in sd_model_references.root.values():
210-
if not isinstance(model, StableDiffusion_ModelRecord):
211-
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
209+
for model_name, model in sd_model_references.items():
210+
if not isinstance(model, ImageGenerationModelRecord):
211+
logger.error(f"Model {model_name} is not a ImageGenerationModelRecord")
212212
continue
213213

214214
if model.nsfw == nsfw:
215-
found_models.add(model.name)
215+
found_models.add(model_name)
216216

217217
return found_models
218218

@@ -240,21 +240,21 @@ def resolve_all_inpainting_models(self) -> set[str]:
240240
"""
241241
all_model_references = self._model_reference_manager.get_all_model_references()
242242

243-
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]
243+
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.image_generation]
244244

245245
found_models: set[str] = set()
246246

247247
if sd_model_references is None:
248248
logger.error("No stable diffusion models found in model reference.")
249249
return found_models
250250

251-
for model in sd_model_references.root.values():
252-
if not isinstance(model, StableDiffusion_ModelRecord):
253-
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
251+
for model_name, model in sd_model_references.items():
252+
if not isinstance(model, ImageGenerationModelRecord):
253+
logger.error(f"Model {model_name} is not a ImageGenerationModelRecord")
254254
continue
255255

256256
if model.inpainting:
257-
found_models.add(model.name)
257+
found_models.add(model_name)
258258

259259
return found_models
260260

@@ -269,21 +269,21 @@ def resolve_all_models_of_baseline(self, baseline: str) -> set[str]:
269269
"""
270270
all_model_references = self._model_reference_manager.get_all_model_references()
271271

272-
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]
272+
sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.image_generation]
273273

274274
found_models: set[str] = set()
275275

276276
if sd_model_references is None:
277277
logger.error("No stable diffusion models found in model reference.")
278278
return found_models
279279

280-
for model in sd_model_references.root.values():
281-
if not isinstance(model, StableDiffusion_ModelRecord):
282-
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
280+
for model_name, model in sd_model_references.items():
281+
if not isinstance(model, ImageGenerationModelRecord):
282+
logger.error(f"Model {model_name} is not a ImageGenerationModelRecord")
283283
continue
284284

285285
if model.baseline == baseline:
286-
found_models.add(model.name)
286+
found_models.add(model_name)
287287

288288
return found_models
289289

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
horde_model_reference~=0.9.0
1+
horde_model_reference>=2.0.0
22

3-
pydantic==2.9.2
3+
pydantic>=2.9.2
44
requests
55
StrEnum
66
loguru

0 commit comments

Comments
 (0)