@@ -238,15 +238,15 @@ def _get_t5_prompt_embeds(
238238        # Chroma requires the attention mask to include one padding token 
239239        seq_lengths  =  attention_mask .sum (dim = 1 )
240240        mask_indices  =  torch .arange (attention_mask .size (1 )).unsqueeze (0 ).expand (batch_size , - 1 )
241-         attention_mask  =  (mask_indices  <=  seq_lengths .unsqueeze (1 )).long ()
241+         attention_mask  =  (mask_indices  <=  seq_lengths .unsqueeze (1 )).bool ()
242242
243243        prompt_embeds  =  self .text_encoder (
244244            text_input_ids .to (device ), output_hidden_states = False , attention_mask = attention_mask .to (device )
245245        )[0 ]
246246
247247        dtype  =  self .text_encoder .dtype 
248248        prompt_embeds  =  prompt_embeds .to (dtype = dtype , device = device )
249-         attention_mask  =  attention_mask .to (dtype = dtype ,  device = device )
249+         attention_mask  =  attention_mask .to (device = device )
250250
251251        _ , seq_len , _  =  prompt_embeds .shape 
252252
@@ -605,10 +605,9 @@ def _prepare_attention_mask(
605605
606606        # Extend the prompt attention mask to account for image tokens in the final sequence 
607607        attention_mask  =  torch .cat (
608-             [attention_mask , torch .ones (batch_size , sequence_length , device = attention_mask .device )],
608+             [attention_mask , torch .ones (batch_size , sequence_length , device = attention_mask .device ,  dtype = torch . bool )],
609609            dim = 1 ,
610610        )
611-         attention_mask  =  attention_mask .to (dtype )
612611
613612        return  attention_mask 
614613
0 commit comments