Skip to content

Commit 625cc8e

Browse files
committed
update
1 parent a2a9e4e commit 625cc8e

File tree

2 files changed

+359
-0
lines changed

2 files changed

+359
-0
lines changed

tests/modular/__init__.py

Whitespace-only changes.
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
import gc
2+
import unittest
3+
from typing import Callable, Union
4+
5+
import numpy as np
6+
import torch
7+
8+
import diffusers
9+
from diffusers import (
10+
DiffusionPipeline,
11+
)
12+
from diffusers.utils import logging
13+
from diffusers.utils.testing_utils import (
14+
backend_empty_cache,
15+
numpy_cosine_similarity_distance,
16+
require_accelerator,
17+
require_torch,
18+
torch_device,
19+
)
20+
21+
22+
def to_np(tensor):
23+
if isinstance(tensor, torch.Tensor):
24+
tensor = tensor.detach().cpu().numpy()
25+
26+
return tensor
27+
28+
29+
@require_torch
30+
class ModularPipelineTesterMixin:
31+
"""
32+
This mixin is designed to be used with unittest.TestCase classes.
33+
It provides a set of common tests for each modular pipeline,
34+
including:
35+
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
36+
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
37+
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
38+
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
39+
- test_to_device: check if the pipeline's __call__ method can handle different devices
40+
"""
41+
42+
# Canonical parameters that are passed to `__call__` regardless
43+
# of the type of pipeline. They are always optional and have common
44+
# sense default values.
45+
required_optional_params = frozenset(
46+
[
47+
"num_inference_steps",
48+
"num_images_per_prompt",
49+
"latents",
50+
"output_type",
51+
]
52+
)
53+
# this is modular specific: generator needs to be a intermediate input because it's mutable
54+
required_intermediate_params = frozenset(
55+
[
56+
"generator",
57+
]
58+
)
59+
60+
def get_generator(self, seed):
61+
device = torch_device if torch_device != "mps" else "cpu"
62+
generator = torch.Generator(device).manual_seed(seed)
63+
return generator
64+
65+
@property
66+
def pipeline_class(self) -> Union[Callable, DiffusionPipeline]:
67+
raise NotImplementedError(
68+
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
69+
"See existing pipeline tests for reference."
70+
)
71+
72+
@property
73+
def repo(self) -> str:
74+
raise NotImplementedError(
75+
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
76+
)
77+
78+
@property
79+
def pipeline_blocks_class(self) -> Union[Callable, DiffusionPipeline]:
80+
raise NotImplementedError(
81+
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
82+
"See existing pipeline tests for reference."
83+
)
84+
85+
def get_pipeline(self):
86+
raise NotImplementedError(
87+
"You need to implement `get_pipeline(self)` in the child test class. "
88+
"See existing pipeline tests for reference."
89+
)
90+
91+
def get_dummy_inputs(self, device, seed=0):
92+
raise NotImplementedError(
93+
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
94+
"See existing pipeline tests for reference."
95+
)
96+
97+
@property
98+
def params(self) -> frozenset:
99+
raise NotImplementedError(
100+
"You need to set the attribute `params` in the child test class. "
101+
"`params` are checked for if all values are present in `__call__`'s signature."
102+
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
103+
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
104+
"image pipelines, including prompts and prompt embedding overrides."
105+
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
106+
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
107+
"with non-configurable height and width arguments should set the attribute as "
108+
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
109+
"See existing pipeline tests for reference."
110+
)
111+
112+
@property
113+
def batch_params(self) -> frozenset:
114+
raise NotImplementedError(
115+
"You need to set the attribute `batch_params` in the child test class. "
116+
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
117+
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
118+
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
119+
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
120+
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
121+
"image pipeline `negative_prompt` is not batched should set the attribute as "
122+
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
123+
"See existing pipeline tests for reference."
124+
)
125+
126+
def setUp(self):
127+
# clean up the VRAM before each test
128+
super().setUp()
129+
torch.compiler.reset()
130+
gc.collect()
131+
backend_empty_cache(torch_device)
132+
133+
def tearDown(self):
134+
# clean up the VRAM after each test in case of CUDA runtime errors
135+
super().tearDown()
136+
torch.compiler.reset()
137+
gc.collect()
138+
backend_empty_cache(torch_device)
139+
140+
def test_pipeline_call_signature(self):
141+
pipe = self.get_pipeline()
142+
parameters = pipe.blocks.input_names
143+
optional_parameters = pipe.default_call_parameters
144+
intermediate_parameters = pipe.blocks.intermediate_input_names
145+
146+
remaining_required_parameters = set()
147+
148+
for param in self.params:
149+
if param not in parameters:
150+
remaining_required_parameters.add(param)
151+
152+
self.assertTrue(
153+
len(remaining_required_parameters) == 0,
154+
f"Required parameters not present: {remaining_required_parameters}",
155+
)
156+
157+
remaining_required_intermediate_parameters = set()
158+
159+
for param in self.required_intermediate_params:
160+
if param not in intermediate_parameters:
161+
remaining_required_intermediate_parameters.add(param)
162+
163+
self.assertTrue(
164+
len(remaining_required_intermediate_parameters) == 0,
165+
f"Required intermediate parameters not present: {remaining_required_intermediate_parameters}",
166+
)
167+
168+
remaining_required_optional_parameters = set()
169+
170+
for param in self.required_optional_params:
171+
if param not in optional_parameters:
172+
remaining_required_optional_parameters.add(param)
173+
174+
self.assertTrue(
175+
len(remaining_required_optional_parameters) == 0,
176+
f"Required optional parameters not present: {remaining_required_optional_parameters}",
177+
)
178+
179+
def test_inference_batch_consistent(self, batch_sizes=[2]):
180+
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
181+
182+
def _test_inference_batch_consistent(
183+
self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True
184+
):
185+
pipe = self.get_pipeline()
186+
pipe.to(torch_device)
187+
pipe.set_progress_bar_config(disable=None)
188+
189+
inputs = self.get_dummy_inputs(torch_device)
190+
inputs["generator"] = self.get_generator(0)
191+
192+
logger = logging.get_logger(pipe.__module__)
193+
logger.setLevel(level=diffusers.logging.FATAL)
194+
195+
# prepare batched inputs
196+
batched_inputs = []
197+
for batch_size in batch_sizes:
198+
batched_input = {}
199+
batched_input.update(inputs)
200+
201+
for name in self.batch_params:
202+
if name not in inputs:
203+
continue
204+
205+
value = inputs[name]
206+
if name == "prompt":
207+
len_prompt = len(value)
208+
# make unequal batch sizes
209+
batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
210+
211+
# make last batch super long
212+
batched_input[name][-1] = 100 * "very long"
213+
214+
else:
215+
batched_input[name] = batch_size * [value]
216+
217+
if batch_generator and "generator" in inputs:
218+
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
219+
220+
if "batch_size" in inputs:
221+
batched_input["batch_size"] = batch_size
222+
223+
batched_inputs.append(batched_input)
224+
225+
logger.setLevel(level=diffusers.logging.WARNING)
226+
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
227+
output = pipe(**batched_input, output="images")
228+
assert len(output) == batch_size
229+
230+
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4):
231+
self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
232+
233+
def _test_inference_batch_single_identical(
234+
self,
235+
batch_size=2,
236+
expected_max_diff=1e-4,
237+
additional_params_copy_to_batched_inputs=["num_inference_steps"],
238+
):
239+
pipe = self.get_pipeline()
240+
pipe.to(torch_device)
241+
pipe.set_progress_bar_config(disable=None)
242+
inputs = self.get_dummy_inputs(torch_device)
243+
# Reset generator in case it is has been used in self.get_dummy_inputs
244+
inputs["generator"] = self.get_generator(0)
245+
246+
logger = logging.get_logger(pipe.__module__)
247+
logger.setLevel(level=diffusers.logging.FATAL)
248+
249+
# batchify inputs
250+
batched_inputs = {}
251+
batched_inputs.update(inputs)
252+
253+
for name in self.batch_params:
254+
if name not in inputs:
255+
continue
256+
257+
value = inputs[name]
258+
if name == "prompt":
259+
len_prompt = len(value)
260+
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
261+
batched_inputs[name][-1] = 100 * "very long"
262+
263+
else:
264+
batched_inputs[name] = batch_size * [value]
265+
266+
if "generator" in inputs:
267+
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
268+
269+
if "batch_size" in inputs:
270+
batched_inputs["batch_size"] = batch_size
271+
272+
for arg in additional_params_copy_to_batched_inputs:
273+
batched_inputs[arg] = inputs[arg]
274+
275+
output = pipe(**inputs, output="images")
276+
output_batch = pipe(**batched_inputs, output="images")
277+
278+
assert output_batch.shape[0] == batch_size
279+
280+
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
281+
assert max_diff < expected_max_diff
282+
283+
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
284+
@require_accelerator
285+
def test_float16_inference(self, expected_max_diff=5e-2):
286+
pipe = self.get_pipeline(torch_dtype=torch.float32)
287+
288+
pipe.to(torch_device)
289+
pipe.set_progress_bar_config(disable=None)
290+
291+
pipe_fp16 = self.get_pipeline(torch_dtype=torch.float16)
292+
pipe_fp16.to(torch_device, torch.float16)
293+
pipe_fp16.set_progress_bar_config(disable=None)
294+
295+
inputs = self.get_dummy_inputs(torch_device)
296+
# Reset generator in case it is used inside dummy inputs
297+
if "generator" in inputs:
298+
inputs["generator"] = self.get_generator(0)
299+
output = pipe(**inputs, output="images")
300+
301+
fp16_inputs = self.get_dummy_inputs(torch_device)
302+
# Reset generator in case it is used inside dummy inputs
303+
if "generator" in fp16_inputs:
304+
fp16_inputs["generator"] = self.get_generator(0)
305+
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
306+
307+
if isinstance(output, torch.Tensor):
308+
output = output.cpu()
309+
output_fp16 = output_fp16.cpu()
310+
311+
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
312+
assert max_diff < expected_max_diff
313+
314+
@require_accelerator
315+
def test_to_device(self):
316+
pipe = self.get_pipeline()
317+
pipe.set_progress_bar_config(disable=None)
318+
319+
pipe.to("cpu")
320+
model_devices = [
321+
component.device.type for component in pipe.components.values() if hasattr(component, "device")
322+
]
323+
self.assertTrue(all(device == "cpu" for device in model_devices))
324+
325+
output_cpu = pipe(**self.get_dummy_inputs("cpu"), output="images")
326+
self.assertTrue(np.isnan(output_cpu).sum() == 0)
327+
328+
pipe.to(torch_device)
329+
model_devices = [
330+
component.device.type for component in pipe.components.values() if hasattr(component, "device")
331+
]
332+
self.assertTrue(all(device == torch_device for device in model_devices))
333+
334+
output_device = pipe(**self.get_dummy_inputs(torch_device), output="images")
335+
self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
336+
337+
def test_num_images_per_prompt(self):
338+
pipe = self.get_pipeline()
339+
340+
if "num_images_per_prompt" not in pipe.blocks.input_names:
341+
return
342+
343+
pipe = pipe.to(torch_device)
344+
pipe.set_progress_bar_config(disable=None)
345+
346+
batch_sizes = [1, 2]
347+
num_images_per_prompts = [1, 2]
348+
349+
for batch_size in batch_sizes:
350+
for num_images_per_prompt in num_images_per_prompts:
351+
inputs = self.get_dummy_inputs(torch_device)
352+
353+
for key in inputs.keys():
354+
if key in self.batch_params:
355+
inputs[key] = batch_size * [inputs[key]]
356+
357+
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
358+
359+
assert images.shape[0] == batch_size * num_images_per_prompt

0 commit comments

Comments
 (0)