Skip to content

Commit 135df5b

Browse files
authored
[tests] Add inference test slices for SD3 and remove unnecessary tests (#12106)
* update * nuke LoC for inference slices
1 parent 4a9dbd5 commit 135df5b

File tree

3 files changed

+46
-217
lines changed

3 files changed

+46
-217
lines changed

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -124,37 +124,22 @@ def get_dummy_inputs(self, device, seed=0):
124124
}
125125
return inputs
126126

127-
def test_stable_diffusion_3_different_prompts(self):
128-
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
129-
130-
inputs = self.get_dummy_inputs(torch_device)
131-
output_same_prompt = pipe(**inputs).images[0]
132-
133-
inputs = self.get_dummy_inputs(torch_device)
134-
inputs["prompt_2"] = "a different prompt"
135-
inputs["prompt_3"] = "another different prompt"
136-
output_different_prompts = pipe(**inputs).images[0]
137-
138-
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
139-
140-
# Outputs should be different here
141-
assert max_diff > 1e-2
142-
143-
def test_stable_diffusion_3_different_negative_prompts(self):
144-
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
145-
146-
inputs = self.get_dummy_inputs(torch_device)
147-
output_same_prompt = pipe(**inputs).images[0]
127+
def test_inference(self):
128+
components = self.get_dummy_components()
129+
pipe = self.pipeline_class(**components)
148130

149131
inputs = self.get_dummy_inputs(torch_device)
150-
inputs["negative_prompt_2"] = "deformed"
151-
inputs["negative_prompt_3"] = "blurry"
152-
output_different_prompts = pipe(**inputs).images[0]
132+
image = pipe(**inputs).images[0]
133+
generated_slice = image.flatten()
134+
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
153135

154-
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
136+
# fmt: off
137+
expected_slice = np.array([0.5112, 0.5228, 0.5235, 0.5524, 0.3188, 0.5017, 0.5574, 0.4899, 0.6812, 0.5991, 0.3908, 0.5213, 0.5582, 0.4457, 0.4204, 0.5616])
138+
# fmt: on
155139

156-
# Outputs should be different here
157-
assert max_diff > 1e-2
140+
self.assertTrue(
141+
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
142+
)
158143

159144
def test_fused_qkv_projections(self):
160145
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -268,40 +253,9 @@ def test_sd3_inference(self):
268253

269254
image = pipe(**inputs).images[0]
270255
image_slice = image[0, :10, :10]
271-
expected_slice = np.array(
272-
[
273-
0.4648,
274-
0.4404,
275-
0.4177,
276-
0.5063,
277-
0.4800,
278-
0.4287,
279-
0.5425,
280-
0.5190,
281-
0.4717,
282-
0.5430,
283-
0.5195,
284-
0.4766,
285-
0.5361,
286-
0.5122,
287-
0.4612,
288-
0.4871,
289-
0.4749,
290-
0.4058,
291-
0.4756,
292-
0.4678,
293-
0.3804,
294-
0.4832,
295-
0.4822,
296-
0.3799,
297-
0.5103,
298-
0.5034,
299-
0.3953,
300-
0.5073,
301-
0.4839,
302-
0.3884,
303-
]
304-
)
256+
# fmt: off
257+
expected_slice = np.array([0.4648, 0.4404, 0.4177, 0.5063, 0.4800, 0.4287, 0.5425, 0.5190, 0.4717, 0.5430, 0.5195, 0.4766, 0.5361, 0.5122, 0.4612, 0.4871, 0.4749, 0.4058, 0.4756, 0.4678, 0.3804, 0.4832, 0.4822, 0.3799, 0.5103, 0.5034, 0.3953, 0.5073, 0.4839, 0.3884])
258+
# fmt: on
305259

306260
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
307261

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py

Lines changed: 18 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -128,37 +128,22 @@ def get_dummy_inputs(self, device, seed=0):
128128
}
129129
return inputs
130130

