Skip to content

Commit 513f1fb

Browse files
Allow passing non-default modules to pipeline (#188)
* Allow passing non-default modules to pipeline. Override modules are recognized and replaced in the pipeline. However, no check is performed about mismatched classes yet. This is because the override module is already instantiated and we have no library or class name to compare against. * up * add test Co-authored-by: Patrick von Platen <[email protected]>
1 parent d7b6920 commit 513f1fb

File tree

2 files changed

+71
-13
lines changed

2 files changed

+71
-13
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import importlib
18+
import inspect
1819
import os
1920
from typing import Optional, Union
2021

@@ -148,6 +149,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
148149
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
149150
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
150151

152+
# some modules can be passed directly to the init
153+
# in this case they are already instantiated in `kwargs`
154+
# extract them here
155+
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
156+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
157+
151158
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
152159

153160
init_kwargs = {}
@@ -158,8 +165,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
158165
# 3. Load each module in the pipeline
159166
for name, (library_name, class_name) in init_dict.items():
160167
is_pipeline_module = hasattr(pipelines, library_name)
168+
loaded_sub_model = None
169+
161170
# if the model is in a pipeline module, then we load it from the pipeline
162-
if is_pipeline_module:
171+
if name in passed_class_obj:
172+
# 1. check that passed_class_obj has correct parent class
173+
if not is_pipeline_module:
174+
library = importlib.import_module(library_name)
175+
class_obj = getattr(library, class_name)
176+
importable_classes = LOADABLE_CLASSES[library_name]
177+
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
178+
179+
expected_class_obj = None
180+
for class_name, class_candidate in class_candidates.items():
181+
if issubclass(class_obj, class_candidate):
182+
expected_class_obj = class_candidate
183+
184+
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
185+
raise ValueError(
186+
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
187+
f" {expected_class_obj}"
188+
)
189+
else:
190+
logger.warn(
191+
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
192+
" has the correct type"
193+
)
194+
195+
# set passed class object
196+
loaded_sub_model = passed_class_obj[name]
197+
elif is_pipeline_module:
163198
pipeline_module = getattr(pipelines, library_name)
164199
class_obj = getattr(pipeline_module, class_name)
165200
importable_classes = ALL_IMPORTABLE_CLASSES
@@ -171,23 +206,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
171206
importable_classes = LOADABLE_CLASSES[library_name]
172207
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
173208

174-
load_method_name = None
175-
for class_name, class_candidate in class_candidates.items():
176-
if issubclass(class_obj, class_candidate):
177-
load_method_name = importable_classes[class_name][1]
209+
if loaded_sub_model is None:
210+
load_method_name = None
211+
for class_name, class_candidate in class_candidates.items():
212+
if issubclass(class_obj, class_candidate):
213+
load_method_name = importable_classes[class_name][1]
178214

179-
load_method = getattr(class_obj, load_method_name)
215+
load_method = getattr(class_obj, load_method_name)
180216

181-
# check if the module is in a subdirectory
182-
if os.path.isdir(os.path.join(cached_folder, name)):
183-
loaded_sub_model = load_method(os.path.join(cached_folder, name))
184-
else:
185-
# else load from the root directory
186-
loaded_sub_model = load_method(cached_folder)
217+
# check if the module is in a subdirectory
218+
if os.path.isdir(os.path.join(cached_folder, name)):
219+
loaded_sub_model = load_method(os.path.join(cached_folder, name))
220+
else:
221+
# else load from the root directory
222+
loaded_sub_model = load_method(cached_folder)
187223

188224
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
189225

190-
# 5. Instantiate the pipeline
226+
# 4. Instantiate the pipeline
191227
model = pipeline_class(**init_kwargs)
192228
return model
193229

tests/test_modeling_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,28 @@ def test_from_pretrained_hub(self):
718718

719719
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
720720

721+
@slow
722+
def test_from_pretrained_hub_pass_model(self):
723+
model_path = "google/ddpm-cifar10-32"
724+
725+
# pass unet into DiffusionPipeline
726+
unet = UNet2DModel.from_pretrained(model_path)
727+
ddpm_from_hub_custom_model = DDPMPipeline.from_pretrained(model_path, unet=unet)
728+
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet)
729+
730+
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
731+
732+
ddpm_from_hub_custom_model.scheduler.num_timesteps = 10
733+
ddpm_from_hub.scheduler.num_timesteps = 10
734+
735+
generator = torch.manual_seed(0)
736+
737+
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
738+
generator = generator.manual_seed(0)
739+
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
740+
741+
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
742+
721743
@slow
722744
def test_output_format(self):
723745
model_path = "google/ddpm-cifar10-32"

0 commit comments

Comments
 (0)