Skip to content

Commit 1cf90c4

Browse files
committed
test: add ChromaPipeline attention mask dtype tests
1 parent 7bb6e4a commit 1cf90c4

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/pipelines/chroma/test_pipeline_chroma.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,24 @@ def test_chroma_image_output_shape(self):
158158
image = pipe(**inputs).images[0]
159159
output_height, output_width, _ = image.shape
160160
assert (output_height, output_width) == (expected_height, expected_width)
161+
162+
163+
class ChromaPipelineAttentionMaskTests(unittest.TestCase):
164+
def setUp(self):
165+
self.pipe = ChromaPipeline.from_pretrained(
166+
"lodestones/Chroma1-Base",
167+
torch_dtype=torch.float16,
168+
)
169+
170+
def test_attention_mask_dtype_is_bool_short_prompt(self):
171+
prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds("man")
172+
self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}")
173+
self.assertGreater(prompt_embeds.shape[0], 0)
174+
self.assertGreater(prompt_embeds.shape[1], 0)
175+
176+
def test_attention_mask_dtype_is_bool_long_prompt(self):
177+
long_prompt = "a detailed portrait of a man standing in a garden with flowers and trees"
178+
prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds(long_prompt)
179+
self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}")
180+
self.assertGreater(prompt_embeds.shape[0], 0)
181+
self.assertGreater(prompt_embeds.shape[1], 0)

0 commit comments

Comments
 (0)