Skip to content

Commit af76988

Browse files
authored
[tests] introduce VAETesterMixin to consolidate tests for slicing and tiling (huggingface#12374)
* up * up * up * up * up * u[ * up * up * up
1 parent 4715c5c commit af76988

17 files changed

+204
-446
lines changed

tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@
3535
torch_all_close,
3636
torch_device,
3737
)
38-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
38+
from ..test_modeling_common import ModelTesterMixin
39+
from .testing_utils import AutoencoderTesterMixin
3940

4041

4142
enable_full_determinism()
4243

4344

44-
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
45+
class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
4546
model_class = AsymmetricAutoencoderKL
4647
main_input_name = "sample"
4748
base_precision = 1e-2

tests/models/autoencoders/test_models_autoencoder_cosmos.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
from diffusers import AutoencoderKLCosmos
1818

1919
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
20-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
20+
from ..test_modeling_common import ModelTesterMixin
21+
from .testing_utils import AutoencoderTesterMixin
2122

2223

2324
enable_full_determinism()
2425

2526

26-
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
27+
class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
2728
model_class = AutoencoderKLCosmos
2829
main_input_name = "sample"
2930
base_precision = 1e-2
@@ -80,7 +81,3 @@ def test_gradient_checkpointing_is_applied(self):
8081
@unittest.skip("Not sure why this test fails. Investigate later.")
8182
def test_effective_gradient_checkpointing(self):
8283
pass
83-
84-
@unittest.skip("Unsupported test.")
85-
def test_forward_with_norm_groups(self):
86-
pass

tests/models/autoencoders/test_models_autoencoder_dc.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
floats_tensor,
2323
torch_device,
2424
)
25-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
25+
from ..test_modeling_common import ModelTesterMixin
26+
from .testing_utils import AutoencoderTesterMixin
2627

2728

2829
enable_full_determinism()
2930

3031

31-
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
32+
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
3233
model_class = AutoencoderDC
3334
main_input_name = "sample"
3435
base_precision = 1e-2
@@ -81,7 +82,3 @@ def prepare_init_args_and_inputs_for_common(self):
8182
init_dict = self.get_autoencoder_dc_config()
8283
inputs_dict = self.dummy_input
8384
return init_dict, inputs_dict
84-
85-
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
86-
def test_forward_with_norm_groups(self):
87-
pass

tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py

Lines changed: 4 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,15 @@
2020
from diffusers import AutoencoderKLHunyuanVideo
2121
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
2222

23-
from ...testing_utils import (
24-
enable_full_determinism,
25-
floats_tensor,
26-
torch_device,
27-
)
28-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
23+
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
24+
from ..test_modeling_common import ModelTesterMixin
25+
from .testing_utils import AutoencoderTesterMixin
2926

3027

3128
enable_full_determinism()
3229

3330

34-
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
31+
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
3532
model_class = AutoencoderKLHunyuanVideo
3633
main_input_name = "sample"
3734
base_precision = 1e-2
@@ -87,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self):
8784
inputs_dict = self.dummy_input
8885
return init_dict, inputs_dict
8986

90-
def test_enable_disable_tiling(self):
91-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
92-
93-
torch.manual_seed(0)
94-
model = self.model_class(**init_dict).to(torch_device)
95-
96-
inputs_dict.update({"return_dict": False})
97-
98-
torch.manual_seed(0)
99-
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
100-
101-
torch.manual_seed(0)
102-
model.enable_tiling()
103-
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
104-
105-
self.assertLess(
106-
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
107-
0.5,
108-
"VAE tiling should not affect the inference results",
109-
)
110-
111-
torch.manual_seed(0)
112-
model.disable_tiling()
113-
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
114-
115-
self.assertEqual(
116-
output_without_tiling.detach().cpu().numpy().all(),
117-
output_without_tiling_2.detach().cpu().numpy().all(),
118-
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
119-
)
120-
121-
def test_enable_disable_slicing(self):
122-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
123-
124-
torch.manual_seed(0)
125-
model = self.model_class(**init_dict).to(torch_device)
126-
127-
inputs_dict.update({"return_dict": False})
128-
129-
torch.manual_seed(0)
130-
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
131-
132-
torch.manual_seed(0)
133-
model.enable_slicing()
134-
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
135-
136-
self.assertLess(
137-
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
138-
0.5,
139-
"VAE slicing should not affect the inference results",
140-
)
141-
142-
torch.manual_seed(0)
143-
model.disable_slicing()
144-
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
145-
146-
self.assertEqual(
147-
output_without_slicing.detach().cpu().numpy().all(),
148-
output_without_slicing_2.detach().cpu().numpy().all(),
149-
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
150-
)
151-
15287
def test_gradient_checkpointing_is_applied(self):
15388
expected_set = {
15489
"HunyuanVideoDecoder3D",

tests/models/autoencoders/test_models_autoencoder_kl.py

Lines changed: 3 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@
3535
torch_all_close,
3636
torch_device,
3737
)
38-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
38+
from ..test_modeling_common import ModelTesterMixin
39+
from .testing_utils import AutoencoderTesterMixin
3940

