Skip to content

Commit b353787

Browse files
committed
Fixed the CICD for Diffusion
Signed-off-by: jingyu <[email protected]> Signed-off-by: Jingyu Xin <[email protected]>
1 parent d6d2e75 commit b353787

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ def main():
105105

106106
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
107107

108-
pipe = PipelineManager.create_pipeline_from(MODEL_ID[args.model], dtype_map[args.model_dtype])
108+
pipe = PipelineManager.create_pipeline_from(
109+
MODEL_ID[args.model],
110+
dtype_map[args.model_dtype],
111+
override_model_path=args.override_model_path,
112+
)
109113

110114
# Save the backbone of the pipeline and move it to the GPU
111115
add_embedding = None

examples/diffusers/quantization/quantize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,9 @@ def __init__(self, config: ModelConfig, logger: logging.Logger):
309309

310310
@staticmethod
311311
def create_pipeline_from(
312-
model_type: ModelType, torch_dtype: torch.dtype = torch.bfloat16
312+
model_type: ModelType,
313+
torch_dtype: torch.dtype = torch.bfloat16,
314+
override_model_path: str | None = None,
313315
) -> DiffusionPipeline:
314316
"""
315317
Create and return an appropriate pipeline based on configuration.
@@ -321,7 +323,9 @@ def create_pipeline_from(
321323
ValueError: If model type is unsupported
322324
"""
323325
try:
324-
model_id = MODEL_REGISTRY[model_type]
326+
model_id = (
327+
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
328+
)
325329
if model_type == ModelType.SD3_MEDIUM:
326330
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
327331
elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]:

examples/diffusers/quantization/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
cuda-python
2+
diffusers==0.34.0
23
nvtx
34
onnx_graphsurgeon
45
opencv-python>=4.8.1.78,<4.12.0.88

0 commit comments

Comments
 (0)