@@ -238,15 +238,15 @@ def _get_t5_prompt_embeds(
238238 # Chroma requires the attention mask to include one padding token
239239 seq_lengths = attention_mask .sum (dim = 1 )
240240 mask_indices = torch .arange (attention_mask .size (1 )).unsqueeze (0 ).expand (batch_size , - 1 )
241- attention_mask = (mask_indices <= seq_lengths .unsqueeze (1 )).long ()
241+ attention_mask = (mask_indices <= seq_lengths .unsqueeze (1 )).bool ()
242242
243243 prompt_embeds = self .text_encoder (
244244 text_input_ids .to (device ), output_hidden_states = False , attention_mask = attention_mask .to (device )
245245 )[0 ]
246246
247247 dtype = self .text_encoder .dtype
248248 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
249- attention_mask = attention_mask .to (dtype = dtype , device = device )
249+ attention_mask = attention_mask .to (device = device )
250250
251251 _ , seq_len , _ = prompt_embeds .shape
252252
@@ -605,10 +605,9 @@ def _prepare_attention_mask(
605605
606606 # Extend the prompt attention mask to account for image tokens in the final sequence
607607 attention_mask = torch .cat (
608- [attention_mask , torch .ones (batch_size , sequence_length , device = attention_mask .device )],
608+ [attention_mask , torch .ones (batch_size , sequence_length , device = attention_mask .device , dtype = torch . bool )],
609609 dim = 1 ,
610610 )
611- attention_mask = attention_mask .to (dtype )
612611
613612 return attention_mask
614613
0 commit comments