Skip to content

Commit ec53008

Browse files
committed
add test
1 parent 1dc755c commit ec53008

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
706706
use_onnx = kwargs.pop("use_onnx", None)
707707
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
708708

709-
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
709+
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
710710
torch_dtype = torch.float32
711711
logger.warning(
712712
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."

tests/pipelines/test_pipelines_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,29 @@ def run_forward(pipe):
22832283
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
22842284
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
22852285

2286+
def test_torch_dtype_dict(self):
2287+
components = self.get_dummy_components()
2288+
if not components:
2289+
self.skipTest("No dummy components defined.")
2290+
2291+
pipe = self.pipeline_class(**components)
2292+
2293+
specified_key = next(iter(components.keys()))
2294+
2295+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
2296+
pipe.save_pretrained(tmpdirname)
2297+
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
2298+
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
2299+
2300+
for name, component in loaded_pipe.components.items():
2301+
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
2302+
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
2303+
self.assertEqual(
2304+
component.dtype,
2305+
expected_dtype,
2306+
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
2307+
)
2308+
22862309

22872310
@is_staging_test
22882311
class PipelinePushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)