Skip to content

Commit aa29af8

Browse files
committed
update
1 parent bffa3a9 commit aa29af8

File tree

5 files changed

+53
-32
lines changed

5 files changed

+53
-32
lines changed

tests/models/testing_utils/attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@
2121
AttnProcessor,
2222
)
2323

24-
from ...testing_utils import is_attention, require_accelerator, torch_device
24+
from ...testing_utils import is_attention, torch_device
2525

2626

2727
@is_attention
28-
@require_accelerator
2928
class AttentionTesterMixin:
3029
"""
3130
Mixin class for testing attention processor and module functionality on models.

tests/models/testing_utils/common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import json
1717
import os
1818
import tempfile
19-
from typing import Dict, List, Tuple
2019

2120
import pytest
2221
import torch
22+
import torch.nn as nn
2323
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
2424

2525
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
@@ -30,8 +30,8 @@
3030

3131
def compute_module_persistent_sizes(
3232
model: nn.Module,
33-
dtype: Optional[Union[str, torch.device]] = None,
34-
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
33+
dtype: str | torch.device | None = None,
34+
special_dtypes: dict[str, str | torch.device] | None = None,
3535
):
3636
"""
3737
Compute the size of each submodule of a given model (parameters + persistent buffers).
@@ -128,6 +128,7 @@ def get_dummy_inputs(self):
128128
)
129129

130130
def test_from_save_pretrained(self, expected_max_diff=5e-5):
131+
torch.manual_seed(0)
131132
model = self.model_class(**self.get_init_dict())
132133
model.to(torch_device)
133134
model.eval()
@@ -273,10 +274,10 @@ def set_nan_tensor_to_zero(t):
273274
return t.to(device)
274275

275276
def recursive_check(tuple_object, dict_object):
276-
if isinstance(tuple_object, (List, Tuple)):
277+
if isinstance(tuple_object, (list, tuple)):
277278
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
278279
recursive_check(tuple_iterable_value, dict_iterable_value)
279-
elif isinstance(tuple_object, Dict):
280+
elif isinstance(tuple_object, dict):
280281
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
281282
recursive_check(tuple_iterable_value, dict_iterable_value)
282283
elif tuple_object is None:

tests/models/testing_utils/quantization.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
is_gguf_available,
2626
is_nvidia_modelopt_available,
2727
is_optimum_quanto_available,
28+
is_torchao_available,
2829
)
2930

3031
from ...testing_utils import (
@@ -41,6 +42,7 @@
4142
require_gguf_version_greater_or_equal,
4243
require_quanto,
4344
require_torchao_version_greater_or_equal,
45+
require_modelopt_version_greater_or_equal,
4446
torch_device,
4547
)
4648

@@ -58,7 +60,6 @@
5860
pass
5961

6062
if is_torchao_available():
61-
6263
if is_torchao_version(">=", "0.9.0"):
6364
pass
6465

@@ -644,9 +645,7 @@ def test_torchao_modules_to_not_convert(self):
644645
if modules_to_exclude is None:
645646
pytest.skip("modules_to_not_convert_for_test not defined for this model")
646647

647-
self._test_quantization_modules_to_not_convert(
648-
self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
649-
)
648+
self._test_quantization_modules_to_not_convert(self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude)
650649

651650
def test_torchao_device_map(self):
652651
"""Test that device_map='auto' works correctly with quantization."""

tests/models/transformers/test_models_transformer_flux_.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import Any
17+
1618
import torch
1719

1820
from diffusers import FluxTransformer2DModel
@@ -46,7 +48,11 @@ class FluxTransformerTesterConfig:
4648
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
4749
pretrained_model_kwargs = {"subfolder": "transformer"}
4850

49-
def get_init_dict(self):
51+
@property
52+
def generator(self):
53+
return torch.Generator("cpu").manual_seed(0)
54+
55+
def get_init_dict(self) -> dict[str, int | list[int]]:
5056
"""Return Flux model initialization arguments."""
5157
return {
5258
"patch_size": 1,
@@ -60,30 +66,32 @@ def get_init_dict(self):
6066
"axes_dims_rope": [4, 4, 8],
6167
}
6268

