|  | 
| 26 | 26 |     enable_full_determinism, | 
| 27 | 27 |     floats_tensor, | 
| 28 | 28 |     require_torch_accelerator, | 
| 29 |  | -    require_torch_accelerator_with_training, | 
| 30 | 29 |     slow, | 
| 31 | 30 |     torch_all_close, | 
| 32 | 31 |     torch_device, | 
| @@ -107,77 +106,22 @@ def test_mid_block_attn_groups(self): | 
| 107 | 106 |         expected_shape = inputs_dict["sample"].shape | 
| 108 | 107 |         self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") | 
| 109 | 108 | 
 | 
| 110 |  | -    @require_torch_accelerator_with_training | 
| 111 |  | -    def test_gradient_checkpointing(self): | 
| 112 |  | -        # enable deterministic behavior for gradient checkpointing | 
| 113 |  | -        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
| 114 |  | -        model = self.model_class(**init_dict) | 
| 115 |  | -        model.to(torch_device) | 
| 116 |  | - | 
| 117 |  | -        assert not model.is_gradient_checkpointing and model.training | 
| 118 |  | - | 
| 119 |  | -        out = model(**inputs_dict).sample | 
| 120 |  | -        # run the backwards pass on the model. For backwards pass, for simplicity purpose, | 
| 121 |  | -        # we won't calculate the loss and rather backprop on out.sum() | 
| 122 |  | -        model.zero_grad() | 
| 123 |  | - | 
| 124 |  | -        labels = torch.randn_like(out) | 
| 125 |  | -        loss = (out - labels).mean() | 
| 126 |  | -        loss.backward() | 
| 127 |  | - | 
| 128 |  | -        # re-instantiate the model now enabling gradient checkpointing | 
| 129 |  | -        model_2 = self.model_class(**init_dict) | 
| 130 |  | -        # clone model | 
| 131 |  | -        model_2.load_state_dict(model.state_dict()) | 
| 132 |  | -        model_2.to(torch_device) | 
| 133 |  | -        model_2.enable_gradient_checkpointing() | 
| 134 |  | - | 
| 135 |  | -        assert model_2.is_gradient_checkpointing and model_2.training | 
| 136 |  | - | 
| 137 |  | -        out_2 = model_2(**inputs_dict).sample | 
| 138 |  | -        # run the backwards pass on the model. For backwards pass, for simplicity purpose, | 
| 139 |  | -        # we won't calculate the loss and rather backprop on out.sum() | 
| 140 |  | -        model_2.zero_grad() | 
| 141 |  | -        loss_2 = (out_2 - labels).mean() | 
| 142 |  | -        loss_2.backward() | 
| 143 |  | - | 
| 144 |  | -        # compare the output and parameters gradients | 
| 145 |  | -        self.assertTrue((loss - loss_2).abs() < 1e-5) | 
| 146 |  | -        named_params = dict(model.named_parameters()) | 
| 147 |  | -        named_params_2 = dict(model_2.named_parameters()) | 
| 148 |  | -        for name, param in named_params.items(): | 
| 149 |  | -            self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) | 
| 150 |  | - | 
| 151 | 109 |     def test_gradient_checkpointing_is_applied(self): | 
| 152 |  | -        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
| 153 |  | - | 
| 154 |  | -        # NOTE: UNet2DModel only supports int arguments for `attention_head_dim` currently | 
| 155 |  | -        init_dict["attention_head_dim"] = 8 | 
| 156 |  | - | 
| 157 |  | -        model_class_copy = copy.copy(self.model_class) | 
| 158 |  | - | 
| 159 |  | -        modules_with_gc_enabled = {} | 
| 160 |  | - | 
| 161 |  | -        def _set_gradient_checkpointing_new(self, module, value=False): | 
| 162 |  | -            if hasattr(module, "gradient_checkpointing"): | 
| 163 |  | -                module.gradient_checkpointing = value | 
| 164 |  | -                modules_with_gc_enabled[module.__class__.__name__] = True | 
| 165 |  | - | 
| 166 |  | -        model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new | 
| 167 |  | - | 
| 168 |  | -        model = model_class_copy(**init_dict) | 
| 169 |  | -        model.enable_gradient_checkpointing() | 
| 170 |  | - | 
| 171 |  | -        EXPECTED_SET = { | 
|  | 110 | +        expected_set = { | 
| 172 | 111 |             "AttnUpBlock2D", | 
| 173 | 112 |             "AttnDownBlock2D", | 
| 174 | 113 |             "UNetMidBlock2D", | 
| 175 | 114 |             "UpBlock2D", | 
| 176 | 115 |             "DownBlock2D", | 
| 177 | 116 |         } | 
| 178 | 117 | 
 | 
| 179 |  | -        assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET | 
| 180 |  | -        assert all(modules_with_gc_enabled.values()), "All modules should be enabled" | 
|  | 118 | +        # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` | 
|  | 119 | +        attention_head_dim = 8 | 
|  | 120 | +        block_out_channels = (16, 32) | 
|  | 121 | + | 
|  | 122 | +        super().test_gradient_checkpointing_is_applied( | 
|  | 123 | +            expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels | 
|  | 124 | +        ) | 
| 181 | 125 | 
 | 
| 182 | 126 | 
 | 
| 183 | 127 | class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): | 
|  | 
0 commit comments