@@ -158,3 +158,24 @@ def test_chroma_image_output_shape(self):
158158 image = pipe (** inputs ).images [0 ]
159159 output_height , output_width , _ = image .shape
160160 assert (output_height , output_width ) == (expected_height , expected_width )
161+
162+
163+ class ChromaPipelineAttentionMaskTests (unittest .TestCase ):
164+ def setUp (self ):
165+ self .pipe = ChromaPipeline .from_pretrained (
166+ "lodestones/Chroma1-Base" ,
167+ torch_dtype = torch .float16 ,
168+ )
169+
170+ def test_attention_mask_dtype_is_bool_short_prompt (self ):
171+ prompt_embeds , attn_mask = self .pipe ._get_t5_prompt_embeds ("man" )
172+ self .assertEqual (attn_mask .dtype , torch .bool , f"Expected bool, got { attn_mask .dtype } " )
173+ self .assertGreater (prompt_embeds .shape [0 ], 0 )
174+ self .assertGreater (prompt_embeds .shape [1 ], 0 )
175+
176+ def test_attention_mask_dtype_is_bool_long_prompt (self ):
177+ long_prompt = "a detailed portrait of a man standing in a garden with flowers and trees"
178+ prompt_embeds , attn_mask = self .pipe ._get_t5_prompt_embeds (long_prompt )
179+ self .assertEqual (attn_mask .dtype , torch .bool , f"Expected bool, got { attn_mask .dtype } " )
180+ self .assertGreater (prompt_embeds .shape [0 ], 0 )
181+ self .assertGreater (prompt_embeds .shape [1 ], 0 )
0 commit comments