Skip to content
17 changes: 15 additions & 2 deletions src/diffusers/pipelines/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None, provide
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"

if provider_options is None:
provider_options = []
elif not isinstance(provider_options, list):
provider_options = [provider_options]

return ort.InferenceSession(
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
)
Expand Down Expand Up @@ -174,7 +179,10 @@ def _from_pretrained(
# load model from local directory
if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
Path(model_id, model_file_name).as_posix(),
provider=provider,
sess_options=sess_options,
provider_options=kwargs.get("provider_options"),
)
kwargs["model_save_dir"] = Path(model_id)
# load model from hub
Expand All @@ -190,7 +198,12 @@ def _from_pretrained(
)
kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
model = OnnxRuntimeModel.load_model(
model_cache_path,
provider=provider,
sess_options=sess_options,
provider_options=kwargs.get("provider_options"),
)
return cls(model=model, **kwargs)

@classmethod
Expand Down