Skip to content

Commit 1ba8a88

Browse files
committed
fix comments
Signed-off-by: Matrix Yao <[email protected]>
1 parent bda0afd commit 1ba8a88

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

src/diffusers/models/unets/unet_2d_blocks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2557,8 +2557,6 @@ def forward(
25572557
b1=self.b1,
25582558
b2=self.b2,
25592559
)
2560-
if hidden_states.device != res_hidden_states.device:
2561-
res_hidden_states = res_hidden_states.to(hidden_states.device)
25622560
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
25632561

25642562
if torch.is_grad_enabled() and self.gradient_checkpointing:

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class conditioning with `class_embed_type` equal to `None`.
165165
"""
166166

167167
_supports_gradient_checkpointing = True
168-
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
168+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
169169
_skip_layerwise_casting_patterns = ["norm"]
170170
_repeated_blocks = ["BasicTransformerBlock"]
171171

tests/models/test_modeling_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,8 +1828,8 @@ def test_wrong_device_map_raises_error(self, device_map, msg_substring):
18281828

18291829
assert msg_substring in str(err_ctx.exception)
18301830

1831-
@parameterized.expand([0, "cuda", torch.device("cuda")])
1832-
@require_torch_gpu
1831+
@parameterized.expand([0, torch_device, torch.device(torch_device)])
1832+
@require_torch_accelerator
18331833
def test_passing_non_dict_device_map_works(self, device_map):
18341834
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
18351835
model = self.model_class(**init_dict).eval()
@@ -1838,8 +1838,8 @@ def test_passing_non_dict_device_map_works(self, device_map):
18381838
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
18391839
_ = loaded_model(**inputs_dict)
18401840

1841-
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
1842-
@require_torch_gpu
1841+
@parameterized.expand([("", torch_device), ("", torch.device(torch_device))])
1842+
@require_torch_accelerator
18431843
def test_passing_dict_device_map_works(self, name, device):
18441844
# There are other valid dict-based `device_map` values too. It's best to refer to
18451845
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.

0 commit comments

Comments
 (0)