@@ -834,20 +834,22 @@ def load_module(name, value):
834834 return True
835835
836836 init_dict = {k : v for k , v in init_dict .items () if load_module (k , v )}
837- scheduler_types = expected_types ["scheduler" ][0 ]
838- if isinstance (scheduler_types , enum .EnumMeta ):
839- scheduler_types = list (scheduler_types )
840- else :
841- scheduler_types = [str (scheduler_types )]
842- scheduler_types = [str (scheduler ).split ("." )[- 1 ].strip ("'>" ) for scheduler in scheduler_types ]
837+ scheduler_types = None
838+ if "scheduler" in expected_types :
839+ scheduler_types = expected_types ["scheduler" ][0 ]
840+ if isinstance (scheduler_types , enum .EnumMeta ):
841+ scheduler_types = list (scheduler_types )
842+ else :
843+ scheduler_types = [str (scheduler_types )]
844+ scheduler_types = [str (scheduler ).split ("." )[- 1 ].strip ("'>" ) for scheduler in scheduler_types ]
843845
844846 for key , (_ , expected_class_name ) in zip (init_dict .keys (), init_dict .values ()):
845847 if key not in passed_class_obj :
846848 continue
847849 class_name = passed_class_obj [key ].__class__ .__name__
848850 class_name = class_name [4 :] if class_name .startswith ("Flax" ) else class_name
849851 expected_class_name = expected_class_name [4 :] if expected_class_name .startswith ("Flax" ) else expected_class_name
850- if key == "scheduler" and class_name not in scheduler_types :
852+ if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types :
851853 raise ValueError (f"Expected { scheduler_types } for { key } , got { class_name } ." )
852854 elif key != "scheduler" and class_name != expected_class_name :
853855 raise ValueError (f"Expected { expected_class_name } for { key } , got { class_name } ." )
0 commit comments