Skip to content

Commit 7c16aa8

Browse files
committed
Remove dedicated dtype tests as per feedback
1 parent 26c33ef commit 7c16aa8

File tree

1 file changed

+0
-24
lines changed

1 file changed

+0
-24
lines changed

tests/pipelines/chroma/test_pipeline_chroma.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from transformers import AutoTokenizer, T5EncoderModel
66

77
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
8-
from diffusers.utils.testing_utils import slow
98

109
from ...testing_utils import torch_device
1110
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
@@ -159,26 +158,3 @@ def test_chroma_image_output_shape(self):
159158
image = pipe(**inputs).images[0]
160159
output_height, output_width, _ = image.shape
161160
assert (output_height, output_width) == (expected_height, expected_width)
162-
163-
164-
class ChromaPipelineAttentionMaskTests(unittest.TestCase):
165-
def setUp(self):
166-
self.pipe = ChromaPipeline.from_pretrained(
167-
"lodestones/Chroma1-Base",
168-
torch_dtype=torch.float16,
169-
)
170-
171-
@slow
172-
def test_attention_mask_dtype_is_bool_short_prompt(self):
173-
prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds("man")
174-
self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}")
175-
self.assertGreater(prompt_embeds.shape[0], 0)
176-
self.assertGreater(prompt_embeds.shape[1], 0)
177-
178-
@slow
179-
def test_attention_mask_dtype_is_bool_long_prompt(self):
180-
long_prompt = "a detailed portrait of a man standing in a garden with flowers and trees"
181-
prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds(long_prompt)
182-
self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}")
183-
self.assertGreater(prompt_embeds.shape[0], 0)
184-
self.assertGreater(prompt_embeds.shape[1], 0)

0 commit comments

Comments
 (0)