@@ -146,6 +146,8 @@ def _split_inputs_into_batches(
146146 prompt_embeds_batches = list (torch .split (prompt_embeds , batch_size ))
147147 if prompt_embeds_mask is not None :
148148 prompt_embeds_mask_batches = list (torch .split (prompt_embeds_mask , batch_size ))
149+ else :
150+ prompt_embeds_mask_batches = None
149151 if negative_prompt_embeds is not None :
150152 negative_prompt_embeds_batches = list (torch .split (negative_prompt_embeds , batch_size ))
151153 else :
@@ -182,10 +184,11 @@ def _split_inputs_into_batches(
182184 prompt_embeds_mask_batches [- 1 ] = torch .vstack (sequence_to_stack )
183185
184186 # Pad negative_prompt_embeds_batches
185- sequence_to_stack = (negative_prompt_embeds_batches [- 1 ],) + tuple (
186- torch .zeros_like (negative_prompt_embeds_batches [- 1 ][0 ][None , :]) for _ in range (num_dummy_samples )
187- )
188- negative_prompt_embeds_batches [- 1 ] = torch .vstack (sequence_to_stack )
187+ if negative_prompt_embeds_batches is not None :
188+ sequence_to_stack = (negative_prompt_embeds_batches [- 1 ],) + tuple (
189+ torch .zeros_like (negative_prompt_embeds_batches [- 1 ][0 ][None , :]) for _ in range (num_dummy_samples )
190+ )
191+ negative_prompt_embeds_batches [- 1 ] = torch .vstack (sequence_to_stack )
189192
190193 # Pad negative_prompt_embeds_mask if necessary
191194 if negative_prompt_embeds_mask is not None :
@@ -206,7 +209,9 @@ def _split_inputs_into_batches(
206209 # Stack batches in the same tensor
207210 latents_batches = torch .stack (latents_batches )
208211 prompt_embeds_batches = torch .stack (prompt_embeds_batches )
209- prompt_embeds_mask_batches = torch .stack (prompt_embeds_mask_batches )
212+ prompt_embeds_mask_batches = (
213+ torch .stack (prompt_embeds_mask_batches ) if prompt_embeds_mask_batches is not None else None
214+ )
210215 if negative_prompt_embeds_batches is not None :
211216 negative_prompt_embeds_batches = torch .stack (negative_prompt_embeds_batches )
212217 if negative_prompt_embeds_mask_batches is not None :
0 commit comments