Skip to content

Commit 351b2f1

Browse files
committed
update
1 parent 08c2902 commit 351b2f1

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def from_pretrained(
310310
"cache_dir",
311311
"force_download",
312312
"local_files_only",
313+
"local_dir",
313314
"proxies",
314315
"resume_download",
315316
"revision",
@@ -336,7 +337,6 @@ def from_pretrained(
336337
module_file=module_file,
337338
class_name=class_name,
338339
**hub_kwargs,
339-
**kwargs,
340340
)
341341
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
342342
block_kwargs = {

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def get_cached_module_file(
254254
token: Optional[Union[bool, str]] = None,
255255
revision: Optional[str] = None,
256256
local_files_only: bool = False,
257+
local_dir: Optional[str] = None,
257258
):
258259
"""
259260
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -336,6 +337,7 @@ def get_cached_module_file(
336337
force_download=force_download,
337338
proxies=proxies,
338339
local_files_only=local_files_only,
340+
local_dir=local_dir,
339341
)
340342
submodule = "git"
341343
module_file = pretrained_model_name_or_path + ".py"
@@ -359,6 +361,7 @@ def get_cached_module_file(
359361
force_download=force_download,
360362
proxies=proxies,
361363
local_files_only=local_files_only,
364+
local_dir=local_dir,
362365
token=token,
363366
)
364367
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -419,6 +422,7 @@ def get_cached_module_file(
419422
token=token,
420423
revision=revision,
421424
local_files_only=local_files_only,
425+
local_dir=local_dir,
422426
)
423427
return os.path.join(full_submodule, module_file)
424428

@@ -435,7 +439,7 @@ def get_class_from_dynamic_module(
435439
token: Optional[Union[bool, str]] = None,
436440
revision: Optional[str] = None,
437441
local_files_only: bool = False,
438-
**kwargs,
442+
local_dir: Optional[str] = None,
439443
):
440444
"""
441445
Extracts a class from a module file, present in the local folder or repository of a model.
@@ -508,5 +512,6 @@ def get_class_from_dynamic_module(
508512
token=token,
509513
revision=revision,
510514
local_files_only=local_files_only,
515+
local_dir=local_dir,
511516
)
512517
return get_class_in_module(class_name, final_module)

0 commit comments

Comments
 (0)