Skip to content

Commit 8082e85

Browse files
authored
Refactor image padding logic to pervent zero tensor in transformer_z_image.py
1 parent b010a8c commit 8082e85

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -515,16 +515,19 @@ def patchify_and_embed(
515515
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
516516
device=device,
517517
).flatten(0, 2)
518-
image_padding_pos_ids = (
519-
self.create_coordinate_grid(
520-
size=(1, 1, 1),
521-
start=(0, 0, 0),
522-
device=device,
518+
if image_padding_len > 0:
519+
image_padding_pos_ids = (
520+
self.create_coordinate_grid(
521+
size=(1, 1, 1),
522+
start=(0, 0, 0),
523+
device=device,
524+
)
525+
.flatten(0, 2)
526+
.repeat(image_padding_len, 1)
523527
)
524-
.flatten(0, 2)
525-
.repeat(image_padding_len, 1)
526-
)
527-
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
528+
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
529+
else :
530+
image_padded_pos_ids = image_ori_pos_ids
528531
all_image_pos_ids.append(image_padded_pos_ids)
529532
# pad mask
530533
all_image_pad_mask.append(
@@ -534,10 +537,10 @@ def patchify_and_embed(
534537
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
535538
],
536539
dim=0,
537-
)
540+
) if image_padding_len > 0 else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
538541
)
539542
# padded feature
540-
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
543+
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image
541544
all_image_out.append(image_padded_feat)
542545

543546
return (

0 commit comments

Comments
 (0)