Skip to content

Commit cfe1e2e

Browse files
committed
up
1 parent 1448b03 commit cfe1e2e

8 files changed

+85
-390
lines changed

tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
floats_tensor,
2626
torch_device,
2727
)
28-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
28+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin
2929

3030

3131
enable_full_determinism()
3232

3333

34-
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
34+
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
3535
model_class = AutoencoderKLHunyuanVideo
3636
main_input_name = "sample"
3737
base_precision = 1e-2
@@ -87,68 +87,6 @@ def prepare_init_args_and_inputs_for_common(self):
8787
inputs_dict = self.dummy_input
8888
return init_dict, inputs_dict
8989

90-
def test_enable_disable_tiling(self):
91-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
92-
93-
torch.manual_seed(0)
94-
model = self.model_class(**init_dict).to(torch_device)
95-
96-
inputs_dict.update({"return_dict": False})
97-
98-
torch.manual_seed(0)
99-
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
100-
101-
torch.manual_seed(0)
102-
model.enable_tiling()
103-
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
104-
105-
self.assertLess(
106-
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
107-
0.5,
108-
"VAE tiling should not affect the inference results",
109-
)
110-
111-
torch.manual_seed(0)
112-
model.disable_tiling()
113-
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
114-
115-
self.assertEqual(
116-
output_without_tiling.detach().cpu().numpy().all(),
117-
output_without_tiling_2.detach().cpu().numpy().all(),
118-
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
119-
)
120-
121-
def test_enable_disable_slicing(self):
122-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
123-
124-
torch.manual_seed(0)
125-
model = self.model_class(**init_dict).to(torch_device)
126-
127-
inputs_dict.update({"return_dict": False})
128-
129-
torch.manual_seed(0)
130-
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
131-
132-
torch.manual_seed(0)
133-
model.enable_slicing()
134-
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
135-
136-
self.assertLess(
137-
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
138-
0.5,
139-
"VAE slicing should not affect the inference results",
140-
)
141-
142-
torch.manual_seed(0)
143-
model.disable_slicing()
144-
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
145-
146-
self.assertEqual(
147-
output_without_slicing.detach().cpu().numpy().all(),
148-
output_without_slicing_2.detach().cpu().numpy().all(),
149-
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
150-
)
151-
15290
def test_gradient_checkpointing_is_applied(self):
15391
expected_set = {
15492
"HunyuanVideoDecoder3D",

tests/models/autoencoders/test_models_autoencoder_kl.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@
3535
torch_all_close,
3636
torch_device,
3737
)
38-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
38+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin
3939

4040

4141
enable_full_determinism()
4242

4343

44-
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
44+
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, VAETestMixin, unittest.TestCase):
4545
model_class = AutoencoderKL
4646
main_input_name = "sample"
4747
base_precision = 1e-2
@@ -83,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self):
8383
inputs_dict = self.dummy_input
8484
return init_dict, inputs_dict
8585

86-
def test_enable_disable_tiling(self):
87-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
88-
89-
torch.manual_seed(0)
90-
model = self.model_class(**init_dict).to(torch_device)
91-
92-
inputs_dict.update({"return_dict": False})
93-
94-
torch.manual_seed(0)
95-
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
96-
97-
torch.manual_seed(0)
98-
model.enable_tiling()
99-
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
100-
101-
self.assertLess(
102-
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
103-
0.5,
104-
"VAE tiling should not affect the inference results",
105-
)
106-
107-
torch.manual_seed(0)
108-
model.disable_tiling()
109-
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
110-
111-
self.assertEqual(
112-
output_without_tiling.detach().cpu().numpy().all(),
113-
output_without_tiling_2.detach().cpu().numpy().all(),
114-
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
115-
)
116-
117-
def test_enable_disable_slicing(self):
118-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
119-
120-
torch.manual_seed(0)
121-
model = self.model_class(**init_dict).to(torch_device)
122-
123-
inputs_dict.update({"return_dict": False})
124-
125-
torch.manual_seed(0)
126-
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
127-
128-
torch.manual_seed(0)
129-
model.enable_slicing()
130-
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
131-
132-
self.assertLess(
133-
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
134-
0.5,
135-
"VAE slicing should not affect the inference results",
136-
)
137-
138-
torch.manual_seed(0)
139-
model.disable_slicing()
140-
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
141-
142-
self.assertEqual(
143-
output_without_slicing.detach().cpu().numpy().all(),
144-
output_without_slicing_2.detach().cpu().numpy().all(),
145-
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
146-
)
147-
14886
def test_gradient_checkpointing_is_applied(self):
14987
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
15088
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
floats_tensor,
2525
torch_device,
2626
)
27-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
27+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin
2828

2929

3030
enable_full_determinism()
3131

3232