131-
def test_stable_diffusion_3_img2img_different_prompts(self):
132-
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
131+
def test_inference(self):
132+
components = self.get_dummy_components()
133+
pipe = self.pipeline_class(**components)
133134

134135
inputs = self.get_dummy_inputs(torch_device)
135-
output_same_prompt = pipe(**inputs).images[0]
136-
137-
inputs = self.get_dummy_inputs(torch_device)
138-
inputs["prompt_2"] = "a different prompt"
139-
inputs["prompt_3"] = "another different prompt"
140-
output_different_prompts = pipe(**inputs).images[0]
141-
142-
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
143-
144-
# Outputs should be different here
145-
assert max_diff > 1e-2
146-
147-
def test_stable_diffusion_3_img2img_different_negative_prompts(self):
148-
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
149-
150-
inputs = self.get_dummy_inputs(torch_device)
151-
output_same_prompt = pipe(**inputs).images[0]
152-
153-
inputs = self.get_dummy_inputs(torch_device)
154-
inputs["negative_prompt_2"] = "deformed"
155-
inputs["negative_prompt_3"] = "blurry"
156-
output_different_prompts = pipe(**inputs).images[0]
136+
image = pipe(**inputs).images[0]
137+
generated_slice = image.flatten()
138+
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
157139

158-
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
140+
# fmt: off
141+
expected_slice = np.array([0.4564, 0.5486, 0.4868, 0.5923, 0.3775, 0.5543, 0.4807, 0.4177, 0.3778, 0.5957, 0.5726, 0.4333, 0.6312, 0.5062, 0.4838, 0.5984])
142+
# fmt: on
159143

160-
# Outputs should be different here
161-
assert max_diff > 1e-2
144+
self.assertTrue(
145+
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
146+
)
162147

163148
@unittest.skip("Skip for now.")
164149
def test_multi_vae(self):
@@ -207,112 +192,16 @@ def test_sd3_img2img_inference(self):
207192
inputs = self.get_inputs(torch_device)
208193
image = pipe(**inputs).images[0]
209194
image_slice = image[0, :10, :10]
195+
196+
# fmt: off
210197
expected_slices = Expectations(
211198
{
212-
("xpu", 3): np.array(
213-
[
214-
0.5117,
215-
0.4421,
216-
0.3852,
217-
0.5044,
218-
0.4219,
219-
0.3262,
220-
0.5024,
221-
0.4329,
222-
0.3276,
223-
0.4978,
224-
0.4412,
225-
0.3355,
226-
0.4983,
227-
0.4338,
228-
0.3279,
229-
0.4893,
230-
0.4241,
231-
0.3129,
232-
0.4875,
233-
0.4253,
234-
0.3030,
235-
0.4961,
236-
0.4267,
237-
0.2988,
238-
0.5029,
239-
0.4255,
240-
0.3054,
241-
0.5132,
242-
0.4248,
243-
0.3222,
244-
]
245-
),
246-
("cuda", 7): np.array(
247-
[
248-
0.5435,
249-
0.4673,
250-
0.5732,
251-
0.4438,
252-
0.3557,
253-
0.4912,
254-
0.4331,
255-
0.3491,
256-
0.4915,
257-
0.4287,
258-
0.347,
259-
0.4849,
260-
0.4355,
261-
0.3469,
262-
0.4871,
263-
0.4431,
264-
0.3538,
265-
0.4912,
266-
0.4521,
267-
0.3643,
268-
0.5059,
269-
0.4587,
270-
0.373,
271-
0.5166,
272-
0.4685,
273-
0.3845,
274-
0.5264,
275-
0.4746,
276-
0.3914,
277-
0.5342,
278-
]
279-
),
280-
("cuda", 8): np.array(
281-
[
282-
0.5146,
283-
0.4385,
284-
0.3826,
285-
0.5098,
286-
0.4150,
287-
0.3218,
288-
0.5142,
289-
0.4312,
290-
0.3298,
291-
0.5127,
292-
0.4431,
293-
0.3411,
294-
0.5171,
295-
0.4424,
296-
0.3374,
297-
0.5088,
298-
0.4348,
299-
0.3242,
300-
0.5073,
301-
0.4380,
302-
0.3174,
303-
0.5132,
304-
0.4397,
305-
0.3115,
306-
0.5132,
307-
0.4343,
308-
0.3118,
309-
0.5219,
310-
0.4328,
311-
0.3256,
312-
]
313-
),
199+
("xpu", 3): np.array([0.5117, 0.4421, 0.3852, 0.5044, 0.4219, 0.3262, 0.5024, 0.4329, 0.3276, 0.4978, 0.4412, 0.3355, 0.4983, 0.4338, 0.3279, 0.4893, 0.4241, 0.3129, 0.4875, 0.4253, 0.3030, 0.4961, 0.4267, 0.2988, 0.5029, 0.4255, 0.3054, 0.5132, 0.4248, 0.3222]),
200+
("cuda", 7): np.array([0.5435, 0.4673, 0.5732, 0.4438, 0.3557, 0.4912, 0.4331, 0.3491, 0.4915, 0.4287, 0.347, 0.4849, 0.4355, 0.3469, 0.4871, 0.4431, 0.3538, 0.4912, 0.4521, 0.3643, 0.5059, 0.4587, 0.373, 0.5166, 0.4685, 0.3845, 0.5264, 0.4746, 0.3914, 0.5342]),
201+
("cuda", 8): np.array([0.5146, 0.4385, 0.3826, 0.5098, 0.4150, 0.3218, 0.5142, 0.4312, 0.3298, 0.5127, 0.4431, 0.3411, 0.5171, 0.4424, 0.3374, 0.5088, 0.4348, 0.3242, 0.5073, 0.4380, 0.3174, 0.5132, 0.4397, 0.3115, 0.5132, 0.4343, 0.3118, 0.5219, 0.4328, 0.3256]),
314202
}
315203
)
204+
# fmt: on
316205

