1313# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414# See the License for the specific language governing permissions and
1515# limitations under the License.
16+ import enum
1617import fnmatch
1718import importlib
1819import inspect
@@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
811812 # in this case they are already instantiated in `kwargs`
812813 # extract them here
813814 expected_modules , optional_kwargs = cls ._get_signature_keys (pipeline_class )
815+ expected_types = pipeline_class ._get_signature_types ()
814816 passed_class_obj = {k : kwargs .pop (k ) for k in expected_modules if k in kwargs }
815817 passed_pipe_kwargs = {k : kwargs .pop (k ) for k in optional_kwargs if k in kwargs }
816818 init_dict , unused_kwargs , _ = pipeline_class .extract_init_dict (config_dict , ** kwargs )
@@ -832,13 +834,21 @@ def load_module(name, value):
832834 return True
833835
834836 init_dict = {k : v for k , v in init_dict .items () if load_module (k , v )}
837+ scheduler_types = expected_types ["scheduler" ][0 ]
838+ if isinstance (scheduler_types , enum .EnumType ):
839+ scheduler_types = list (scheduler_types )
840+ else :
841+ scheduler_types = [str (scheduler_types )]
842+ scheduler_types = [str (scheduler ).split ("." )[- 1 ].strip ("'>" ) for scheduler in scheduler_types ]
835843
836844 for key , (_ , expected_class_name ) in zip (init_dict .keys (), init_dict .values ()):
837- if key not in passed_class_obj or key == "scheduler" :
845+ if key not in passed_class_obj :
838846 continue
839847 class_name = passed_class_obj [key ].__class__ .__name__
840848 class_name = class_name [4 :] if class_name .startswith ("Flax" ) else class_name
841- if class_name != expected_class_name :
849+ if key == "scheduler" and class_name not in scheduler_types :
850+ raise ValueError (f"Expected { scheduler_types } for { key } , got { class_name } ." )
851+ elif key != "scheduler" and class_name != expected_class_name :
842852 raise ValueError (f"Expected { expected_class_name } for { key } , got { class_name } ." )
843853
844854 # Special case: safety_checker must be loaded separately when using `from_flax`
0 commit comments