33-
class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
33+
class AutoencoderKLCogVideoXTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
3434
model_class = AutoencoderKLCogVideoX
3535
main_input_name = "sample"
3636
base_precision = 1e-2
@@ -82,68 +82,6 @@ def prepare_init_args_and_inputs_for_common(self):
8282
inputs_dict = self.dummy_input
8383
return init_dict, inputs_dict
8484

85-
def test_enable_disable_tiling(self):
86-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
87-
88-
torch.manual_seed(0)
89-
model = self.model_class(**init_dict).to(torch_device)
90-
91-
inputs_dict.update({"return_dict": False})
92-
93-
torch.manual_seed(0)
94-
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
95-
96-
torch.manual_seed(0)
97-
model.enable_tiling()
98-
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
99-
100-
self.assertLess(
101-
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
102-
0.5,
103-
"VAE tiling should not affect the inference results",
104-
)
105-
106-
torch.manual_seed(0)
107-
model.disable_tiling()
108-
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
109-
110-
self.assertEqual(
111-
output_without_tiling.detach().cpu().numpy().all(),
112-
output_without_tiling_2.detach().cpu().numpy().all(),
113-
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
114-
)
115-
116-
def test_enable_disable_slicing(self):
117-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
118-
119-
torch.manual_seed(0)
120-
model = self.model_class(**init_dict).to(torch_device)
121-
122-
inputs_dict.update({"return_dict": False})
123-
124-
torch.manual_seed(0)
125-
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
126-
127-
torch.manual_seed(0)
128-
model.enable_slicing()
129-
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
130-
131-
self.assertLess(
132-
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
133-
0.5,
134-
"VAE slicing should not affect the inference results",
135-
)
136-
137-
torch.manual_seed(0)
138-
model.disable_slicing()
139-
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
140-
141-
self.assertEqual(
142-
output_without_slicing.detach().cpu().numpy().all(),
143-
output_without_slicing_2.detach().cpu().numpy().all(),
144-
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
145-
)
146-
14785
def test_gradient_checkpointing_is_applied(self):
14886
expected_set = {
14987
"CogVideoXDownBlock3D",

tests/models/autoencoders/test_models_autoencoder_ltx_video.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
floats_tensor,
2525
torch_device,
2626
)
27-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
27+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin
2828

2929

3030
enable_full_determinism()
3131

3232

33-
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
33+
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
3434
model_class = AutoencoderKLLTXVideo
3535
main_input_name = "sample"
3636
base_precision = 1e-2
@@ -167,34 +167,3 @@ def test_outputs_equivalence(self):
167167
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
168168
def test_forward_with_norm_groups(self):
169169
pass
170-
171-
def test_enable_disable_tiling(self):
172-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
173-
174-
torch.manual_seed(0)
175-
model = self.model_class(**init_dict).to(torch_device)
176-
177-
inputs_dict.update({"return_dict": False})
178-
179-
torch.manual_seed(0)
180-
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
181-
182-
torch.manual_seed(0)
183-
model.enable_tiling()
184-
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
185-
186-
self.assertLess(
187-
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
188-
0.5,
189-
"VAE tiling should not affect the inference results",
190-
)
191-
192-
torch.manual_seed(0)
193-
model.disable_tiling()
194-
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
195-
196-
self.assertEqual(
197-
output_without_tiling.detach().cpu().numpy().all(),
198-
output_without_tiling_2.detach().cpu().numpy().all(),
199-
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
200-
)

tests/models/autoencoders/test_models_autoencoder_tiny.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
torch_all_close,
3232
torch_device,
3333
)
34-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
34+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin
3535

3636

3737
enable_full_determinism()
3838

3939

40-
class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
40+
class AutoencoderTinyTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
4141
model_class = AutoencoderTiny
4242
main_input_name = "sample"
4343
base_precision = 1e-2
@@ -81,37 +81,6 @@ def prepare_init_args_and_inputs_for_common(self):
8181
def test_enable_disable_tiling(self):
8282
pass
8383

84-
def test_enable_disable_slicing(self):
85-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
86-
87-
torch.manual_seed(0)
88-
model = self.model_class(**init_dict).to(torch_device)
89-
90-
inputs_dict.update({"return_dict": False})
91-
92-
torch.manual_seed(0)
93-
output_without_slicing = model(**inputs_dict)[0]
94-
95-
torch.manual_seed(0)
96-
model.enable_slicing()
97-
output_with_slicing = model(**inputs_dict)[0]
98-
99-
self.assertLess(
100-
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
101-
0.5,
102-
"VAE slicing should not affect the inference results",
103-
)
104-
105-
torch.manual_seed(0)
106-
model.disable_slicing()
107-
output_without_slicing_2 = model(**inputs_dict)[0]
108-
109-
self.assertEqual(
110-
output_without_slicing.detach().cpu().numpy().all(),
111-
output_without_slicing_2.detach().cpu().numpy().all(),
112-
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
113-
)
114-
11584
@unittest.skip("Test not supported.")
11685
def test_outputs_equivalence(self):
11786
pass

0 commit comments

Comments
 (0)