|
26 | 26 | enable_full_determinism, |
27 | 27 | floats_tensor, |
28 | 28 | require_torch_accelerator, |
29 | | - require_torch_accelerator_with_training, |
30 | 29 | slow, |
31 | 30 | torch_all_close, |
32 | 31 | torch_device, |
@@ -107,77 +106,22 @@ def test_mid_block_attn_groups(self): |
107 | 106 | expected_shape = inputs_dict["sample"].shape |
108 | 107 | self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") |
109 | 108 |
|
110 | | - @require_torch_accelerator_with_training |
111 | | - def test_gradient_checkpointing(self): |
112 | | - # enable deterministic behavior for gradient checkpointing |
113 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
114 | | - model = self.model_class(**init_dict) |
115 | | - model.to(torch_device) |
116 | | - |
117 | | - assert not model.is_gradient_checkpointing and model.training |
118 | | - |
119 | | - out = model(**inputs_dict).sample |
120 | | - # run the backwards pass on the model. For backwards pass, for simplicity purpose, |
121 | | - # we won't calculate the loss and rather backprop on out.sum() |
122 | | - model.zero_grad() |
123 | | - |
124 | | - labels = torch.randn_like(out) |
125 | | - loss = (out - labels).mean() |
126 | | - loss.backward() |
127 | | - |
128 | | - # re-instantiate the model now enabling gradient checkpointing |
129 | | - model_2 = self.model_class(**init_dict) |
130 | | - # clone model |
131 | | - model_2.load_state_dict(model.state_dict()) |
132 | | - model_2.to(torch_device) |
133 | | - model_2.enable_gradient_checkpointing() |
134 | | - |
135 | | - assert model_2.is_gradient_checkpointing and model_2.training |
136 | | - |
137 | | - out_2 = model_2(**inputs_dict).sample |
138 | | - # run the backwards pass on the model. For backwards pass, for simplicity purpose, |
139 | | - # we won't calculate the loss and rather backprop on out.sum() |
140 | | - model_2.zero_grad() |
141 | | - loss_2 = (out_2 - labels).mean() |
142 | | - loss_2.backward() |
143 | | - |
144 | | - # compare the output and parameters gradients |
145 | | - self.assertTrue((loss - loss_2).abs() < 1e-5) |
146 | | - named_params = dict(model.named_parameters()) |
147 | | - named_params_2 = dict(model_2.named_parameters()) |
148 | | - for name, param in named_params.items(): |
149 | | - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) |
150 | | - |
151 | 109 | def test_gradient_checkpointing_is_applied(self): |
152 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
153 | | - |
154 | | - # NOTE: UNet2DModel only supports int arguments for `attention_head_dim` currently |
155 | | - init_dict["attention_head_dim"] = 8 |
156 | | - |
157 | | - model_class_copy = copy.copy(self.model_class) |
158 | | - |
159 | | - modules_with_gc_enabled = {} |
160 | | - |
161 | | - def _set_gradient_checkpointing_new(self, module, value=False): |
162 | | - if hasattr(module, "gradient_checkpointing"): |
163 | | - module.gradient_checkpointing = value |
164 | | - modules_with_gc_enabled[module.__class__.__name__] = True |
165 | | - |
166 | | - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new |
167 | | - |
168 | | - model = model_class_copy(**init_dict) |
169 | | - model.enable_gradient_checkpointing() |
170 | | - |
171 | | - EXPECTED_SET = { |
| 110 | + expected_set = { |
172 | 111 | "AttnUpBlock2D", |
173 | 112 | "AttnDownBlock2D", |
174 | 113 | "UNetMidBlock2D", |
175 | 114 | "UpBlock2D", |
176 | 115 | "DownBlock2D", |
177 | 116 | } |
178 | 117 |
|
179 | | - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET |
180 | | - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" |
| 118 | + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` |
| 119 | + attention_head_dim = 8 |
| 120 | + block_out_channels = (16, 32) |
| 121 | + |
| 122 | + super().test_gradient_checkpointing_is_applied( |
| 123 | + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels |
| 124 | + ) |
181 | 125 |
|
182 | 126 |
|
183 | 127 | class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
|
0 commit comments