33
44from horde_model_reference .meta_consts import MODEL_REFERENCE_CATEGORY , STABLE_DIFFUSION_BASELINE_CATEGORY
55from horde_model_reference .model_reference_manager import ModelReferenceManager
6- from horde_model_reference .model_reference_records import StableDiffusion_ModelRecord
6+ from horde_model_reference .model_reference_records import ImageGenerationModelRecord
77from loguru import logger
88
99from horde_sdk .ai_horde_api .ai_horde_clients import AIHordeAPIManualClient
@@ -176,9 +176,9 @@ def resolve_all_model_names(
176176 """
177177 all_model_references = self ._model_reference_manager .get_all_model_references ()
178178
179- sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .stable_diffusion ]
179+ sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .image_generation ]
180180
181- all_models = set (sd_model_references .root . keys ()) if sd_model_references is not None else set ()
181+ all_models = set (sd_model_references .keys ()) if sd_model_references is not None else set ()
182182
183183 if not ignore_large_models_env_var :
184184 all_models = self .remove_large_models (all_models )
@@ -198,21 +198,21 @@ def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]:
198198 """
199199 all_model_references = self ._model_reference_manager .get_all_model_references ()
200200
201- sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .stable_diffusion ]
201+ sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .image_generation ]
202202
203203 found_models : set [str ] = set ()
204204
205205 if sd_model_references is None :
206206 logger .error ("No stable diffusion models found in model reference." )
207207 return found_models
208208
209- for model in sd_model_references .root . values ():
210- if not isinstance (model , StableDiffusion_ModelRecord ):
211- logger .error (f"Model { model } is not a StableDiffusion_ModelRecord " )
209+ for model_name , model in sd_model_references .items ():
210+ if not isinstance (model , ImageGenerationModelRecord ):
211+ logger .error (f"Model { model_name } is not a ImageGenerationModelRecord " )
212212 continue
213213
214214 if model .nsfw == nsfw :
215- found_models .add (model . name )
215+ found_models .add (model_name )
216216
217217 return found_models
218218
@@ -240,21 +240,21 @@ def resolve_all_inpainting_models(self) -> set[str]:
240240 """
241241 all_model_references = self ._model_reference_manager .get_all_model_references ()
242242
243- sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .stable_diffusion ]
243+ sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .image_generation ]
244244
245245 found_models : set [str ] = set ()
246246
247247 if sd_model_references is None :
248248 logger .error ("No stable diffusion models found in model reference." )
249249 return found_models
250250
251- for model in sd_model_references .root . values ():
252- if not isinstance (model , StableDiffusion_ModelRecord ):
253- logger .error (f"Model { model } is not a StableDiffusion_ModelRecord " )
251+ for model_name , model in sd_model_references .items ():
252+ if not isinstance (model , ImageGenerationModelRecord ):
253+ logger .error (f"Model { model_name } is not a ImageGenerationModelRecord " )
254254 continue
255255
256256 if model .inpainting :
257- found_models .add (model . name )
257+ found_models .add (model_name )
258258
259259 return found_models
260260
@@ -269,21 +269,21 @@ def resolve_all_models_of_baseline(self, baseline: str) -> set[str]:
269269 """
270270 all_model_references = self ._model_reference_manager .get_all_model_references ()
271271
272- sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .stable_diffusion ]
272+ sd_model_references = all_model_references [MODEL_REFERENCE_CATEGORY .image_generation ]
273273
274274 found_models : set [str ] = set ()
275275
276276 if sd_model_references is None :
277277 logger .error ("No stable diffusion models found in model reference." )
278278 return found_models
279279
280- for model in sd_model_references .root . values ():
281- if not isinstance (model , StableDiffusion_ModelRecord ):
282- logger .error (f"Model { model } is not a StableDiffusion_ModelRecord " )
280+ for model_name , model in sd_model_references .items ():
281+ if not isinstance (model , ImageGenerationModelRecord ):
282+ logger .error (f"Model { model_name } is not a ImageGenerationModelRecord " )
283283 continue
284284
285285 if model .baseline == baseline :
286- found_models .add (model . name )
286+ found_models .add (model_name )
287287
288288 return found_models
289289
0 commit comments