Skip to content

Commit 6d6ce7c

Browse files
author
Grzegorz Pluto-Prondzinski
authored
Fix null handling when padding optional negative prompt batches (#2357)
1 parent 6afa2ca commit 6d6ce7c

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

optimum/habana/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def _split_inputs_into_batches(
146146
prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
147147
if prompt_embeds_mask is not None:
148148
prompt_embeds_mask_batches = list(torch.split(prompt_embeds_mask, batch_size))
149+
else:
150+
prompt_embeds_mask_batches = None
149151
if negative_prompt_embeds is not None:
150152
negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
151153
else:
@@ -182,10 +184,11 @@ def _split_inputs_into_batches(
182184
prompt_embeds_mask_batches[-1] = torch.vstack(sequence_to_stack)
183185

184186
# Pad negative_prompt_embeds_batches
185-
sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
186-
torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
187-
)
188-
negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
187+
if negative_prompt_embeds_batches is not None:
188+
sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
189+
torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
190+
)
191+
negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
189192

190193
# Pad negative_prompt_embeds_mask if necessary
191194
if negative_prompt_embeds_mask is not None:
@@ -206,7 +209,9 @@ def _split_inputs_into_batches(
206209
# Stack batches in the same tensor
207210
latents_batches = torch.stack(latents_batches)
208211
prompt_embeds_batches = torch.stack(prompt_embeds_batches)
209-
prompt_embeds_mask_batches = torch.stack(prompt_embeds_mask_batches)
212+
prompt_embeds_mask_batches = (
213+
torch.stack(prompt_embeds_mask_batches) if prompt_embeds_mask_batches is not None else None
214+
)
210215
if negative_prompt_embeds_batches is not None:
211216
negative_prompt_embeds_batches = torch.stack(negative_prompt_embeds_batches)
212217
if negative_prompt_embeds_mask_batches is not None:

0 commit comments

Comments
 (0)