Skip to content

Commit b36b6b5

Browse files
committed
Update UNet2DModel gradient checkpointing tests to follow the current UNet2DConditionModel gradient checkpointing tests
1 parent 658a8de commit b36b6b5

File tree

1 file changed

+8
-64
lines changed

1 file changed

+8
-64
lines changed

tests/models/unets/test_models_unet_2d.py

Lines changed: 8 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
enable_full_determinism,
2727
floats_tensor,
2828
require_torch_accelerator,
29-
require_torch_accelerator_with_training,
3029
slow,
3130
torch_all_close,
3231
torch_device,
@@ -107,77 +106,22 @@ def test_mid_block_attn_groups(self):
107106
expected_shape = inputs_dict["sample"].shape
108107
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
109108

110-
@require_torch_accelerator_with_training
111-
def test_gradient_checkpointing(self):
112-
# enable deterministic behavior for gradient checkpointing
113-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
114-
model = self.model_class(**init_dict)
115-
model.to(torch_device)
116-
117-
assert not model.is_gradient_checkpointing and model.training
118-
119-
out = model(**inputs_dict).sample
120-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
121-
# we won't calculate the loss and rather backprop on out.sum()
122-
model.zero_grad()
123-
124-
labels = torch.randn_like(out)
125-
loss = (out - labels).mean()
126-
loss.backward()
127-
128-
# re-instantiate the model now enabling gradient checkpointing
129-
model_2 = self.model_class(**init_dict)
130-
# clone model
131-
model_2.load_state_dict(model.state_dict())
132-
model_2.to(torch_device)
133-
model_2.enable_gradient_checkpointing()
134-
135-
assert model_2.is_gradient_checkpointing and model_2.training
136-
137-
out_2 = model_2(**inputs_dict).sample
138-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
139-
# we won't calculate the loss and rather backprop on out.sum()
140-
model_2.zero_grad()
141-
loss_2 = (out_2 - labels).mean()
142-
loss_2.backward()
143-
144-
# compare the output and parameters gradients
145-
self.assertTrue((loss - loss_2).abs() < 1e-5)
146-
named_params = dict(model.named_parameters())
147-
named_params_2 = dict(model_2.named_parameters())
148-
for name, param in named_params.items():
149-
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
150-
151109
def test_gradient_checkpointing_is_applied(self):
152-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
153-
154-
# NOTE: UNet2DModel only supports int arguments for `attention_head_dim` currently
155-
init_dict["attention_head_dim"] = 8
156-
157-
model_class_copy = copy.copy(self.model_class)
158-
159-
modules_with_gc_enabled = {}
160-
161-
def _set_gradient_checkpointing_new(self, module, value=False):
162-
if hasattr(module, "gradient_checkpointing"):
163-
module.gradient_checkpointing = value
164-
modules_with_gc_enabled[module.__class__.__name__] = True
165-
166-
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
167-
168-
model = model_class_copy(**init_dict)
169-
model.enable_gradient_checkpointing()
170-
171-
EXPECTED_SET = {
110+
expected_set = {
172111
"AttnUpBlock2D",
173112
"AttnDownBlock2D",
174113
"UNetMidBlock2D",
175114
"UpBlock2D",
176115
"DownBlock2D",
177116
}
178117

179-
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
180-
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
118+
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
119+
attention_head_dim = 8
120+
block_out_channels = (16, 32)
121+
122+
super().test_gradient_checkpointing_is_applied(
123+
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
124+
)
181125

182126

183127
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):

0 commit comments

Comments
 (0)