|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +""" |
| 17 | +GradSampleControllerFastGradientClipping: Controller-based Fast Gradient and Ghost Clipping. |
| 18 | +
|
| 19 | +This module provides a GradSampleModule-less approach with ghost clipping support, |
| 20 | +combining the benefits of: |
| 21 | +- Controller-based hook management (no model wrapping) |
| 22 | +- Ghost clipping (memory-efficient gradient norm computation) |
| 23 | +""" |
| 24 | + |
| 25 | +import logging |
| 26 | +from typing import List |
| 27 | + |
| 28 | +import torch |
| 29 | +import torch.nn as nn |
| 30 | +from opacus.grad_sample.functorch import ft_compute_per_sample_gradient |
| 31 | +from opacus.grad_sample.grad_sample_controller import GradSampleController |
| 32 | +from opacus.grad_sample.grad_sample_module import ( |
| 33 | + create_or_accumulate_grad_sample, |
| 34 | + promote_current_grad_sample, |
| 35 | +) |
| 36 | +from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN |
| 37 | +from opacus.utils.module_utils import trainable_modules, trainable_parameters |
| 38 | + |
| 39 | + |
| 40 | +logger = logging.getLogger(__name__) |
| 41 | +logger.disabled = True |
| 42 | + |
| 43 | + |
| 44 | +def create_norm_sample( |
| 45 | + *, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int |
| 46 | +) -> None: |
| 47 | + """ |
| 48 | + Creates a ``_norm_sample`` attribute in the given parameter |
| 49 | +
|
| 50 | +
|
| 51 | + Args: |
| 52 | + param: Parameter to which ``_norm_sample`` will be added |
| 53 | + grad_sample: Per-sample gradients tensor. Must be of the same |
| 54 | + shape as ``param`` with extra batch dimension |
| 55 | + """ |
| 56 | + |
| 57 | + if param.requires_grad: |
| 58 | + if ( |
| 59 | + max_batch_len == 0 |
| 60 | + ): # To handle the case of empty batch that may arise from Poisson sampling |
| 61 | + param._norm_sample = torch.tensor( |
| 62 | + [], device=grad_sample.device, dtype=grad_sample.dtype |
| 63 | + ) |
| 64 | + else: |
| 65 | + param._norm_sample = torch.zeros( |
| 66 | + torch.Size([max_batch_len, 1]), |
| 67 | + device=grad_sample.device, |
| 68 | + dtype=grad_sample.dtype, |
| 69 | + ) |
| 70 | + param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm( |
| 71 | + 2, dim=-1 |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +class GradSampleControllerFastGradientClipping(GradSampleController): |
| 76 | + """ |
| 77 | + Controller for managing privacy hooks with Fast Gradient and Ghost Clipping support |
| 78 | +
|
| 79 | + Extends GradSampleController to add ghost clipping support for memory-efficient |
| 80 | + gradient norm computation. Supports both: |
| 81 | + - Ghost Clipping: Direct norm computation without materializing full gradients |
| 82 | + - Fast Gradient Clipping: Full gradient computation followed by norm computation |
| 83 | +
|
| 84 | + This class attaches hooks directly to model modules and manages their lifecycle, |
| 85 | + providing an alternative to GradSampleModule wrapping that's more compatible |
| 86 | + with transformers and other complex models. |
| 87 | + """ |
| 88 | + |
| 89 | + NORM_SAMPLERS = {} |
| 90 | + |
| 91 | + def __init__( |
| 92 | + self, |
| 93 | + m: nn.Module, |
| 94 | + *, |
| 95 | + batch_first=True, |
| 96 | + loss_reduction="mean", |
| 97 | + strict: bool = True, |
| 98 | + force_functorch=False, |
| 99 | + max_grad_norm=1, |
| 100 | + use_ghost_clipping=True, |
| 101 | + ): |
| 102 | + """ |
| 103 | +
|
| 104 | + Args: |
| 105 | + m: nn.Module to attach hooks to |
| 106 | + batch_first: Flag to indicate if the input tensor to the corresponding module |
| 107 | + has the first dimension representing the batch. If set to True, dimensions on |
| 108 | + input tensor are expected be ``[batch_size, ...]``, otherwise |
| 109 | + ``[K, batch_size, ...]`` |
| 110 | + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) |
| 111 | + is a sum or a mean operation. Can take values "sum" or "mean" |
| 112 | + max_grad_norm: The value at which gradients are to be clipped. |
| 113 | + strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers, |
| 114 | + which is not currently supported by Opacus. |
| 115 | + If set to ``False``, per sample gradients will |
| 116 | + be computed on "best effort" basis - they will be available where |
| 117 | + possible and set to None otherwise. This is not recommended, because |
| 118 | + some unsupported modules (e.g. BatchNorm) affect other parameters and |
| 119 | + invalidate the concept of per sample gradients for the entire model. |
| 120 | + force_functorch: If set to ``True``, will use functorch to compute |
| 121 | + all per sample gradients. Otherwise, functorch will be used only |
| 122 | + for layers without registered grad sampler methods. |
| 123 | + use_ghost_clipping: If set to ``True``, Ghost Clipping |
| 124 | + will be used for clipping gradients of supported layers. If ``False``, Fast |
| 125 | + Gradient Clipping will be used for all layers. |
| 126 | +
|
| 127 | + Raises: |
| 128 | + NotImplementedError |
| 129 | + If ``strict`` is set to ``True`` and module ``m`` (or any of its |
| 130 | + submodules) includes a buffer. |
| 131 | + """ |
| 132 | + # Call parent constructor |
| 133 | + super().__init__( |
| 134 | + m, |
| 135 | + batch_first=batch_first, |
| 136 | + loss_reduction=loss_reduction, |
| 137 | + strict=strict, |
| 138 | + force_functorch=force_functorch, |
| 139 | + ) |
| 140 | + |
| 141 | + # Add ghost clipping specific attributes |
| 142 | + self.max_grad_norm = max_grad_norm |
| 143 | + self.use_ghost_clipping = use_ghost_clipping |
| 144 | + self._per_sample_gradient_norms = None |
| 145 | + |
| 146 | + # Initialize _norm_sample attribute for parameters |
| 147 | + for _, p in trainable_parameters(self.module): |
| 148 | + p._norm_sample = None |
| 149 | + |
| 150 | + self.trainable_parameters = [p for _, p in trainable_parameters(self.module)] |
| 151 | + |
| 152 | + if logger.isEnabledFor(logging.INFO): |
| 153 | + self.log_module_gradient_sample_mode( |
| 154 | + module=m, |
| 155 | + force_functorch=force_functorch, |
| 156 | + use_ghost_clipping=use_ghost_clipping, |
| 157 | + ) |
| 158 | + |
| 159 | + def get_clipping_coef(self) -> torch.Tensor: |
| 160 | + """Get per-example gradient scaling factor for clipping.""" |
| 161 | + norm_sample = self.get_norm_sample() |
| 162 | + return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0) |
| 163 | + |
| 164 | + def get_norm_sample(self) -> torch.Tensor: |
| 165 | + """Get per-example gradient norms.""" |
| 166 | + norm_sample = torch.stack( |
| 167 | + [param._norm_sample for param in self.trainable_parameters], dim=0 |
| 168 | + ).norm(2, dim=0) |
| 169 | + self.per_sample_gradient_norms = norm_sample |
| 170 | + return norm_sample |
| 171 | + |
| 172 | + def capture_activations_hook( |
| 173 | + self, |
| 174 | + module: nn.Module, |
| 175 | + forward_input: List[torch.Tensor], |
| 176 | + _forward_output: torch.Tensor, |
| 177 | + ): |
| 178 | + """ |
| 179 | + Override parent method to add parameter tying check for ghost clipping. |
| 180 | + """ |
| 181 | + # Call parent implementation |
| 182 | + super().capture_activations_hook(module, forward_input, _forward_output) |
| 183 | + |
| 184 | + # Add ghost clipping specific check for parameter tying |
| 185 | + if self.hooks_enabled: |
| 186 | + for _, p in trainable_parameters(module): |
| 187 | + if ( |
| 188 | + self.use_ghost_clipping |
| 189 | + and p._forward_counter > 1 |
| 190 | + and type(module) in self.NORM_SAMPLERS |
| 191 | + ): |
| 192 | + raise NotImplementedError( |
| 193 | + "Parameter tying is not supported with Ghost Clipping" |
| 194 | + ) |
| 195 | + |
| 196 | + def capture_backprops_hook( |
| 197 | + self, |
| 198 | + module: nn.Module, |
| 199 | + _forward_input: torch.Tensor, |
| 200 | + forward_output: torch.Tensor, |
| 201 | + loss_reduction: str, |
| 202 | + batch_first: bool, |
| 203 | + ): |
| 204 | + """ |
| 205 | + Computes per sample gradient norms given the current backprops and activations |
| 206 | + stored by the associated forward hook. Computed per sample gradient norms are |
| 207 | + stored in ``_norm_sample`` field in each parameter. |
| 208 | +
|
| 209 | + For non-recurrent layers the process is straightforward: for each |
| 210 | + ``loss.backward()`` call this hook will be called exactly one. For recurrent |
| 211 | + layers, however, this is more complicated and the hook will be called multiple |
| 212 | + times, while still processing the same batch of data. |
| 213 | +
|
| 214 | + For this reason we first accumulate the gradients from *the same batch* in |
| 215 | + ``p._current_grad_sample`` and then, when we detect the end of a full backward |
| 216 | + pass - we store accumulated result on ``p.grad_sample`` (for fast gradient clipping) |
| 217 | + or ``p._norm_sample`` (for ghost clipping). |
| 218 | +
|
| 219 | + Args: |
| 220 | + module: nn.Module, |
| 221 | + _forward_input: torch.Tensor, |
| 222 | + forward_output: torch.Tensor, |
| 223 | + loss_reduction: str, |
| 224 | + batch_first: bool, |
| 225 | + """ |
| 226 | + if not self.hooks_enabled: |
| 227 | + return |
| 228 | + |
| 229 | + backprops = forward_output[0].detach() |
| 230 | + activations, backprops = self.rearrange_grad_samples( |
| 231 | + module=module, |
| 232 | + backprops=backprops, |
| 233 | + loss_reduction=loss_reduction, |
| 234 | + batch_first=batch_first, |
| 235 | + ) |
| 236 | + |
| 237 | + # Handle DTensor if needed |
| 238 | + activations = [ |
| 239 | + temp.to_local() if type(temp) is torch.distributed.tensor.DTensor else temp |
| 240 | + for temp in activations |
| 241 | + ] |
| 242 | + |
| 243 | + if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS: |
| 244 | + # Ghost clipping: compute norms directly |
| 245 | + norm_sampler_fn = self.NORM_SAMPLERS[type(module)] |
| 246 | + norm_samples = norm_sampler_fn(module, activations, backprops) |
| 247 | + |
| 248 | + for param, ns in norm_samples.items(): |
| 249 | + if param.requires_grad: |
| 250 | + param._norm_sample = ns |
| 251 | + param._forward_counter -= 1 |
| 252 | + |
| 253 | + else: |
| 254 | + # Fast gradient clipping: compute full gradients then norms |
| 255 | + if not self.force_functorch and type(module) in self.GRAD_SAMPLERS: |
| 256 | + grad_sampler_fn = self.GRAD_SAMPLERS[type(module)] |
| 257 | + else: |
| 258 | + grad_sampler_fn = ft_compute_per_sample_gradient |
| 259 | + |
| 260 | + grad_samples = grad_sampler_fn(module, activations, backprops) |
| 261 | + for param, gs in grad_samples.items(): |
| 262 | + create_or_accumulate_grad_sample( |
| 263 | + param=param, grad_sample=gs, max_batch_len=module.max_batch_len |
| 264 | + ) |
| 265 | + del grad_samples |
| 266 | + |
| 267 | + # Detect end of current batch processing and switch accumulation |
| 268 | + # mode from sum to stacking. Used for RNNs and tied parameters |
| 269 | + # (See #417 for details) |
| 270 | + for _, p in trainable_parameters(module): |
| 271 | + p._forward_counter -= 1 |
| 272 | + if p._forward_counter == 0: |
| 273 | + promote_current_grad_sample(p) |
| 274 | + create_norm_sample( |
| 275 | + param=p, |
| 276 | + grad_sample=p.grad_sample, |
| 277 | + max_batch_len=module.max_batch_len, |
| 278 | + ) |
| 279 | + p.grad_sample = None |
| 280 | + |
| 281 | + if len(module.activations) == 0: |
| 282 | + if hasattr(module, "max_batch_len"): |
| 283 | + del module.max_batch_len |
| 284 | + |
| 285 | + def log_module_gradient_sample_mode( |
| 286 | + self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True |
| 287 | + ): |
| 288 | + """ |
| 289 | + Add logs to track gradient sample mode for each part of the module, including 1) Ghost Clipping, 2) Fast Gradient Clipping (hook mode), and 3) Fast Gradient Clipping (functorch mode). |
| 290 | +
|
| 291 | + Args: |
| 292 | + module: nn.Module to be checked |
| 293 | + force_functorch: If set to ``True``, will use functorch to compute |
| 294 | + all per sample gradients. Otherwise, functorch will be used only |
| 295 | + for layers without registered grad sampler methods. |
| 296 | + use_ghost_clipping: If set to ``True``, Ghost Clipping |
| 297 | + will be used for clipping gradients of supported layers. If ``False``, Fast |
| 298 | + Gradient Clipping will be used for all layers. |
| 299 | + """ |
| 300 | + for m_name, m in trainable_modules(module): |
| 301 | + if type(m) in [DPRNN, DPLSTM, DPGRU]: |
| 302 | + logger.info( |
| 303 | + f"Module name: {m_name}, module type: {type(m)}. No hook or functorch is added." |
| 304 | + ) |
| 305 | + |
| 306 | + elif use_ghost_clipping and type(m) in self.NORM_SAMPLERS: |
| 307 | + logger.info( |
| 308 | + f"Module name: {m_name}, module type: {type(m)}, under Ghost Clipping." |
| 309 | + ) |
| 310 | + |
| 311 | + else: |
| 312 | + if not force_functorch and type(m) in self.GRAD_SAMPLERS: |
| 313 | + # When functorch is not enforced, use FGC (hook mode) if the layer has a registered grad_sampler (supported). Otherwise, use FGC (functorch mode). |
| 314 | + logger.info( |
| 315 | + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (hook mode)." |
| 316 | + ) |
| 317 | + else: |
| 318 | + logger.info( |
| 319 | + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (functorch mode)." |
| 320 | + ) |
| 321 | + |
| 322 | + @property |
| 323 | + def per_sample_gradient_norms(self) -> torch.Tensor: |
| 324 | + """Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings""" |
| 325 | + if self._per_sample_gradient_norms is not None: |
| 326 | + return self._per_sample_gradient_norms |
| 327 | + else: |
| 328 | + raise AttributeError( |
| 329 | + "per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property." |
| 330 | + ) |
| 331 | + |
| 332 | + @per_sample_gradient_norms.setter |
| 333 | + def per_sample_gradient_norms(self, value): |
| 334 | + self._per_sample_gradient_norms = value |
0 commit comments