Skip to content

Commit 43c41f4

Browse files
committed
fix tests
1 parent f035a0d commit 43c41f4

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,16 +1086,19 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
10861086

10871087
def test_wrong_device_map_raises_error(self):
10881088
with self.assertRaises(ValueError) as err_ctx:
1089-
_ = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy-subfolder", device_map=-1)
1090-
msg_substring = "You can't pass device_map as a negative int"
1091-
assert msg_substring in str(err_ctx.exception)
1089+
_ = self.model_class.from_pretrained(
1090+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=-1
1091+
)
1092+
1093+
msg_substring = "You can't pass device_map as a negative int"
1094+
assert msg_substring in str(err_ctx.exception)
10921095

10931096
@require_torch_gpu
10941097
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
10951098
def test_passing_non_dict_device_map_works(self, device_map):
10961099
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10971100
loaded_model = self.model_class.from_pretrained(
1098-
"hf-internal-testing/unet2d-sharded-dummy-subfolder", device_map=device_map
1101+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
10991102
)
11001103
output = loaded_model(**inputs_dict)
11011104
assert output.sample.shape == (4, 4, 16, 16)

0 commit comments

Comments
 (0)