317206
expected_slice = expected_slices.get_expectation()
318207

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -132,37 +132,23 @@ def get_dummy_inputs(self, device, seed=0):
132132
}
133133
return inputs
134134

135-
def test_stable_diffusion_3_inpaint_different_prompts(self):
136-
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
135+
def test_inference(self):
136+
components = self.get_dummy_components()
137+
pipe = self.pipeline_class(**components)
137138

138139
inputs = self.get_dummy_inputs(torch_device)
139-
output_same_prompt = pipe(**inputs).images[0]
140+
image = pipe(**inputs).images[0]
141+
generated_slice = image.flatten()
142+
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
140143

141-
inputs = self.get_dummy_inputs(torch_device)
142-
inputs["prompt_2"] = "a different prompt"
143-
inputs["prompt_3"] = "another different prompt"
144-
output_different_prompts = pipe(**inputs).images[0]
145-
146-
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
147-
148-
# Outputs should be different here
149-
assert max_diff > 1e-2
150-
151-
def test_stable_diffusion_3_inpaint_different_negative_prompts(self):
152-
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
153-
154-
inputs = self.get_dummy_inputs(torch_device)
155-
output_same_prompt = pipe(**inputs).images[0]
156-
157-
inputs = self.get_dummy_inputs(torch_device)
158-
inputs["negative_prompt_2"] = "deformed"
159-
inputs["negative_prompt_3"] = "blurry"
160-
output_different_prompts = pipe(**inputs).images[0]
144+
# fmt: off
145+
expected_slice = np.array([0.5035, 0.6661, 0.5859, 0.413, 0.4224, 0.4234, 0.7181, 0.5062, 0.5183, 0.6877, 0.5074, 0.585, 0.6111, 0.5422, 0.5306, 0.5891])
146+
# fmt: on
161147

162-
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
163-
164-
# Outputs should be different here
165-
assert max_diff > 1e-2
148+
self.assertTrue(
149+
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
150+
)
166151

152+
@unittest.skip("Skip for now.")
167153
def test_multi_vae(self):
168154
pass

0 commit comments

Comments
 (0)