Skip to content

Commit 6aad7a7

Browse files
committed
Fix for scheduler
1 parent 679c18c commit 6aad7a7

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import enum
1617
import fnmatch
1718
import importlib
1819
import inspect
@@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
811812
# in this case they are already instantiated in `kwargs`
812813
# extract them here
813814
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
815+
expected_types = pipeline_class._get_signature_types()
814816
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
815817
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
816818
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -832,13 +834,21 @@ def load_module(name, value):
832834
return True
833835

834836
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
837+
scheduler_types = expected_types["scheduler"][0]
838+
if isinstance(scheduler_types, enum.EnumType):
839+
scheduler_types = list(scheduler_types)
840+
else:
841+
scheduler_types = [str(scheduler_types)]
842+
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]
835843

836844
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
837-
if key not in passed_class_obj or key == "scheduler":
845+
if key not in passed_class_obj:
838846
continue
839847
class_name = passed_class_obj[key].__class__.__name__
840848
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
841-
if class_name != expected_class_name:
849+
if key == "scheduler" and class_name not in scheduler_types:
850+
raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.")
851+
elif key != "scheduler" and class_name != expected_class_name:
842852
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")
843853

844854
# Special case: safety_checker must be loaded separately when using `from_flax`

0 commit comments

Comments
 (0)