Skip to content

Commit cfd5b34

Browse files
authored
fix chroma pipeline fast tests
1 parent c8d6aef commit cfd5b34

File tree

1 file changed

+6
-34
lines changed

1 file changed

+6
-34
lines changed

tests/pipelines/chroma/chroma.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,31 +67,13 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
6767
attention_head_dim=16,
6868
num_attention_heads=2,
6969
joint_attention_dim=32,
70-
pooled_projection_dim=32,
7170
axes_dims_rope=[4, 4, 8],
7271
)
73-
clip_text_encoder_config = CLIPTextConfig(
74-
bos_token_id=0,
75-
eos_token_id=2,
76-
hidden_size=32,
77-
intermediate_size=37,
78-
layer_norm_eps=1e-05,
79-
num_attention_heads=4,
80-
num_hidden_layers=5,
81-
pad_token_id=1,
82-
vocab_size=1000,
83-
hidden_act="gelu",
84-
projection_dim=32,
85-
)
86-
87-
torch.manual_seed(0)
88-
text_encoder = CLIPTextModel(clip_text_encoder_config)
8972

9073
torch.manual_seed(0)
91-
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
74+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
9275

93-
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
94-
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
76+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
9577

9678
torch.manual_seed(0)
9779
vae = AutoencoderKL(
@@ -113,7 +95,6 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
11395
return {
11496
"scheduler": scheduler,
11597
"text_encoder": text_encoder,
116-
"text_encoder_2": text_encoder_2,
11798
"tokenizer": tokenizer,
11899
"tokenizer_2": tokenizer_2,
119100
"transformer": transformer,
@@ -130,6 +111,7 @@ def get_dummy_inputs(self, device, seed=0):
130111

131112
inputs = {
132113
"prompt": "A painting of a squirrel eating a burger",
114+
"negative_prompt": "bad, ugly",
133115
"generator": generator,
134116
"num_inference_steps": 2,
135117
"guidance_scale": 5.0,
@@ -140,14 +122,14 @@ def get_dummy_inputs(self, device, seed=0):
140122
}
141123
return inputs
142124

143-
def test_flux_different_prompts(self):
125+
def test_chroma_different_prompts(self):
144126
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
145127

146128
inputs = self.get_dummy_inputs(torch_device)
147129
output_same_prompt = pipe(**inputs).images[0]
148130

149131
inputs = self.get_dummy_inputs(torch_device)
150-
inputs["prompt_2"] = "a different prompt"
132+
inputs["prompt"] = "a different prompt"
151133
output_different_prompts = pipe(**inputs).images[0]
152134

153135
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
@@ -196,7 +178,7 @@ def test_fused_qkv_projections(self):
196178
"Original outputs should match when fused QKV projections are disabled."
197179
)
198180

199-
def test_flux_image_output_shape(self):
181+
def test_chroma_image_output_shape(self):
200182
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
201183
inputs = self.get_dummy_inputs(torch_device)
202184

@@ -210,13 +192,3 @@ def test_flux_image_output_shape(self):
210192
output_height, output_width, _ = image.shape
211193
assert (output_height, output_width) == (expected_height, expected_width)
212194

213-
def test_flux_true_cfg(self):
214-
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
215-
inputs = self.get_dummy_inputs(torch_device)
216-
inputs.pop("generator")
217-
218-
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
219-
inputs["negative_prompt"] = "bad quality"
220-
inputs["true_cfg_scale"] = 2.0
221-
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
222-
assert not np.allclose(no_true_cfg_out, true_cfg_out)

0 commit comments

Comments
 (0)