Skip to content

Commit a9de100

Browse files
committed
more
1 parent 6178385 commit a9de100

File tree

3 files changed

+269
-8
lines changed

3 files changed

+269
-8
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,6 @@ def __init__(
228228

229229
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
230230

231-
sample_size = (
232-
self.config.sample_size[0]
233-
if isinstance(self.config.sample_size, (list, tuple))
234-
else self.config.sample_size
235-
)
236-
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
237-
self.tile_overlap_factor = 0.25
238-
239231
def _set_gradient_checkpointing(self, module, value=False):
240232
if isinstance(module, (Encoder, TemporalDecoder)):
241233
module.gradient_checkpointing = value
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import AutoencoderKLTemporalDecoder
21+
from diffusers.utils.testing_utils import (
22+
enable_full_determinism,
23+
floats_tensor,
24+
torch_all_close,
25+
torch_device,
26+
)
27+
28+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
29+
30+
31+
enable_full_determinism()
32+
33+
34+
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
35+
model_class = AutoencoderKLTemporalDecoder
36+
main_input_name = "sample"
37+
base_precision = 1e-2
38+
39+
@property
40+
def dummy_input(self):
41+
batch_size = 3
42+
num_channels = 3
43+
sizes = (32, 32)
44+
45+
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
46+
num_frames = 3
47+
48+
return {"sample": image, "num_frames": num_frames}
49+
50+
@property
51+
def input_shape(self):
52+
return (3, 32, 32)
53+
54+
@property
55+
def output_shape(self):
56+
return (3, 32, 32)
57+
58+
def prepare_init_args_and_inputs_for_common(self):
59+
init_dict = {
60+
"block_out_channels": [32, 64],
61+
"in_channels": 3,
62+
"out_channels": 3,
63+
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
64+
"latent_channels": 4,
65+
"layers_per_block": 2,
66+
}
67+
inputs_dict = self.dummy_input
68+
return init_dict, inputs_dict
69+
70+
@unittest.skipIf(torch_device == "mps", "Test not supported for MPS.")
71+
def test_gradient_checkpointing(self):
72+
# enable deterministic behavior for gradient checkpointing
73+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
74+
model = self.model_class(**init_dict)
75+
model.to(torch_device)
76+
77+
assert not model.is_gradient_checkpointing and model.training
78+
79+
out = model(**inputs_dict).sample
80+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
81+
# we won't calculate the loss and rather backprop on out.sum()
82+
model.zero_grad()
83+
84+
labels = torch.randn_like(out)
85+
loss = (out - labels).mean()
86+
loss.backward()
87+
88+
# re-instantiate the model now enabling gradient checkpointing
89+
model_2 = self.model_class(**init_dict)
90+
# clone model
91+
model_2.load_state_dict(model.state_dict())
92+
model_2.to(torch_device)
93+
model_2.enable_gradient_checkpointing()
94+
95+
assert model_2.is_gradient_checkpointing and model_2.training
96+
97+
out_2 = model_2(**inputs_dict).sample
98+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
99+
# we won't calculate the loss and rather backprop on out.sum()
100+
model_2.zero_grad()
101+
loss_2 = (out_2 - labels).mean()
102+
loss_2.backward()
103+
104+
# compare the output and parameters gradients
105+
self.assertTrue((loss - loss_2).abs() < 1e-5)
106+
named_params = dict(model.named_parameters())
107+
named_params_2 = dict(model_2.named_parameters())
108+
for name, param in named_params.items():
109+
if "post_quant_conv" in name:
110+
continue
111+
112+
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
113+
114+
@unittest.skip("Test unsupported.")
115+
def test_forward_with_norm_groups(self):
116+
pass
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import AutoencoderOobleck
21+
from diffusers.utils.testing_utils import (
22+
enable_full_determinism,
23+
floats_tensor,
24+
require_torch_accelerator_with_training,
25+
torch_all_close,
26+
torch_device,
27+
)
28+
29+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
30+
31+
32+
enable_full_determinism()
33+
34+
35+
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
36+
model_class = AutoencoderOobleck
37+
main_input_name = "sample"
38+
base_precision = 1e-2
39+
40+
def get_autoencoder_oobleck_config(self, block_out_channels=None):
41+
init_dict = {
42+
"encoder_hidden_size": 12,
43+
"decoder_channels": 12,
44+
"decoder_input_channels": 6,
45+
"audio_channels": 2,
46+
"downsampling_ratios": [2, 4],
47+
"channel_multiples": [1, 2],
48+
}
49+
return init_dict
50+
51+
@property
52+
def dummy_input(self):
53+
batch_size = 4
54+
num_channels = 2
55+
seq_len = 24
56+
57+
waveform = floats_tensor((batch_size, num_channels, seq_len)).to(torch_device)
58+
59+
return {"sample": waveform, "sample_posterior": False}
60+
61+
@property
62+
def input_shape(self):
63+
return (2, 24)
64+
65+
@property
66+
def output_shape(self):
67+
return (2, 24)
68+
69+
def prepare_init_args_and_inputs_for_common(self):
70+
init_dict = self.get_autoencoder_oobleck_config()
71+
inputs_dict = self.dummy_input
72+
return init_dict, inputs_dict
73+
74+
def test_enable_disable_slicing(self):
75+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
76+
77+
torch.manual_seed(0)
78+
model = self.model_class(**init_dict).to(torch_device)
79+
80+
inputs_dict.update({"return_dict": False})
81+
82+
torch.manual_seed(0)
83+
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
84+
85+
torch.manual_seed(0)
86+
model.enable_slicing()
87+
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
88+
89+
self.assertLess(
90+
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
91+
0.5,
92+
"VAE slicing should not affect the inference results",
93+
)
94+
95+
torch.manual_seed(0)
96+
model.disable_slicing()
97+
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
98+
99+
self.assertEqual(
100+
output_without_slicing.detach().cpu().numpy().all(),
101+
output_without_slicing_2.detach().cpu().numpy().all(),
102+
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
103+
)
104+
105+
@require_torch_accelerator_with_training
106+
def test_gradient_checkpointing(self):
107+
# enable deterministic behavior for gradient checkpointing
108+
# (TODO: sayakpaul): should be grouped in https://github.com/huggingface/diffusers/pull/9494
109+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
110+
model = self.model_class(**init_dict)
111+
model.to(torch_device)
112+
113+
assert not model.is_gradient_checkpointing and model.training
114+
115+
out = model(**inputs_dict).sample
116+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
117+
# we won't calculate the loss and rather backprop on out.sum()
118+
model.zero_grad()
119+
120+
labels = torch.randn_like(out)
121+
loss = (out - labels).mean()
122+
loss.backward()
123+
124+
# re-instantiate the model now enabling gradient checkpointing
125+
model_2 = self.model_class(**init_dict)
126+
# clone model
127+
model_2.load_state_dict(model.state_dict())
128+
model_2.to(torch_device)
129+
model_2.enable_gradient_checkpointing()
130+
131+
assert model_2.is_gradient_checkpointing and model_2.training
132+
133+
out_2 = model_2(**inputs_dict).sample
134+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
135+
# we won't calculate the loss and rather backprop on out.sum()
136+
model_2.zero_grad()
137+
loss_2 = (out_2 - labels).mean()
138+
loss_2.backward()
139+
140+
# compare the output and parameters gradients
141+
self.assertTrue((loss - loss_2).abs() < 1e-5)
142+
named_params = dict(model.named_parameters())
143+
named_params_2 = dict(model_2.named_parameters())
144+
for name, param in named_params.items():
145+
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
146+
147+
@unittest.skip("Test unsupported.")
148+
def test_forward_with_norm_groups(self):
149+
pass
150+
151+
@unittest.skip("No attention module used in this model")
152+
def test_set_attn_processor_for_determinism(self):
153+
return

0 commit comments

Comments
 (0)