63-
def get_dummy_inputs(self):
69+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
6470
batch_size = 1
6571
height = width = 4
6672
num_latent_channels = 4
6773
num_image_channels = 3
68-
sequence_length = 24
69-
embedding_dim = 8
74+
sequence_length = 48
75+
embedding_dim = 32
7076

7177
return {
72-
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
73-
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
74-
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
75-
"img_ids": randn_tensor((height * width, num_image_channels)),
76-
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
78+
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator),
79+
"encoder_hidden_states": randn_tensor(
80+
(batch_size, sequence_length, embedding_dim), generator=self.generator
81+
),
82+
"pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator),
83+
"img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator),
84+
"txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator),
7785
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
7886
}
7987

8088
@property
81-
def input_shape(self):
82-
return (16, 4)
89+
def input_shape(self) -> tuple[int, int]:
90+
return (1, 16, 4)
8391

8492
@property
85-
def output_shape(self):
86-
return (16, 4)
93+
def output_shape(self) -> tuple[int, int]:
94+
return (1, 16, 4)
8795

8896

8997
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
@@ -140,7 +148,7 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
140148
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
141149
"""IP Adapter tests for Flux Transformer."""
142150

143-
def create_ip_adapter_state_dict(self, model):
151+
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
144152
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
145153

146154
ip_cross_attn_state_dict = {}
@@ -202,7 +210,7 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin
202210

203211
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
204212

205-
def get_dummy_inputs(self, height=4, width=4):
213+
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
206214
"""Override to support dynamic height/width for LoRA hotswap tests."""
207215
batch_size = 1
208216
num_latent_channels = 4
@@ -223,7 +231,7 @@ def get_dummy_inputs(self, height=4, width=4):
223231
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
224232
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
225233

226-
def get_dummy_inputs(self, height=4, width=4):
234+
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
227235
"""Override to support dynamic height/width for compilation tests."""
228236
batch_size = 1
229237
num_latent_channels = 4
@@ -250,7 +258,7 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
250258

251259

252260
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
253-
def get_dummy_inputs(self):
261+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
254262
return {
255263
"hidden_states": randn_tensor((1, 4096, 64)),
256264
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
@@ -263,7 +271,7 @@ def get_dummy_inputs(self):
263271

264272

265273
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
266-
def get_dummy_inputs(self):
274+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
267275
return {
268276
"hidden_states": randn_tensor((1, 4096, 64)),
269277
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
@@ -276,7 +284,7 @@ def get_dummy_inputs(self):
276284

277285

278286
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
279-
def get_dummy_inputs(self):
287+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
280288
return {
281289
"hidden_states": randn_tensor((1, 4096, 64)),
282290
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
@@ -291,7 +299,7 @@ def get_dummy_inputs(self):
291299
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
292300
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
293301

294-
def get_dummy_inputs(self):
302+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
295303
return {
296304
"hidden_states": randn_tensor((1, 4096, 64)),
297305
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
@@ -304,7 +312,7 @@ def get_dummy_inputs(self):
304312

305313

306314
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
307-
def get_dummy_inputs(self):
315+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
308316
return {
309317
"hidden_states": randn_tensor((1, 4096, 64)),
310318
"encoder_hidden_states": randn_tensor((1, 512, 4096)),

tests/testing_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
is_flax_available,
3838
is_gguf_available,
3939
is_kernels_available,
40+
is_nvidia_modelopt_available,
4041
is_note_seq_available,
4142
is_onnx_available,
4243
is_opencv_available,
@@ -765,6 +766,19 @@ def decorator(test_case):
765766
return decorator
766767

767768

769+
def require_modelopt_version_greater_or_equal(modelopt_version):
770+
def decorator(test_case):
771+
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
772+
version.parse(importlib.metadata.version("modelopt")).base_version
773+
) >= version.parse(modelopt_version)
774+
return pytest.mark.skipif(
775+
not correct_nvidia_modelopt_version,
776+
f"Test requires modelopt with version greater than {modelopt_version}.",
777+
)(test_case)
778+
779+
return decorator
780+
781+
768782
def deprecate_after_peft_backend(test_case):
769783
"""
770784
Decorator marking a test that will be skipped after PEFT backend

0 commit comments

Comments
 (0)