15
15
# limitations under the License.
16
16
17
17
import importlib
18
+ import inspect
18
19
import os
19
20
from typing import Optional , Union
20
21
@@ -148,6 +149,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
148
149
diffusers_module = importlib .import_module (cls .__module__ .split ("." )[0 ])
149
150
pipeline_class = getattr (diffusers_module , config_dict ["_class_name" ])
150
151
152
+ # some modules can be passed directly to the init
153
+ # in this case they are already instantiated in `kwargs`
154
+ # extract them here
155
+ expected_modules = set (inspect .signature (pipeline_class .__init__ ).parameters .keys ())
156
+ passed_class_obj = {k : kwargs .pop (k ) for k in expected_modules if k in kwargs }
157
+
151
158
init_dict , _ = pipeline_class .extract_init_dict (config_dict , ** kwargs )
152
159
153
160
init_kwargs = {}
@@ -158,8 +165,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
158
165
# 3. Load each module in the pipeline
159
166
for name , (library_name , class_name ) in init_dict .items ():
160
167
is_pipeline_module = hasattr (pipelines , library_name )
168
+ loaded_sub_model = None
169
+
161
170
# if the model is in a pipeline module, then we load it from the pipeline
162
- if is_pipeline_module :
171
+ if name in passed_class_obj :
172
+ # 1. check that passed_class_obj has correct parent class
173
+ if not is_pipeline_module :
174
+ library = importlib .import_module (library_name )
175
+ class_obj = getattr (library , class_name )
176
+ importable_classes = LOADABLE_CLASSES [library_name ]
177
+ class_candidates = {c : getattr (library , c ) for c in importable_classes .keys ()}
178
+
179
+ expected_class_obj = None
180
+ for class_name , class_candidate in class_candidates .items ():
181
+ if issubclass (class_obj , class_candidate ):
182
+ expected_class_obj = class_candidate
183
+
184
+ if not issubclass (passed_class_obj [name ].__class__ , expected_class_obj ):
185
+ raise ValueError (
186
+ f"{ passed_class_obj [name ]} is of type: { type (passed_class_obj [name ])} , but should be"
187
+ f" { expected_class_obj } "
188
+ )
189
+ else :
190
+ logger .warn (
191
+ f"You have passed a non-standard module { passed_class_obj [name ]} . We cannot verify whether it"
192
+ " has the correct type"
193
+ )
194
+
195
+ # set passed class object
196
+ loaded_sub_model = passed_class_obj [name ]
197
+ elif is_pipeline_module :
163
198
pipeline_module = getattr (pipelines , library_name )
164
199
class_obj = getattr (pipeline_module , class_name )
165
200
importable_classes = ALL_IMPORTABLE_CLASSES
@@ -171,23 +206,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
171
206
importable_classes = LOADABLE_CLASSES [library_name ]
172
207
class_candidates = {c : getattr (library , c ) for c in importable_classes .keys ()}
173
208
174
- load_method_name = None
175
- for class_name , class_candidate in class_candidates .items ():
176
- if issubclass (class_obj , class_candidate ):
177
- load_method_name = importable_classes [class_name ][1 ]
209
+ if loaded_sub_model is None :
210
+ load_method_name = None
211
+ for class_name , class_candidate in class_candidates .items ():
212
+ if issubclass (class_obj , class_candidate ):
213
+ load_method_name = importable_classes [class_name ][1 ]
178
214
179
- load_method = getattr (class_obj , load_method_name )
215
+ load_method = getattr (class_obj , load_method_name )
180
216
181
- # check if the module is in a subdirectory
182
- if os .path .isdir (os .path .join (cached_folder , name )):
183
- loaded_sub_model = load_method (os .path .join (cached_folder , name ))
184
- else :
185
- # else load from the root directory
186
- loaded_sub_model = load_method (cached_folder )
217
+ # check if the module is in a subdirectory
218
+ if os .path .isdir (os .path .join (cached_folder , name )):
219
+ loaded_sub_model = load_method (os .path .join (cached_folder , name ))
220
+ else :
221
+ # else load from the root directory
222
+ loaded_sub_model = load_method (cached_folder )
187
223
188
224
init_kwargs [name ] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
189
225
190
- # 5 . Instantiate the pipeline
226
+ # 4 . Instantiate the pipeline
191
227
model = pipeline_class (** init_kwargs )
192
228
return model
193
229
0 commit comments