Skip to content

Commit bac8438

Browse files
committed
Update testcase with new list command and sd engine update.
1 parent bfef661 commit bac8438

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

testscases/StableDiffusion/stable_diffusion_engine_tc.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,16 @@ def initialize_engine(model_name, model_path, device_list):
153153
return controlnet_openpose.ControlNetOpenPose(model=model_path, device=device_list)
154154
if model_name == "controlnet_referenceonly":
155155
return stable_diffusion_engine.StableDiffusionEngineReferenceOnly(model=model_path, device=device_list)
156-
return stable_diffusion_engine.StableDiffusionEngine(model=model_path, device=device_list)
156+
return stable_diffusion_engine.StableDiffusionEngine(model=model_path, device=device_list, model_name=model_name)
157157

158158
def parse_args() -> argparse.Namespace:
159159
"""Parse and return command line arguments."""
160160
parser = argparse.ArgumentParser(add_help=False, formatter_class=argparse.RawTextHelpFormatter)
161161
args = parser.add_argument_group('Options')
162162
args.add_argument('-h', '--help', action = 'help',
163163
help='Show this help message and exit.')
164+
args.add_argument('-l', '--list', action = 'store_true',
165+
help='Show list of models currently installed.')
164166
# base path to models
165167
args.add_argument('-bp','--model_base_path',type = str, default = None, required = False,
166168
help='Optional. Specify the absolute base path to model weights. \nUsage example: -bp \\stable-diffusion\\model-weights\\')
@@ -194,8 +196,6 @@ def parse_args() -> argparse.Namespace:
194196
# guidance scale
195197
args.add_argument('-g','--guidance_scale',type = float, default = 7.5, required = False,
196198
help='Optional. Affects how closely the image prompt is followed.')
197-
198-
199199
# power mode
200200
args.add_argument('-pm','--power_mode',type = str, default = "best performance", required = False,
201201
help='Optional. Specify the power mode. Default is best performance')
@@ -209,6 +209,32 @@ def parse_args() -> argparse.Namespace:
209209

210210
return parser.parse_args()
211211

212+
def validate_model_paths(base_path: str, model_paths: dict) -> dict:
213+
"""
214+
Check if model directories exist based on base_path and model_paths structure.
215+
216+
Args:
217+
base_path (str): Root directory where models are stored.
218+
model_paths (dict): Dictionary with model keys and relative path parts.
219+
220+
Returns:
221+
dict: Dictionary with model names and a boolean indicating existence.
222+
"""
223+
results = {}
224+
for model_name, relative_parts in model_paths.items():
225+
full_path = os.path.join(base_path, *relative_parts)
226+
if os.path.isdir(full_path):
227+
if "int8a16" in model_name:
228+
if os.path.isfile(os.path.join(full_path, "unet_int8a16.xml")):
229+
results[model_name] = full_path
230+
elif "fp8" in model_name:
231+
if os.path.isfile(os.path.join(full_path, "unet_fp8.xml")):
232+
results[model_name] = full_path
233+
else:
234+
results[model_name] = full_path
235+
return results
236+
237+
212238
def main():
213239
args = parse_args()
214240
results = []
@@ -253,6 +279,12 @@ def main():
253279
"controlnet_scribble_int8": ["stable-diffusion-ov", "controlnet-scribble-int8"],
254280
}
255281

282+
if args.list:
283+
print(f"\nInstalled models: ")
284+
for key in validate_model_paths(weight_path, model_paths).keys():
285+
print(f"{key}")
286+
exit()
287+
256288
model_name = args.model_name
257289
model_path = os.path.join(weight_path, *model_paths.get(model_name))
258290
model_config_file_name = os.path.join(model_path, "config.json")

0 commit comments

Comments
 (0)