4041

4142
enable_full_determinism()
4243

4344

44-
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
45+
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
4546
model_class = AutoencoderKL
4647
main_input_name = "sample"
4748
base_precision = 1e-2
@@ -83,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self):
8384
inputs_dict = self.dummy_input
8485
return init_dict, inputs_dict
8586

86-
def test_enable_disable_tiling(self):
87-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
88-
89-
torch.manual_seed(0)
90-
model = self.model_class(**init_dict).to(torch_device)
91-
92-
inputs_dict.update({"return_dict": False})
93-
94-
torch.manual_seed(0)
95-
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
96-
97-
torch.manual_seed(0)
98-
model.enable_tiling()
99-
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
100-
101-
self.assertLess(
102-
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
103-
0.5,
104-
"VAE tiling should not affect the inference results",
105-
)
106-
107-
torch.manual_seed(0)
108-
model.disable_tiling()
109-
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
110-
111-
self.assertEqual(
112-
output_without_tiling.detach().cpu().numpy().all(),
113-
output_without_tiling_2.detach().cpu().numpy().all(),
114-
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
115-
)
116-
117-
def test_enable_disable_slicing(self):
118-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
119-
120-
torch.manual_seed(0)
121-
model = self.model_class(**init_dict).to(torch_device)
122-
123-
inputs_dict.update({"return_dict": False})
124-
125-
torch.manual_seed(0)
126-
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
127-
128-
torch.manual_seed(0)
129-
model.enable_slicing()
130-
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
131-
132-
self.assertLess(
133-
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
134-
0.5,
135-
"VAE slicing should not affect the inference results",
136-
)
137-
138-
torch.manual_seed(0)
139-
model.disable_slicing()
140-
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
141-
142-
self.assertEqual(
143-
output_without_slicing.detach().cpu().numpy().all(),
144-
output_without_slicing_2.detach().cpu().numpy().all(),
145-
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
146-
)
147-
14887
def test_gradient_checkpointing_is_applied(self):
14988
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
15089
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py

Lines changed: 3 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
floats_tensor,
2525
torch_device,
2626
)
27-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
27+
from ..test_modeling_common import ModelTesterMixin
28+
from .testing_utils import AutoencoderTesterMixin
2829

2930

3031
enable_full_determinism()
3132

3233

33-
class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
34+
class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
3435
model_class = AutoencoderKLCogVideoX
3536
main_input_name = "sample"
3637
base_precision = 1e-2
@@ -82,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self):
8283
inputs_dict = self.dummy_input
8384
return init_dict, inputs_dict
8485

85-
def test_enable_disable_tiling(self):
86-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
87-
88-
torch.manual_seed(0)
89-
model = self.model_class(**init_dict).to(torch_device)
90-
91-
inputs_dict.update({"return_dict": False})
92-
93-
torch.manual_seed(0)
94-
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
95-
96-
torch.manual_seed(0)
97-
model.enable_tiling()
98-
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
99-
100-
self.assertLess(
101-
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
102-
0.5,
103-
"VAE tiling should not affect the inference results",
104-
)
105-
106-
torch.manual_seed(0)
107-
model.disable_tiling()
108-
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
109-
110-
self.assertEqual(
111-
output_without_tiling.detach().cpu().numpy().all(),
112-
output_without_tiling_2.detach().cpu().numpy().all(),
113-
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
114-
)
115-
116-
def test_enable_disable_slicing(self):
117-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
118-
119-
torch.manual_seed(0)
120-
model = self.model_class(**init_dict).to(torch_device)
121-
122-
inputs_dict.update({"return_dict": False})
123-
124-
torch.manual_seed(0)
125-
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
126-
127-
torch.manual_seed(0)
128-
model.enable_slicing()
129-
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
130-
131-
self.assertLess(
132-
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
133-
0.5,
134-
"VAE slicing should not affect the inference results",
135-
)
136-
137-
torch.manual_seed(0)
138-
model.disable_slicing()
139-
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
140-
141-
self.assertEqual(
142-
output_without_slicing.detach().cpu().numpy().all(),
143-
output_without_slicing_2.detach().cpu().numpy().all(),
144-
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
145-
)
146-
14786
def test_gradient_checkpointing_is_applied(self):
14887
expected_set = {
14988
"CogVideoXDownBlock3D",

tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
floats_tensor,
2323
torch_device,
2424
)
25-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
25+
from ..test_modeling_common import ModelTesterMixin
26+
from .testing_utils import AutoencoderTesterMixin
2627

2728

2829
enable_full_determinism()
2930

3031

31-
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
32+
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
3233
model_class = AutoencoderKLTemporalDecoder
3334
main_input_name = "sample"
3435
base_precision = 1e-2
@@ -67,7 +68,3 @@ def prepare_init_args_and_inputs_for_common(self):
6768
def test_gradient_checkpointing_is_applied(self):
6869
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
6970
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
70-
71-
@unittest.skip("Test unsupported.")
72-
def test_forward_with_norm_groups(self):
73-
pass

0 commit comments

Comments
 (0)