diff --git a/opacus/__init__.py b/opacus/__init__.py index b2009227..94adc2b9 100644 --- a/opacus/__init__.py +++ b/opacus/__init__.py @@ -14,13 +14,18 @@ # limitations under the License. from . import utils -from .grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping +from .grad_sample import ( + GradSampleController, + GradSampleModule, + GradSampleModuleFastGradientClipping, +) from .privacy_engine import PrivacyEngine from .version import __version__ __all__ = [ "PrivacyEngine", + "GradSampleController", "GradSampleModule", "GradSampleModuleFastGradientClipping", "utils", diff --git a/opacus/grad_sample/README.md b/opacus/grad_sample/README.md index 7827680f..f7320491 100644 --- a/opacus/grad_sample/README.md +++ b/opacus/grad_sample/README.md @@ -3,25 +3,54 @@ Computing per sample gradients is an integral part of Opacus framework. We strive to provide out-of-the-box support for wide range of models, while keeping computations efficient. -We currently provide two independent approaches for computing per sample gradients: hooks-based ``GradSampleModule`` -(stable implementation, exists since the very first version of Opacus) and ``GradSampleModuleExpandedWeights`` -(based on a beta functionality available in PyTorch 1.12). - -Each of the two implementations comes with it's own set of limitations, and we leave the choice up to the client -which one to use. - -``GradSampleModuleExpandedWeights`` is currently in early beta and can produce unexpected errors, but potentially -improves upon ``GradSampleModule`` on performance and functionality. - -**TL;DR:** If you want stable implementation, use ``GradSampleModule`` (`grad_sample_mode="hooks"`). -If you want to experiment with the new functionality, you have two options. Try -``GradSampleModuleExpandedWeights``(`grad_sample_mode="ew"`) for better performance and `grad_sample_mode=functorch` -if your model is not supported by ``GradSampleModule``. - -Please switch back to ``GradSampleModule``(`grad_sample_mode="hooks"`) if you encounter strange errors or unexpexted behaviour. -We'd also appreciate it if you report these to us - -## Hooks-based approach +We currently provide three independent approaches for computing per sample gradients: + +1. **Hooks-based `GradSampleModule`** (stable, wraps the model) +2. **`GradSampleController`** (stable, no model wrapping - recommended for transformers) +3. **`GradSampleModuleExpandedWeights`** (beta, based on PyTorch 1.12+ functionality) + +Each implementation comes with its own set of limitations and benefits. + +**TL;DR:** +- Use `GradSampleModule` (`grad_sample_mode="hooks"`) for stable implementation with standard models (default) +- Use controller mode (`return_controller=True`) for transformer models and when you need direct model access without wrapping +- Use `GradSampleModuleExpandedWeights` (`grad_sample_mode="ew"`) if you want to experiment with better performance +- Use `grad_sample_mode="functorch"` if your model has unsupported layers + +Please report any strange errors or unexpected behaviour to us! + +## Controller-Based Approach (No Model Wrapping) +- Usage: Set `return_controller=True` in `PrivacyEngine.make_private()` +- Controller class: ``opacus.grad_sample.GradSampleController`` + +**Recommended for transformer models and when model wrapping causes issues.** + +Computes per-sample gradients by attaching hooks directly to model parameters without wrapping the model in a +`GradSampleModule`. This approach: + +- ✅ Preserves model type (e.g., `isinstance(model, BertModel)` remains `True`) +- ✅ No `_module.` prefix in state_dict +- ✅ Direct access to model attributes (no attribute forwarding needed) +- ✅ Better compatibility with HuggingFace transformers and models with custom `__getattr__` +- ✅ Same grad sampler methods as `GradSampleModule` + +**Example:** +```python +from opacus import PrivacyEngine + +privacy_engine = PrivacyEngine() +model, optimizer, dataloader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=dataloader, + noise_multiplier=1.0, + max_grad_norm=1.0, + return_controller=True, # ← Enable controller mode +) +# model is now unwrapped with hooks attached directly +``` + +## Hooks-based approach (Model Wrapping) - Model wrapping class: ``opacus.grad_sample.grad_sample_module.GradSampleModule`` - Keyword argument for ``PrivacyEngine.make_private()``: `grad_sample_mode="hooks"` @@ -62,23 +91,27 @@ is roughly the same. Please note that these are known limitations and we plan to improve Expanded Weights and bridge the gap in feature completeness -| xxx | Hooks | Expanded Weights | Functorch | -|:----------------------------:|:-------------------------------:|:----------------:|:------------:| -| Required PyTorch version | 1.8+ | 1.13+ | 1.12 (to be updated) | -| Development status | Underlying mechanism deprecated | Beta | Beta | -| Runtime Performance† | baseline | ✅ ~25% faster | 🟨 0-50% slower | -| Any DP-allowed†† layers | Not supported | Not supported | ✅ Supported | -| Most popular nn.* layers | ✅ Supported | ✅ Supported | ✅ Supported | -| torchscripted models | Not supported | ✅ Supported | Not supported | -| Client-provided grad sampler | ✅ Supported | Not supported | ✅ Not needed | -| `batch_first=False` | ✅ Supported | Not supported | ✅ Supported | -| Recurrent networks | ✅ Supported | Not supported | ✅ Supported | -| Padding `same` in Conv | ✅ Supported | Not supported | ✅ Supported | -| Empty poisson batches | ✅ Supported | Not supported | Not supported | - -† Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size. -Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers. -Note, that performance differences are only observed on GPU training, CPU performance seem to be almost identical +| xxx | GradSampleModule (Hooks) | GradSampleController | Expanded Weights | Functorch | +|:----------------------------:|:------------------------:|:-------------------:|:----------------:|:------------:| +| Required PyTorch version | 1.8+ | 1.8+ | 1.13+ | 1.12 (to be updated) | +| Development status | Deprecated mechanism | ✅ Beta | Beta | Beta | +| Model wrapping | ✅ Wraps model | ✅ No wrapping | ✅ Wraps model | ✅ Wraps model | +| Runtime Performance† | baseline | baseline | ✅ ~25% faster | 🟨 0-50% slower | +| Transformer compatibility | 🟨 May have issues | ✅ Excellent | 🟨 May have issues | 🟨 May have issues | +| State dict compatibility | 🟨 `_module.` prefix | ✅ Clean keys | 🟨 `_module.` prefix | 🟨 `_module.` prefix | +| Type preservation | ❌ Model wrapped | ✅ Model unchanged | ❌ Model wrapped | ❌ Model wrapped | +| Any DP-allowed†† layers | Not supported | Not supported | Not supported | ✅ Supported | +| Most popular nn.* layers | ✅ Supported | ✅ Supported | ✅ Supported | ✅ Supported | +| torchscripted models | Not supported | Not supported | ✅ Supported | Not supported | +| Client-provided grad sampler | ✅ Supported | ✅ Supported | Not supported | ✅ Not needed | +| `batch_first=False` | ✅ Supported | ✅ Supported | Not supported | ✅ Supported | +| Recurrent networks | ✅ Supported | ✅ Supported | Not supported | ✅ Supported | +| Padding `same` in Conv | ✅ Supported | ✅ Supported | Not supported | ✅ Supported | +| Empty poisson batches | ✅ Supported | ✅ Supported | Not supported | Not supported | + +† Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size. +Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers. +Note, that performance differences are only observed on GPU training, CPU performance seem to be almost identical for all approaches. †† Layers that produce joint computations on batch samples (e.g. BatchNorm) are not allowed under any approach diff --git a/opacus/grad_sample/__init__.py b/opacus/grad_sample/__init__.py index 17c67bbd..5783bd94 100644 --- a/opacus/grad_sample/__init__.py +++ b/opacus/grad_sample/__init__.py @@ -18,6 +18,10 @@ from .dp_rnn import compute_rnn_linear_grad_sample # noqa from .embedding import compute_embedding_grad_sample # noqa from .embedding_norm_sample import compute_embedding_norm_sample # noqa +from .grad_sample_controller import GradSampleController # noqa +from .grad_sample_controller_fast_gradient_clipping import ( # noqa + GradSampleControllerFastGradientClipping, +) from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample from .grad_sample_module_fast_gradient_clipping import ( # noqa GradSampleModuleFastGradientClipping, @@ -45,6 +49,8 @@ __all__ = [ + "GradSampleController", + "GradSampleControllerFastGradientClipping", "GradSampleModule", "GradSampleModuleFastGradientClipping", "GradSampleModuleFastGradientClippingFSDP", diff --git a/opacus/grad_sample/grad_sample_controller.py b/opacus/grad_sample/grad_sample_controller.py new file mode 100644 index 00000000..4a565c86 --- /dev/null +++ b/opacus/grad_sample/grad_sample_controller.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GradSampleController: Manages privacy hooks on models without wrapping them. + +This module provides a GradSampleModule-less approach to attaching hooks +directly to model parameters for computing per-sample gradients. +""" + +import logging +from functools import partial +from typing import Iterable, List, Tuple + +import torch +import torch.nn as nn +from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer +from opacus.grad_sample.grad_sample_hooks_mixin import GradSampleHooksMixin +from opacus.grad_sample.grad_sample_module import ( + _get_batch_size, + create_or_accumulate_grad_sample, + promote_current_grad_sample, +) +from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear +from opacus.utils.module_utils import ( + has_trainable_params, + requires_grad, + trainable_modules, + trainable_parameters, +) +from opacus.validators.errors import UnsupportedModuleError +from torch.utils.hooks import RemovableHandle + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +OPACUS_PARAM_MONKEYPATCH_ATTRS = [ + "grad_sample", + "_forward_counter", + "_current_grad_sample", + "_norm_sample", +] + + +# GradSampleHooksMixin is now imported from grad_sample_hooks_mixin.py to avoid circular imports + + +class GradSampleController(GradSampleHooksMixin): + """ + Controller for managing privacy hooks on models without wrapping them + + Computes per-sample gradients using custom-written methods for each layer. + See README.md for more details + + This class attaches hooks directly to model modules and manages their lifecycle, + providing an alternative to GradSampleModule wrapping that's more compatible + with transformers and other complex models. + """ + + def __init__( + self, + m: nn.Module, + *, + batch_first=True, + loss_reduction="mean", + strict: bool = True, + force_functorch=False, + ): + """ + + Args: + m: nn.Module to attach hooks to + batch_first: Flag to indicate if the input tensor to the corresponding module + has the first dimension representing the batch. If set to True, dimensions on + input tensor are expected be ``[batch_size, ...]``, otherwise + ``[K, batch_size, ...]`` + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) + is a sum or a mean operation. Can take values "sum" or "mean" + strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers, + which is not currently supported by Opacus. + If set to ``False``, per sample gradients will + be computed on "best effort" basis - they will be available where + possible and set to None otherwise. This is not recommended, because + some unsupported modules (e.g. BatchNorm) affect other parameters and + invalidate the concept of per sample gradients for the entire model. + force_functorch: If set to ``True``, will use functorch to compute + all per sample gradients. Otherwise, functorch will be used only + for layers without registered grad sampler methods. + + Raises: + NotImplementedError + If ``strict`` is set to ``True`` and module ``m`` (or any of its + submodules) includes a buffer. + """ + errors = self.validate(module=m, strict=strict) + if errors and not strict: + logger.info( + f"GradSampleController found the following errors: {errors}." + "Using non-strict mode, continuing" + ) + + self.module = m + self.hooks_enabled = False + self.grad_accumulation_allowed = True + self.batch_first = batch_first + self.loss_reduction = loss_reduction + self.force_functorch = force_functorch + + self.autograd_grad_sample_hooks: List[RemovableHandle] = [] + + # Initialize parameters with required attributes + for _, p in trainable_parameters(self.module): + p.grad_sample = None + p._forward_counter = 0 + + # Add the hooks + self.add_hooks() + + def _get_target_module(self) -> nn.Module: + """Return the module to attach hooks to.""" + return self.module + + def add_hooks(self) -> None: + """ + Adds hooks to model to save activations and backprop values. + The hooks will + 1. save activations into param.activations during forward pass + 2. compute per-sample gradients in params.grad_sample during backward pass. + Call ``remove_hooks(model)`` to disable this. + """ + self._add_hooks_impl( + target_module=self.module, + hooks_list=self.autograd_grad_sample_hooks, + batch_first=self.batch_first, + loss_reduction=self.loss_reduction, + force_functorch=self.force_functorch, + ) + + def remove_hooks(self) -> None: + """ + Removes hooks added by ``add_hooks()`` + """ + self.disable_hooks() + + while self.autograd_grad_sample_hooks: + handle = self.autograd_grad_sample_hooks.pop() + handle.remove() + + # Remove functorch hooks + for _module_name, module in trainable_modules(self.module): + if hasattr(module, "ft_compute_sample_grad"): + delattr(module, "ft_compute_sample_grad") + if hasattr(module, "activations"): + delattr(module, "activations") + + def cleanup(self): + """ + Clean up all hooks and attributes added to the model. + """ + self.remove_hooks() + + # Clean up parameter attributes + for attr in OPACUS_PARAM_MONKEYPATCH_ATTRS: + for p in self.module.parameters(): + if hasattr(p, attr): + delattr(p, attr) diff --git a/opacus/grad_sample/grad_sample_controller_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_controller_fast_gradient_clipping.py new file mode 100644 index 00000000..165d328e --- /dev/null +++ b/opacus/grad_sample/grad_sample_controller_fast_gradient_clipping.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GradSampleControllerFastGradientClipping: Controller-based Fast Gradient and Ghost Clipping. + +This module provides a GradSampleModule-less approach with ghost clipping support, +combining the benefits of: +- Controller-based hook management (no model wrapping) +- Ghost clipping (memory-efficient gradient norm computation) +""" + +import logging +from typing import List + +import torch +import torch.nn as nn +from opacus.grad_sample.grad_sample_controller import GradSampleController +from opacus.grad_sample.grad_sample_hooks_mixin import ( + GradSampleFastGradientClippingMixin, + create_norm_sample, +) +from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN +from opacus.utils.module_utils import trainable_modules, trainable_parameters + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class GradSampleControllerFastGradientClipping( + GradSampleController, GradSampleFastGradientClippingMixin +): + """ + Controller for managing privacy hooks with Fast Gradient and Ghost Clipping support + + Extends GradSampleController to add ghost clipping support for memory-efficient + gradient norm computation. Supports both: + - Ghost Clipping: Direct norm computation without materializing full gradients + - Fast Gradient Clipping: Full gradient computation followed by norm computation + + This class attaches hooks directly to model modules and manages their lifecycle, + providing an alternative to GradSampleModule wrapping that's more compatible + with transformers and other complex models. + """ + + NORM_SAMPLERS = {} + + def __init__( + self, + m: nn.Module, + *, + batch_first=True, + loss_reduction="mean", + strict: bool = True, + force_functorch=False, + max_grad_norm=1, + use_ghost_clipping=True, + ): + """ + + Args: + m: nn.Module to attach hooks to + batch_first: Flag to indicate if the input tensor to the corresponding module + has the first dimension representing the batch. If set to True, dimensions on + input tensor are expected be ``[batch_size, ...]``, otherwise + ``[K, batch_size, ...]`` + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) + is a sum or a mean operation. Can take values "sum" or "mean" + max_grad_norm: The value at which gradients are to be clipped. + strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers, + which is not currently supported by Opacus. + If set to ``False``, per sample gradients will + be computed on "best effort" basis - they will be available where + possible and set to None otherwise. This is not recommended, because + some unsupported modules (e.g. BatchNorm) affect other parameters and + invalidate the concept of per sample gradients for the entire model. + force_functorch: If set to ``True``, will use functorch to compute + all per sample gradients. Otherwise, functorch will be used only + for layers without registered grad sampler methods. + use_ghost_clipping: If set to ``True``, Ghost Clipping + will be used for clipping gradients of supported layers. If ``False``, Fast + Gradient Clipping will be used for all layers. + + Raises: + NotImplementedError + If ``strict`` is set to ``True`` and module ``m`` (or any of its + submodules) includes a buffer. + """ + # Call parent constructor + super().__init__( + m, + batch_first=batch_first, + loss_reduction=loss_reduction, + strict=strict, + force_functorch=force_functorch, + ) + + # Add ghost clipping specific attributes + self.max_grad_norm = max_grad_norm + self.use_ghost_clipping = use_ghost_clipping + self._per_sample_gradient_norms = None + + # Initialize _norm_sample attribute for parameters + for _, p in trainable_parameters(self.module): + p._norm_sample = None + + self.trainable_parameters = [p for _, p in trainable_parameters(self.module)] + + if logger.isEnabledFor(logging.INFO): + self.log_module_gradient_sample_mode( + module=m, + force_functorch=force_functorch, + use_ghost_clipping=use_ghost_clipping, + ) + + def get_clipping_coef(self) -> torch.Tensor: + """Get per-example gradient scaling factor for clipping.""" + norm_sample = self.get_norm_sample() + return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0) + + def get_norm_sample(self) -> torch.Tensor: + """Get per-example gradient norms.""" + norm_sample = torch.stack( + [param._norm_sample for param in self.trainable_parameters], dim=0 + ).norm(2, dim=0) + self.per_sample_gradient_norms = norm_sample + return norm_sample + + def capture_activations_hook( + self, + module: nn.Module, + forward_input: List[torch.Tensor], + _forward_output: torch.Tensor, + ): + """ + Override parent method to add parameter tying check for ghost clipping. + """ + # Call parent implementation + super().capture_activations_hook(module, forward_input, _forward_output) + + # Add ghost clipping specific check for parameter tying + if self.hooks_enabled: + for _, p in trainable_parameters(module): + if ( + self.use_ghost_clipping + and p._forward_counter > 1 + and type(module) in self.NORM_SAMPLERS + ): + raise NotImplementedError( + "Parameter tying is not supported with Ghost Clipping" + ) + + # Note: capture_backprops_hook is inherited from GradSampleFastGradientClippingMixin + + def log_module_gradient_sample_mode( + self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True + ): + """ + Add logs to track gradient sample mode for each part of the module, including 1) Ghost Clipping, 2) Fast Gradient Clipping (hook mode), and 3) Fast Gradient Clipping (functorch mode). + + Args: + module: nn.Module to be checked + force_functorch: If set to ``True``, will use functorch to compute + all per sample gradients. Otherwise, functorch will be used only + for layers without registered grad sampler methods. + use_ghost_clipping: If set to ``True``, Ghost Clipping + will be used for clipping gradients of supported layers. If ``False``, Fast + Gradient Clipping will be used for all layers. + """ + for m_name, m in trainable_modules(module): + if type(m) in [DPRNN, DPLSTM, DPGRU]: + logger.info( + f"Module name: {m_name}, module type: {type(m)}. No hook or functorch is added." + ) + + elif use_ghost_clipping and type(m) in self.NORM_SAMPLERS: + logger.info( + f"Module name: {m_name}, module type: {type(m)}, under Ghost Clipping." + ) + + else: + if not force_functorch and type(m) in self.GRAD_SAMPLERS: + # When functorch is not enforced, use FGC (hook mode) if the layer has a registered grad_sampler (supported). Otherwise, use FGC (functorch mode). + logger.info( + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (hook mode)." + ) + else: + logger.info( + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (functorch mode)." + ) + + @property + def per_sample_gradient_norms(self) -> torch.Tensor: + """Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings""" + if self._per_sample_gradient_norms is not None: + return self._per_sample_gradient_norms + else: + raise AttributeError( + "per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property." + ) + + @per_sample_gradient_norms.setter + def per_sample_gradient_norms(self, value): + self._per_sample_gradient_norms = value diff --git a/opacus/grad_sample/grad_sample_hooks_mixin.py b/opacus/grad_sample/grad_sample_hooks_mixin.py new file mode 100644 index 00000000..a3e90ee8 --- /dev/null +++ b/opacus/grad_sample/grad_sample_hooks_mixin.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Shared mixin for GradSampleModule and GradSampleController. + +This module contains the common hook logic that is shared between the module-wrapping +approach (GradSampleModule) and the controller-based approach (GradSampleController). +""" + +import logging +from functools import partial +from typing import Iterable, List, Tuple + +import torch +import torch.nn as nn +from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer +from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear +from opacus.utils.module_utils import ( + has_trainable_params, + requires_grad, + trainable_modules, + trainable_parameters, +) +from opacus.validators.errors import UnsupportedModuleError +from torch.utils.hooks import RemovableHandle + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class GradSampleHooksMixin: + """ + Mixin class containing common hook logic shared between GradSampleModule and GradSampleController. + + This class provides the core functionality for: + - Adding/removing hooks to compute per-sample gradients + - Capturing activations during forward pass + - Computing gradients during backward pass + - Managing hook lifecycle + """ + + GRAD_SAMPLERS = {} + + def _get_target_module(self) -> nn.Module: + """Return the module to attach hooks to. Override in subclasses.""" + raise NotImplementedError("Subclasses must implement _get_target_module") + + def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]: + """Iterate over submodules that need hooks attached.""" + if has_trainable_params(module): + yield module + + # Don't recurse if module is handled by functorch + if ( + has_trainable_params(module) + and type(module) not in self.GRAD_SAMPLERS + and type(module) not in [DPRNN, DPLSTM, DPGRU] + ): + return + + for m in module.children(): + yield from self.iterate_submodules(m) + + def _add_hooks_impl( + self, + target_module: nn.Module, + hooks_list: List[RemovableHandle], + *, + batch_first: bool, + loss_reduction: str, + force_functorch: bool, + ) -> None: + """Internal implementation of hook addition.""" + for module in self.iterate_submodules(target_module): + # Do not add hooks to DPRNN, DPLSTM or DPGRU + if type(module) in [DPRNN, DPLSTM, DPGRU]: + continue + + module_type = type(module) + if force_functorch or not (module_type in self.GRAD_SAMPLERS): + prepare_layer(module, batch_first=batch_first) + + hooks_list.append( + module.register_forward_hook(self.capture_activations_hook) + ) + + hooks_list.append( + module.register_full_backward_hook( + partial( + self.capture_backprops_hook, + loss_reduction=loss_reduction, + batch_first=batch_first, + ) + ) + ) + + self.enable_hooks() + + def disable_hooks(self) -> None: + """Globally disable all hooks installed by this library.""" + self.hooks_enabled = False + + def enable_hooks(self) -> None: + """Enable hooks (opposite of disable_hooks).""" + self.hooks_enabled = True + + def capture_activations_hook( + self, + module: nn.Module, + forward_input: List[torch.Tensor], + _forward_output: torch.Tensor, + ): + """Hook to capture activations during forward pass.""" + if ( + not requires_grad(module) + or not module.training + or not torch.is_grad_enabled() + ): + return + + if not self.hooks_enabled: + return + + if not hasattr(module, "activations"): + module.activations = [] + module.activations.append([t.detach() for t in forward_input]) + + for _, p in trainable_parameters(module): + p._forward_counter += 1 + + def capture_backprops_hook( + self, + module: nn.Module, + _forward_input: torch.Tensor, + forward_output: torch.Tensor, + loss_reduction: str, + batch_first: bool, + ): + """ + Computes per sample gradients given the current backprops and activations + stored by the associated forward hook. + """ + # Import here to avoid circular dependency + from opacus.grad_sample.grad_sample_module import ( + _get_batch_size, + create_or_accumulate_grad_sample, + promote_current_grad_sample, + ) + + if not self.hooks_enabled: + return + + backprops = forward_output[0].detach() + activations, backprops = self.rearrange_grad_samples( + module=module, + backprops=backprops, + loss_reduction=loss_reduction, + batch_first=batch_first, + ) + + if not self.force_functorch and type(module) in self.GRAD_SAMPLERS: + grad_sampler_fn = self.GRAD_SAMPLERS[type(module)] + else: + grad_sampler_fn = ft_compute_per_sample_gradient + + grad_samples = grad_sampler_fn(module, activations, backprops) + for param, gs in grad_samples.items(): + create_or_accumulate_grad_sample( + param=param, grad_sample=gs, max_batch_len=module.max_batch_len + ) + + # Detect end of current batch processing + for _, p in trainable_parameters(module): + p._forward_counter -= 1 + if p._forward_counter == 0: + promote_current_grad_sample(p) + + if not self.grad_accumulation_allowed: + if isinstance(p.grad_sample, list) and len(p.grad_sample) > 1: + raise ValueError( + "Poisson sampling is not compatible with grad accumulation. " + "You need to call optimizer.step() after every forward/backward pass " + "or consider using BatchMemoryManager" + ) + + if len(module.activations) == 0: + if hasattr(module, "max_batch_len"): + del module.max_batch_len + + def rearrange_grad_samples( + self, + *, + module: nn.Module, + backprops: torch.Tensor, + loss_reduction: str, + batch_first: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Rearrange activations and grad_samples based on loss reduction and batch dim.""" + # Import here to avoid circular dependency + from opacus.grad_sample.grad_sample_module import _get_batch_size + + if not hasattr(module, "activations"): + raise ValueError( + f"No activations detected for {type(module)}," + " run forward after add_hooks(model)" + ) + + batch_dim = 0 if batch_first or type(module) is RNNLinear else 1 + + if not hasattr(module, "max_batch_len"): + module.max_batch_len = _get_batch_size( + module=module, + batch_dim=batch_dim, + ) + activations = module.activations.pop() + + n = module.max_batch_len + if loss_reduction == "mean": + backprops = backprops * n + elif loss_reduction == "sum": + backprops = backprops + else: + raise ValueError( + f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported" + ) + + # No matter where the batch dimension was, .grad_samples will *always* put it in the first dim + if batch_dim != 0: + activations = [ + t.permute([batch_dim] + [x for x in range(t.dim()) if x != batch_dim]) + for t in activations + ] + backprops = backprops.permute( + [batch_dim] + [x for x in range(backprops.dim()) if x != batch_dim] + ) + + return activations, backprops + + def forbid_grad_accumulation(self): + """Forbid gradient accumulation (for Poisson sampling).""" + self.grad_accumulation_allowed = False + + def allow_grad_accumulation(self): + """Allow gradient accumulation.""" + self.grad_accumulation_allowed = True + + @classmethod + def validate( + cls, module: nn.Module, *, strict: bool = False + ) -> List[UnsupportedModuleError]: + """Check if per sample gradients can be fully computed for a given model.""" + errors = [] + errors.extend( + [ + UnsupportedModuleError( + f"Model contains a trainable layer with buffers " + f"that Opacus doesn't currently support ({m_name}:{m}). " + ) + for m_name, m in trainable_modules(module) + if len(list(m.buffers())) > 0 + ] + ) + if strict and len(errors) > 0: + raise UnsupportedModuleError(errors) + else: + return errors + + +def create_norm_sample( + *, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int +) -> None: + """ + Creates a ``_norm_sample`` attribute in the given parameter + + Args: + param: Parameter to which ``_norm_sample`` will be added + grad_sample: Per-sample gradients tensor. Must be of the same + shape as ``param`` with extra batch dimension + max_batch_len: Maximum batch length for handling empty batches + """ + if param.requires_grad: + if ( + max_batch_len == 0 + ): # To handle the case of empty batch that may arise from Poisson sampling + param._norm_sample = torch.tensor( + [], device=grad_sample.device, dtype=grad_sample.dtype + ) + else: + param._norm_sample = torch.zeros( + torch.Size([max_batch_len, 1]), + device=grad_sample.device, + dtype=grad_sample.dtype, + ) + param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm( + 2, dim=-1 + ) + + +class GradSampleFastGradientClippingMixin(GradSampleHooksMixin): + """ + Mixin for Fast Gradient and Ghost Clipping support. + + Extends GradSampleHooksMixin to add ghost clipping capabilities for + memory-efficient gradient norm computation. + """ + + NORM_SAMPLERS = {} + + def capture_backprops_hook( + self, + module: nn.Module, + _forward_input: torch.Tensor, + forward_output: torch.Tensor, + loss_reduction: str, + batch_first: bool, + ): + """ + Computes per sample gradient norms given the current backprops and activations. + + Supports both: + - Ghost Clipping: Direct norm computation without materializing full gradients + - Fast Gradient Clipping: Full gradient computation followed by norm computation + """ + # Import here to avoid circular dependency + from opacus.grad_sample.grad_sample_module import ( + create_or_accumulate_grad_sample, + promote_current_grad_sample, + ) + + if not self.hooks_enabled: + return + + backprops = forward_output[0].detach() + activations, backprops = self.rearrange_grad_samples( + module=module, + backprops=backprops, + loss_reduction=loss_reduction, + batch_first=batch_first, + ) + + # Handle DTensor if needed + activations = [ + temp.to_local() if type(temp) is torch.distributed.tensor.DTensor else temp + for temp in activations + ] + + if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS: + # Ghost clipping: compute norms directly + norm_sampler_fn = self.NORM_SAMPLERS[type(module)] + norm_samples = norm_sampler_fn(module, activations, backprops) + + for param, ns in norm_samples.items(): + if param.requires_grad: + param._norm_sample = ns + param._forward_counter -= 1 + + else: + # Fast gradient clipping: materialize gradients then compute norms + if not self.force_functorch and type(module) in self.GRAD_SAMPLERS: + grad_sampler_fn = self.GRAD_SAMPLERS[type(module)] + else: + grad_sampler_fn = ft_compute_per_sample_gradient + + grad_samples = grad_sampler_fn(module, activations, backprops) + for param, gs in grad_samples.items(): + create_or_accumulate_grad_sample( + param=param, grad_sample=gs, max_batch_len=module.max_batch_len + ) + # Also create norm sample for fast gradient clipping + create_norm_sample( + param=param, grad_sample=gs, max_batch_len=module.max_batch_len + ) + + # Detect end of current batch processing + for _, p in trainable_parameters(module): + p._forward_counter -= 1 + if p._forward_counter == 0: + promote_current_grad_sample(p) + + if not self.grad_accumulation_allowed: + if isinstance(p.grad_sample, list) and len(p.grad_sample) > 1: + raise ValueError( + "Poisson sampling is not compatible with grad accumulation. " + "You need to call optimizer.step() after every forward/backward pass " + "or consider using BatchMemoryManager" + ) + + if len(module.activations) == 0: + if hasattr(module, "max_batch_len"): + del module.max_batch_len diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index 77240012..b4f585eb 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -23,6 +23,7 @@ import torch import torch.nn as nn from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer +from opacus.grad_sample.grad_sample_hooks_mixin import GradSampleHooksMixin from opacus.grad_sample.gsm_base import AbstractGradSampleModule from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear from opacus.utils.module_utils import ( @@ -76,7 +77,7 @@ def promote_current_grad_sample(p: nn.Parameter) -> None: del p._current_grad_sample -class GradSampleModule(AbstractGradSampleModule): +class GradSampleModule(AbstractGradSampleModule, GradSampleHooksMixin): """ Hooks-based implementation of AbstractGradSampleModule @@ -84,8 +85,6 @@ class GradSampleModule(AbstractGradSampleModule): See README.md for more details """ - GRAD_SAMPLERS = {} - def __init__( self, m: nn.Module, @@ -148,23 +147,9 @@ def __init__( def forward(self, *args, **kwargs): return self._module(*args, **kwargs) - def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]: - if has_trainable_params(module): - yield module - - # Don't recurse if module is handled by functorch - if ( - has_trainable_params(module) - and type(module) not in self.GRAD_SAMPLERS - and type(module) not in [DPRNN, DPLSTM, DPGRU] - ): - return - - for m in module.children(): - yield from self.iterate_submodules(m) - - def _get_module_type(self, module: nn.Module) -> str: - return type(module) + def _get_target_module(self) -> nn.Module: + """Return the module to attach hooks to.""" + return self._module def add_hooks( self, @@ -181,7 +166,6 @@ def add_hooks( Call ``remove_hooks(model)`` to disable this. Args: - model: the model to which hooks are added batch_first: Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be ``[batch_size, ...]``, otherwise @@ -197,30 +181,14 @@ def add_hooks( self._module.autograd_grad_sample_hooks = [] self.autograd_grad_sample_hooks = self._module.autograd_grad_sample_hooks - for module in self.iterate_submodules(self._module): - # Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear` - if type(module) in [DPRNN, DPLSTM, DPGRU]: - continue - - module_type = self._get_module_type(module) - if force_functorch or not (module_type in self.GRAD_SAMPLERS): - prepare_layer(module, batch_first=batch_first) - - self.autograd_grad_sample_hooks.append( - module.register_forward_hook(self.capture_activations_hook) - ) - - self.autograd_grad_sample_hooks.append( - module.register_full_backward_hook( - partial( - self.capture_backprops_hook, - loss_reduction=loss_reduction, - batch_first=batch_first, - ) - ) - ) - - self.enable_hooks() + # Delegate to mixin implementation + self._add_hooks_impl( + target_module=self._module, + hooks_list=self.autograd_grad_sample_hooks, + batch_first=batch_first, + loss_reduction=loss_reduction, + force_functorch=force_functorch, + ) def remove_hooks(self) -> None: """ @@ -269,230 +237,20 @@ def _close(self): super()._close() self.remove_hooks() - def capture_activations_hook( - self, - module: nn.Module, - forward_input: List[torch.Tensor], - _forward_output: torch.Tensor, - ): - if ( - not requires_grad(module) - or not module.training - or not torch.is_grad_enabled() - ): - return - - if not self.hooks_enabled: - return - - if not hasattr(module, "activations"): - module.activations = [] - module.activations.append([t.detach() for t in forward_input]) # pyre-ignore - - for _, p in trainable_parameters(module): - p._forward_counter += 1 - - def capture_backprops_hook( - self, - module: nn.Module, - _forward_input: torch.Tensor, - forward_output: torch.Tensor, - loss_reduction: str, - batch_first: bool, - ): - """ - Computes per sample gradients given the current backprops and activations - stored by the associated forward hook. Computed per sample gradients are - stored in ``grad_sample`` field in each parameter. - - For non-recurrent layers the process is straightforward: for each - ``loss.backward()`` call this hook will be called exactly one. For recurrent - layers, however, this is more complicated and the hook will be called multiple - times, while still processing the same batch of data. - - For this reason we first accumulate the gradients from *the same batch* in - ``p._current_grad_sample`` and then, when we detect the end of a full backward - pass - we store accumulated result on ``p.grad_sample``. - - From there, ``p.grad_sample`` could be either a Tensor or a list of Tensors, - if accumulated over multiple batches - - Args: - module: nn.Module, - _forward_input: torch.Tensor, - forward_output: torch.Tensor, - loss_reduction: str, - batch_first: bool, - """ - if not self.hooks_enabled: - return - - backprops = forward_output[0].detach() - activations, backprops = self.rearrange_grad_samples( - module=module, - backprops=backprops, - loss_reduction=loss_reduction, - batch_first=batch_first, - ) - if ( - not self.force_functorch - and self._get_module_type(module) in self.GRAD_SAMPLERS - ): - grad_sampler_fn = self.GRAD_SAMPLERS[self._get_module_type(module)] - else: - grad_sampler_fn = ft_compute_per_sample_gradient - - grad_samples = grad_sampler_fn(module, activations, backprops) - for param, gs in grad_samples.items(): - create_or_accumulate_grad_sample( - param=param, grad_sample=gs, max_batch_len=module.max_batch_len - ) - - # Detect end of current batch processing and switch accumulation - # mode from sum to stacking. Used for RNNs and tied parameters - # (See #417 for details) - for _, p in trainable_parameters(module): - p._forward_counter -= 1 - if p._forward_counter == 0: - promote_current_grad_sample(p) - - if not self.grad_accumulation_allowed: - if isinstance(p.grad_sample, list) and len(p.grad_sample) > 1: - raise ValueError( - "Poisson sampling is not compatible with grad accumulation. " - "You need to call optimizer.step() after every forward/backward pass " - "or consider using BatchMemoryManager" - ) - - if len(module.activations) == 0: - if hasattr(module, "max_batch_len"): - del module.max_batch_len - - def rearrange_grad_samples( - self, - *, - module: nn.Module, - backprops: torch.Tensor, - loss_reduction: str, - batch_first: bool, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Rearrange activations and grad_samples based on loss reduction and batch dim - - Args: - module: the module for which per-sample gradients are computed - backprops: the captured backprops - loss_reduction: either "mean" or "sum" depending on whether backpropped - loss was averaged or summed over batch - batch_first: True is batch dimension is first - """ - if not hasattr(module, "activations"): - raise ValueError( - f"No activations detected for {type(module)}," - " run forward after add_hooks(model)" - ) - - batch_dim = 0 if batch_first or type(module) is RNNLinear else 1 - - if not hasattr(module, "max_batch_len"): - # For packed sequences, max_batch_len is set in the forward of the model (e.g. the LSTM) - # Otherwise we infer it here - module.max_batch_len = _get_batch_size( - module=module, - batch_dim=batch_dim, - ) - activations = module.activations.pop() - - n = module.max_batch_len - if loss_reduction == "mean": - backprops = backprops * n - elif loss_reduction == "sum": - backprops = backprops - else: - raise ValueError( - f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported" - ) - - # No matter where the batch dimension was, .grad_samples will *always* put it in the first dim - if batch_dim != 0: - activations = [ - t.permute([batch_dim] + [x for x in range(t.dim()) if x != batch_dim]) - for t in activations - ] - backprops = backprops.permute( - [batch_dim] + [x for x in range(backprops.dim()) if x != batch_dim] - ) - - return activations, backprops - - @classmethod - def is_supported(cls, module: nn.Module) -> bool: - """ - Checks if this individual model is supported (i.e. has a registered - grad sampler function) - - Notes: - Note that this method does not check submodules - - Args: - module: nn.Module to be checked - - Returns: - ``True`` if grad sampler is found, ``False`` otherwise - """ - warnings.warn( - "GradSampleModule.is_supported is deprecated, as all layers can now be used with functorch.", - DeprecationWarning, - ) - - return True - - @classmethod - def validate( - cls, module: nn.Module, *, strict: bool = False - ) -> List[NotImplementedError]: - """ - Check if per sample gradients can be fully computed for a given model - - Args: - module: nn.Module to be checked - raise_if_error: Behaviour in case of a negative check result. Will - return the list of exceptions if set to ``False``, and throw otherwise - - Returns: - Empty list of validation is successful. - List of validation errors if ``raise_if_error=False`` and - unsupported modules are found - - Raises: - NotImplementedError - If ``raise_if_error=True`` and unsupported modules are found - """ - errors = [] - errors.extend( - [ - NotImplementedError( - f"Model contains a trainable layer with buffers" - f"that Opacus doesn't currently support({m_name}:{m}). " - ) - for m_name, m in trainable_modules(module) - # With functorch, all modules are trainable - # We still want to avoid module that have buffers (e.g. BatchNorm) - # as the buffers are not private - if len(list(m.buffers())) > 0 - ] - ) - # raise or return errors as needed - if strict and len(errors) > 0: - raise NotImplementedError(errors) - else: - return errors - + # Override base class no-op methods to use mixin implementations def forbid_grad_accumulation(self): - self.grad_accumulation_allowed = False + """Forbid gradient accumulation (for Poisson sampling).""" + GradSampleHooksMixin.forbid_grad_accumulation(self) def allow_grad_accumulation(self): - self.grad_accumulation_allowed = True + """Allow gradient accumulation.""" + GradSampleHooksMixin.allow_grad_accumulation(self) + + # Note: The following methods are inherited from GradSampleHooksMixin: + # - capture_activations_hook + # - capture_backprops_hook + # - rearrange_grad_samples + # - validate def _get_batch_size(*, module: nn.Module, batch_dim: int) -> int: diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index 89dad983..f7246180 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -20,13 +20,12 @@ import torch import torch.nn as nn -from opacus.grad_sample.functorch import ft_compute_per_sample_gradient -from opacus.grad_sample.grad_sample_module import ( - GradSampleModule, - create_or_accumulate_grad_sample, - promote_current_grad_sample, +from opacus.grad_sample.grad_sample_hooks_mixin import ( + GradSampleFastGradientClippingMixin, + create_norm_sample, ) -from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN +from opacus.grad_sample.grad_sample_module import GradSampleModule +from opacus.layers import DPGRU, DPLSTM, DPRNN from opacus.utils.module_utils import ( requires_grad, trainable_modules, @@ -38,38 +37,9 @@ logger.disabled = True -def create_norm_sample( - *, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int -) -> None: - """ - Creates a ``_norm_sample`` attribute in the given parameter - - - Args: - param: Parameter to which ``_norm_sample`` will be added - grad_sample: Per-sample gradients tensor. Must be of the same - shape as ``param`` with extra batch dimension - """ - - if param.requires_grad: - if ( - max_batch_len == 0 - ): # To handle the case of empty batch that may arise from Poisson sampling - param._norm_sample = torch.tensor( - [], device=grad_sample.device, dtype=grad_sample.dtype - ) - else: - param._norm_sample = torch.zeros( - torch.Size([max_batch_len, 1]), - device=grad_sample.device, - dtype=grad_sample.dtype, - ) - param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm( - 2, dim=-1 - ) - - -class GradSampleModuleFastGradientClipping(GradSampleModule): +class GradSampleModuleFastGradientClipping( + GradSampleModule, GradSampleFastGradientClippingMixin +): """ Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping @@ -177,79 +147,7 @@ def capture_activations_hook( "Parameter tying is not supported with Ghost Clipping" ) - def capture_backprops_hook( - self, - module: nn.Module, - _forward_input: torch.Tensor, - forward_output: torch.Tensor, - loss_reduction: str, - batch_first: bool, - ): - """ - Computes norms of per sample gradient given the current backprops and activations - stored by the associated forward hook. Computed per sample gradient norms are - stored in ``norm_sample`` field in each parameter. - - Args: - module: nn.Module, - _forward_input: torch.Tensor, - forward_output: torch.Tensor, - loss_reduction: str, - batch_first: bool, - """ - if not self.hooks_enabled: - return - - backprops = forward_output[0].detach() - - activations, backprops = self.rearrange_grad_samples( - module=module, - backprops=backprops, - loss_reduction=loss_reduction, - batch_first=batch_first, - ) - activations = [ - temp.to_local() if type(temp) is torch.distributed.tensor.DTensor else temp - for temp in activations - ] - - if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS: - norm_sampler_fn = self.NORM_SAMPLERS[type(module)] - norm_samples = norm_sampler_fn(module, activations, backprops) - - for param, ns in norm_samples.items(): - if param.requires_grad: - param._norm_sample = ns - param._forward_counter -= 1 - - else: - if not self.force_functorch and type(module) in self.GRAD_SAMPLERS: - grad_sampler_fn = self.GRAD_SAMPLERS[type(module)] - else: - grad_sampler_fn = ft_compute_per_sample_gradient - - grad_samples = grad_sampler_fn(module, activations, backprops) - for param, gs in grad_samples.items(): - create_or_accumulate_grad_sample( - param=param, grad_sample=gs, max_batch_len=module.max_batch_len - ) - del grad_samples - # Detect end of current batch processing and switch accumulation - # mode from sum to stacking. Used for RNNs and tied parameters - # (See #417 for details) - for _, p in trainable_parameters(module): - p._forward_counter -= 1 - if p._forward_counter == 0: - promote_current_grad_sample(p) - create_norm_sample( - param=p, - grad_sample=p.grad_sample, - max_batch_len=module.max_batch_len, - ) - p.grad_sample = None - if len(module.activations) == 0: - if hasattr(module, "max_batch_len"): - del module.max_batch_len + # Note: capture_backprops_hook is inherited from GradSampleFastGradientClippingMixin def log_module_gradient_sample_mode( self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True diff --git a/opacus/grad_sample/gsm_base.py b/opacus/grad_sample/gsm_base.py index ec137789..16077bb8 100644 --- a/opacus/grad_sample/gsm_base.py +++ b/opacus/grad_sample/gsm_base.py @@ -151,16 +151,20 @@ def forbid_grad_accumulation(self): without an optimizer step or clearing out gradients). When set, GradSampleModule will throw a ValueError on the second backward pass. - :return: + + Note: This is a no-op in the base class. Subclasses that support grad accumulation + detection (like GradSampleModule with hooks) should override this. """ - pass + # No-op for modules that don't support grad accumulation detection + return def allow_grad_accumulation(self): """ Unsets a flag to detect gradient accumulation (multiple forward/backward passes without an optimizer step or clearing out gradients). - When set, GradSampleModule will throw a ValueError on the second backward pass. - :return: + Note: This is a no-op in the base class. Subclasses that support grad accumulation + detection (like GradSampleModule with hooks) should override this. """ - pass + # No-op for modules that don't support grad accumulation detection + return diff --git a/opacus/grad_sample/gsm_exp_weights.py b/opacus/grad_sample/gsm_exp_weights.py index 13afac25..7f984881 100644 --- a/opacus/grad_sample/gsm_exp_weights.py +++ b/opacus/grad_sample/gsm_exp_weights.py @@ -32,6 +32,7 @@ def __init__( *, batch_first=True, loss_reduction="mean", + strict: bool = True, ): if not batch_first: raise NotImplementedError @@ -41,6 +42,8 @@ def __init__( batch_first=batch_first, loss_reduction=loss_reduction, ) + # Note: strict parameter is accepted for compatibility but not used + # in ExpandedWeights implementation def forward(self, x: torch.Tensor, *args, **kwargs): from torch.nn.utils._per_sample_grad import call_for_per_sample_grads diff --git a/opacus/grad_sample/utils.py b/opacus/grad_sample/utils.py index 78d58992..9a8dbbec 100644 --- a/opacus/grad_sample/utils.py +++ b/opacus/grad_sample/utils.py @@ -17,6 +17,10 @@ import torch.nn as nn +from .grad_sample_controller import GradSampleController +from .grad_sample_controller_fast_gradient_clipping import ( + GradSampleControllerFastGradientClipping, +) from .grad_sample_module import GradSampleModule from .grad_sample_module_fast_gradient_clipping import ( GradSampleModuleFastGradientClipping, @@ -53,6 +57,8 @@ def decorator(f): for target_class in target_classes: GradSampleModule.GRAD_SAMPLERS[target_class] = f GradSampleModuleFastGradientClipping.GRAD_SAMPLERS[target_class] = f + GradSampleController.GRAD_SAMPLERS[target_class] = f + GradSampleControllerFastGradientClipping.GRAD_SAMPLERS[target_class] = f return f return decorator @@ -79,25 +85,35 @@ def decorator(f): ) for target_class in target_classes: GradSampleModuleFastGradientClipping.NORM_SAMPLERS[target_class] = f + GradSampleControllerFastGradientClipping.NORM_SAMPLERS[target_class] = f return f return decorator -def wrap_model(model: nn.Module, grad_sample_mode: str, *args, **kwargs): - cls = get_gsm_class(grad_sample_mode) - if grad_sample_mode == "functorch": - kwargs["force_functorch"] = True - return cls(model, *args, **kwargs) - - def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]: """ - Returns AbstractGradSampleModule subclass correspinding to the input mode. + Returns AbstractGradSampleModule subclass corresponding to the input mode. + + This is used for the wrapping approach where the model is wrapped in a + GradSampleModule subclass. + See README for detailed comparison between grad sample modes. - :param grad_sample_mode: - :return: + Args: + grad_sample_mode: Mode for computing per-sample gradients. Supported values: + - "hooks": Standard hook-based computation (GradSampleModule) + - "functorch": Functorch-based computation (GradSampleModule with force_functorch=True) + - "ew": Expanded weights approach (GradSampleModuleExpandedWeights) + - "ghost": Ghost clipping with wrapping (GradSampleModuleFastGradientClipping) + - "ghost_fsdp": Ghost clipping with FSDP (GradSampleModuleFastGradientClippingFSDP) + - "no_op": No-op implementation (GradSampleModuleNoOp) + + Returns: + AbstractGradSampleModule subclass + + Raises: + ValueError: If grad_sample_mode is not recognized """ if grad_sample_mode in ["hooks", "functorch"]: return GradSampleModule @@ -112,5 +128,81 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]: else: raise ValueError( f"Unexpected grad_sample_mode: {grad_sample_mode}. " - f"Allowed values: hooks, ew" + f"Allowed values: hooks, functorch, ew, ghost, ghost_fsdp, no_op" + ) + + +def get_gsc_class(grad_sample_mode: str): + """ + Returns GradSampleController subclass corresponding to the input mode. + + This is used for the controller-based approach where hooks are attached + directly to the model without wrapping. + + See README for a detailed comparison between grad sample modes. + + Args: + grad_sample_mode: Mode for computing per-sample gradients. Supported values: + - "hooks": Standard hook-based computation (GradSampleController) + - "functorch": Functorch-based computation (GradSampleController with force_functorch=True) + - "ghost": Ghost clipping without wrapping (GradSampleControllerFastGradientClipping) + + Returns: + GradSampleController subclass + + Raises: + ValueError: If grad_sample_mode is not recognized or not supported by controllers + """ + if grad_sample_mode in ["hooks", "functorch"]: + return GradSampleController + elif grad_sample_mode == "ghost": + return GradSampleControllerFastGradientClipping + else: + raise ValueError( + f"Unexpected grad_sample_mode: {grad_sample_mode}. " + f"Controller-based approach supports: hooks, functorch, ghost" ) + + +def wrap_model( + model: nn.Module, + grad_sample_mode: str, + use_controller: bool = False, + *args, + **kwargs, +): + """ + Wraps a model for per-sample gradient computation. + + This is a unified interface that supports both wrapping-based and controller-based + approaches for computing per-sample gradients. + + Args: + model: PyTorch module to be wrapped or controlled + grad_sample_mode: Mode for computing per-sample gradients + use_controller: If True, uses controller-based approach (no wrapping). + If False (default), wraps model in GradSampleModule subclass. + *args: Additional positional arguments passed to the wrapper/controller + **kwargs: Additional keyword arguments passed to the wrapper/controller + + Returns: + Either: + - GradSampleModule subclass instance (if use_controller=False) + - GradSampleController instance (if use_controller=True) + + Notes: + - When use_controller=True, the original model is NOT wrapped and can be used + as-is. The controller manages hooks on the side. + - When use_controller=False, the model is wrapped and should be used via the + returned wrapper object. + """ + # Set force_functorch flag for functorch mode + if grad_sample_mode == "functorch": + kwargs["force_functorch"] = True + + if use_controller: + cls = get_gsc_class(grad_sample_mode) + else: + cls = get_gsm_class(grad_sample_mode) + + return cls(model, *args, **kwargs) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index eeea6b6a..f008719b 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -28,6 +28,7 @@ get_gsm_class, wrap_model, ) +from opacus.grad_sample.grad_sample_controller import GradSampleController from opacus.optimizers import DPOptimizer, get_optimizer_class from opacus.schedulers import _GradClipScheduler, _NoiseScheduler from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping @@ -146,6 +147,7 @@ def _prepare_data_loader( *, poisson_sampling: bool, distributed: bool, + batch_first: bool = True, ) -> DataLoader: if self.dataset is None: self.dataset = data_loader.dataset @@ -161,7 +163,9 @@ def _prepare_data_loader( if poisson_sampling: return DPDataLoader.from_data_loader( - data_loader, generator=self.secure_rng, distributed=distributed + data_loader, + generator=self.secure_rng, + distributed=distributed, ) elif self.secure_mode: return switch_generator(data_loader=data_loader, generator=self.secure_rng) @@ -176,47 +180,80 @@ def _prepare_model( max_grad_norm: Union[float, List[float]] = 1.0, loss_reduction: str = "mean", grad_sample_mode: str = "hooks", - ) -> AbstractGradSampleModule: - # Ideally, validation should have been taken care of by calling - # `get_compatible_module()` - self.validate(module=module, optimizer=None, data_loader=None) - - # wrap - if isinstance(module, AbstractGradSampleModule): - if ( - module.batch_first != batch_first - or module.loss_reduction != loss_reduction - or type(module) is not get_gsm_class(grad_sample_mode) - ): + strict: bool = False, + return_controller: bool = False, + use_ghost_clipping: bool = True, + ) -> Union[AbstractGradSampleModule, GradSampleController]: + """ + Prepares a model for differentially private training. + + Args: + module: PyTorch module to be prepared + batch_first: Flag to indicate if input has batch dimension first + max_grad_norm: Maximum gradient norm for clipping (required for ghost clipping) + loss_reduction: Loss reduction method ("mean" or "sum") + grad_sample_mode: Mode for computing per-sample gradients + strict: If True, validates module strictly + return_controller: If True, uses controller-based approach (no wrapping). + If False, wraps module in GradSampleModule. + use_ghost_clipping: If True and grad_sample_mode="ghost", uses ghost clipping + + Returns: + Either GradSampleModule instance (if return_controller=False) or + GradSampleController instance (if return_controller=True) + """ + # Validate module unless using controller-based approach + # (controller validates internally) + if not return_controller: + # Ideally, validation should have been taken care of by calling + # `get_compatible_module()` + self.validate(module=module, optimizer=None, data_loader=None) + + # Check if already wrapped + if isinstance(module, AbstractGradSampleModule): + if ( + module.batch_first != batch_first + or module.loss_reduction != loss_reduction + or type(module) is not get_gsm_class(grad_sample_mode) + ): + raise ValueError( + f"Pre-existing GradSampleModule doesn't match new arguments. " + f"Got: module.batch_first: {module.batch_first}, module.loss_reduction: {module.loss_reduction}, type(module): {type(module)} " + f"Requested: batch_first:{batch_first}, loss_reduction: {loss_reduction}, grad_sample_mode: {grad_sample_mode} " + f"Please pass vanilla nn.Module instead" + ) + return module + + # Prepare kwargs for wrapping/controller + kwargs = { + "batch_first": batch_first, + "loss_reduction": loss_reduction, + "strict": strict, + } + + # Add ghost clipping specific parameters + if grad_sample_mode in ["ghost", "ghost_fsdp"]: + if max_grad_norm is None: raise ValueError( - f"Pre-existing GradSampleModule doesn't match new arguments." - f"Got: module.batch_first: {module.batch_first}, module.loss_reduction: {module.loss_reduction}, type(module): {type(module)}" - f"Requested: batch_first:{batch_first}, loss_reduction: {loss_reduction}, grad_sample_mode: {grad_sample_mode} " - f"Please pass vanilla nn.Module instead" + "max_grad_norm must be provided when using ghost clipping mode" ) + kwargs["max_grad_norm"] = max_grad_norm + if return_controller: + # Only controllers have use_ghost_clipping parameter + kwargs["use_ghost_clipping"] = use_ghost_clipping - return module - else: - if grad_sample_mode in ["ghost", "ghost_fsdp"]: - return wrap_model( - module, - grad_sample_mode=grad_sample_mode, - batch_first=batch_first, - loss_reduction=loss_reduction, - max_grad_norm=max_grad_norm, - ) - else: - return wrap_model( - module, - grad_sample_mode=grad_sample_mode, - batch_first=batch_first, - loss_reduction=loss_reduction, - ) + # Use unified wrap_model function + return wrap_model( + module, + grad_sample_mode=grad_sample_mode, + use_controller=return_controller, + **kwargs, + ) def _prepare_criterion( self, *, - module: GradSampleModule, + controller_or_module: Union[GradSampleModule, GradSampleController], optimizer: DPOptimizer, criterion=nn.CrossEntropyLoss(), loss_reduction: str = "mean", @@ -224,14 +261,16 @@ def _prepare_criterion( ) -> DPLossFastGradientClipping: """ Args: - module: GradSampleModule used for training, + controller_or_module: GradSampleModule or GradSampleController used for training, optimizer: DPOptimizer used for training, criterion: Loss function used for training, loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients) Prepare the DP loss class, which packages the two backward passes for fast gradient clipping. """ - return DPLossFastGradientClipping(module, optimizer, criterion, loss_reduction) + return DPLossFastGradientClipping( + controller_or_module, optimizer, criterion, loss_reduction + ) def is_compatible( self, @@ -309,10 +348,14 @@ def make_private( clipping: str = "flat", noise_generator=None, grad_sample_mode: str = "hooks", + strict: bool = True, + return_controller: bool = False, **kwargs, ) -> Union[ Tuple[GradSampleModule, DPOptimizer, DataLoader], Tuple[GradSampleModule, DPOptimizer, DPLossFastGradientClipping, DataLoader], + Tuple[nn.Module, DPOptimizer, DataLoader], + Tuple[nn.Module, DPOptimizer, DPLossFastGradientClipping, DataLoader], ]: """ Add privacy-related responsibilities to the main PyTorch training objects: @@ -321,7 +364,7 @@ def make_private( All of the returned objects act just like their non-private counterparts passed as arguments, but with added DP tasks. - - Model is wrapped to also compute per sample gradients. + - Model is wrapped to also compute per sample gradients (or hooks attached directly if return_controller=True). - Optimizer is now responsible for gradient clipping and adding noise to the gradients. - Criterion is a wrapper around the original criterion that packages the two backward passes for fast gradient clipping. - DataLoader is updated to perform Poisson sampling. @@ -362,14 +405,25 @@ def make_private( implementation class for the wrapped ``module``. See :class:`~opacus.grad_sample.gsm_base.AbstractGradSampleModule` for more details + strict: If True, will raise an error if the module is incompatible with + grad_sample_mode and will not attach hooks (only used when return_controller=True). + return_controller: If True, uses controller-based approach (no wrapping). + Returns the original unwrapped module with hooks attached via controller. + Controller is stored at module._opacus_controller for cleanup if needed. + If False (default), wraps module in GradSampleModule. + Recommended for HuggingFace transformers and models with custom __getattr__. Returns: - Tuple of (model, optimizer, data_loader) or (model, optimizer, criterion, data_loader). + If return_controller=False (default): + Tuple of (model, optimizer, data_loader) or (model, optimizer, criterion, data_loader). + Model is a GradSampleModule wrapper around the original model. + + If return_controller=True: + Tuple of (model, optimizer, data_loader) or (model, optimizer, criterion, data_loader). + Model is the UNWRAPPED original model with hooks attached directly. - Model is a wrapper around the original model that also computes per sample - gradients Optimizer is a wrapper around the original optimizer that also does - gradient clipping and noise addition to the gradients + gradient clipping and noise addition to the gradients Criterion is a wrapper around the original criterion that packages the two backward passes for fast gradient clipping. Only returned when grad_sample_mode is "ghost". DataLoader is a brand new DataLoader object, constructed to behave as @@ -391,18 +445,22 @@ def make_private( distributed = isinstance(module, (DPDDP, DDP, FSDPModule)) - module = self._prepare_model( + controller_or_module = self._prepare_model( module, batch_first=batch_first, max_grad_norm=max_grad_norm, loss_reduction=loss_reduction, grad_sample_mode=grad_sample_mode, + strict=strict, + return_controller=return_controller, ) if poisson_sampling: - module.forbid_grad_accumulation() + controller_or_module.forbid_grad_accumulation() data_loader = self._prepare_data_loader( - data_loader, distributed=distributed, poisson_sampling=poisson_sampling + data_loader, + distributed=distributed, + poisson_sampling=poisson_sampling, ) sample_rate = 1 / len(data_loader) @@ -429,18 +487,29 @@ def make_private( optimizer.attach_step_hook( self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate) ) + if "ghost" in grad_sample_mode: criterion = self._prepare_criterion( - module=module, + controller_or_module=controller_or_module, optimizer=optimizer, criterion=criterion, loss_reduction=loss_reduction, **kwargs, ) - return module, optimizer, criterion, data_loader + if return_controller: + # Store controller reference on module for cleanup + module._opacus_controller = controller_or_module + return module, optimizer, criterion, data_loader + else: + return controller_or_module, optimizer, criterion, data_loader - return module, optimizer, data_loader + if return_controller: + # Store controller reference on module for cleanup + module._opacus_controller = controller_or_module + return module, optimizer, data_loader + else: + return controller_or_module, optimizer, data_loader def make_private_with_epsilon( self, @@ -459,10 +528,14 @@ def make_private_with_epsilon( clipping: str = "flat", noise_generator=None, grad_sample_mode: str = "hooks", + strict: bool = True, + return_controller: bool = False, **kwargs, ) -> Union[ Tuple[GradSampleModule, DPOptimizer, DataLoader], Tuple[GradSampleModule, DPOptimizer, DPLossFastGradientClipping, DataLoader], + Tuple[nn.Module, DPOptimizer, DataLoader], + Tuple[nn.Module, DPOptimizer, DPLossFastGradientClipping, DataLoader], ]: """ Version of :meth:`~opacus.privacy_engine.PrivacyEngine.make_private`, @@ -509,8 +582,10 @@ def make_private_with_epsilon( Returns: Tuple of (model, optimizer, data_loader) or (model, optimizer, criterion, data_loader). - Model is a wrapper around the original model that also computes per sample - gradients + If return_controller=True: + Tuple of (model, optimizer, data_loader) or (model, optimizer, criterion, data_loader). + Model is the UNWRAPPED original model with hooks attached directly. + Optimizer is a wrapper around the original optimizer that also does gradient clipping and noise addition to the gradients Criterion is a wrapper around the original criterion that packages the two backward passes for fast gradient clipping. @@ -548,6 +623,8 @@ def make_private_with_epsilon( grad_sample_mode=grad_sample_mode, poisson_sampling=poisson_sampling, clipping=clipping, + strict=strict, + return_controller=return_controller, **kwargs, ) @@ -567,7 +644,7 @@ def save_checkpoint( self, *, path: Union[str, os.PathLike, BinaryIO, IO[bytes]], - module: GradSampleModule, + module: Union[nn.Module, GradSampleModule], optimizer: Optional[DPOptimizer] = None, noise_scheduler: Optional[_NoiseScheduler] = None, grad_clip_scheduler: Optional[_GradClipScheduler] = None, @@ -579,7 +656,7 @@ def save_checkpoint( Saves the state_dict of module, optimizer, and accountant at path. Args: path: Path to save the state dict objects. - module: GradSampleModule to save; wrapped module's state_dict is saved. + module: Module to save (wrapped or unwrapped); module's state_dict is saved. optimizer: DPOptimizer to save; wrapped optimizer's state_dict is saved. noise_scheduler: _NoiseScheduler whose state we should save. grad_clip_scheduler: _GradClipScheduler whose state we should save. @@ -608,7 +685,7 @@ def load_checkpoint( self, *, path: Union[str, os.PathLike, BinaryIO, IO[bytes]], - module: GradSampleModule, + module: Union[nn.Module, GradSampleModule], optimizer: Optional[DPOptimizer] = None, noise_scheduler: Optional[_NoiseScheduler] = None, grad_clip_scheduler: Optional[_GradClipScheduler] = None, diff --git a/opacus/tests/adaptive_clipping_test.py b/opacus/tests/adaptive_clipping_test.py new file mode 100644 index 00000000..d179ca0a --- /dev/null +++ b/opacus/tests/adaptive_clipping_test.py @@ -0,0 +1,440 @@ +import unittest + +import torch +import torch.nn as nn +from opacus.optimizers.adaclipoptimizer import AdaClipDPOptimizer +from opacus.privacy_engine import PrivacyEngine +from torch.utils.data import DataLoader, TensorDataset + + +class SimpleNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 20) + self.fc2 = nn.Linear(20, 5) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + return self.fc2(x) + + +class BaseAdaClipTest: + """Base test class for AdaClipDPOptimizer with different privacy engines.""" + + # Subclasses should set this + ENGINE_CLASS = None + USE_CONTROLLER = False # Set to True in controller-based subclass + + def setUp(self): + self.DATA_SIZE = 100 + self.BATCH_SIZE = 10 + self.LR = 0.1 + + # Create simple dataset + self.data = torch.randn(self.DATA_SIZE, 10) + self.labels = torch.randint(0, 5, (self.DATA_SIZE,)) + self.dataset = TensorDataset(self.data, self.labels) + self.dataloader = DataLoader( + self.dataset, batch_size=self.BATCH_SIZE, drop_last=False + ) + + def tearDown(self): + """Clean up controller if needed.""" + if hasattr(self, "controller") and self.controller is not None: + self.controller.cleanup() + + def _make_private(self, model, optimizer, **kwargs): + """ + Wrapper to handle both PrivacyEngine modes. + + Returns: (model, optimizer, dataloader, controller_or_none) + """ + privacy_engine = self.ENGINE_CLASS() + + # Use controller mode if specified + if self.USE_CONTROLLER: + # Controller-based mode with return_controller=True + # When return_controller=True, make_private returns (model, optimizer, dataloader) + # and stores the controller on model._opacus_controller + model, optimizer, dataloader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=self.dataloader, + return_controller=True, + **kwargs, + ) + # Extract controller from model + controller = model._opacus_controller + return model, optimizer, dataloader, controller + else: + # Standard wrapped mode + model, optimizer, dataloader = privacy_engine.make_private( + module=model, optimizer=optimizer, data_loader=self.dataloader, **kwargs + ) + return model, optimizer, dataloader, None + + def test_adaclip_optimizer_initialization(self): + """Test that AdaClipDPOptimizer can be initialized.""" + model = SimpleNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + + # Make private with AdaClip optimizer + # Note: noise_multiplier must be < 2 * unclipped_num_std (AdaClip constraint) + unclipped_num_std = 1.0 + model, optimizer, dataloader, self.controller = self._make_private( + model=model, + optimizer=optimizer, + noise_multiplier=0.5, # < 2 * unclipped_num_std (1.0) + max_grad_norm=1.0, + poisson_sampling=False, + clipping="adaptive", + target_unclipped_quantile=0.5, + clipbound_learning_rate=0.2, + max_clipbound=10.0, + min_clipbound=0.01, + unclipped_num_std=unclipped_num_std, + ) + + # Verify optimizer is AdaClipDPOptimizer + self.assertIsInstance(optimizer, AdaClipDPOptimizer) + + # Verify AdaClip-specific attributes exist + self.assertTrue(hasattr(optimizer, "target_unclipped_quantile")) + self.assertTrue(hasattr(optimizer, "clipbound_learning_rate")) + self.assertTrue(hasattr(optimizer, "max_clipbound")) + self.assertTrue(hasattr(optimizer, "min_clipbound")) + self.assertTrue(hasattr(optimizer, "unclipped_num")) + self.assertTrue(hasattr(optimizer, "sample_size")) + + def test_adaclip_clipbound_updates(self): + """Test that adaptive clipping actually updates the clipping bound.""" + model = SimpleNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + + model, optimizer, dataloader, self.controller = self._make_private( + model=model, + optimizer=optimizer, + noise_multiplier=0.0, # No noise for clearer results + max_grad_norm=1.0, # Initial clip bound + poisson_sampling=False, + clipping="adaptive", + target_unclipped_quantile=0.5, + clipbound_learning_rate=0.2, + max_clipbound=10.0, + min_clipbound=0.01, + unclipped_num_std=0.05, + ) + + criterion = nn.CrossEntropyLoss() + initial_clipbound = optimizer.max_grad_norm + clipbounds = [initial_clipbound] + + # Train for several steps and track clipbound changes + for i, (x, y) in enumerate(dataloader): + if i >= 5: + break + + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + + # Record clipbound after step + clipbounds.append(optimizer.max_grad_norm) + + # Verify that clipbound changed during training + unique_clipbounds = set(f"{cb:.6f}" for cb in clipbounds) + self.assertGreater( + len(unique_clipbounds), + 1, + f"Clipbound should change over time. Got values: {clipbounds}", + ) + + # Verify clipbound stays within bounds + for cb in clipbounds: + self.assertGreaterEqual(cb, 0.01) # min_clipbound + self.assertLessEqual(cb, 10.0) # max_clipbound + + def test_adaclip_unclipped_tracking(self): + """Test that AdaClip correctly tracks unclipped gradient counts.""" + model = SimpleNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + + unclipped_num_std = 1.0 + model, optimizer, dataloader, self.controller = self._make_private( + model=model, + optimizer=optimizer, + noise_multiplier=0.8, # < 2 * unclipped_num_std (0.8 < 1.0) + max_grad_norm=1.0, + poisson_sampling=False, + clipping="adaptive", + target_unclipped_quantile=0.5, + clipbound_learning_rate=0.2, + max_clipbound=10.0, + min_clipbound=0.01, + unclipped_num_std=unclipped_num_std, + ) + + criterion = nn.CrossEntropyLoss() + + # Train one step + x, y = next(iter(dataloader)) + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + + # Call optimizer.step() which triggers clip_and_accumulate + # that sets sample_size and unclipped_num + optimizer.step() + + # After step, clipbound should have been updated + self.assertIsNotNone(optimizer.max_grad_norm) + + # Verify sample_size is positive + sample_size = ( + float(optimizer.sample_size) + if torch.is_tensor(optimizer.sample_size) + else optimizer.sample_size + ) + self.assertGreater( + sample_size, 0, "Sample size should be positive after training step" + ) + + # Compute unclipped fraction (convert to float if tensor) + # Note: unclipped_num can be negative due to DP noise in AdaClip + unclipped_num = ( + float(optimizer.unclipped_num) + if torch.is_tensor(optimizer.unclipped_num) + else optimizer.unclipped_num + ) + unclipped_frac = unclipped_num / sample_size + # Due to DP noise, unclipped_frac may be slightly outside [0, 1] + # Just verify it's been set and is a reasonable value + self.assertGreater( + unclipped_frac, -0.5, "Unclipped fraction should not be too negative" + ) + self.assertLess( + unclipped_frac, 1.5, "Unclipped fraction should not be too large" + ) + + # Do another step to verify counters work across multiple steps + x, y = next(iter(dataloader)) + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + + # After another step, counters should still be valid + sample_size2 = ( + float(optimizer.sample_size) + if torch.is_tensor(optimizer.sample_size) + else optimizer.sample_size + ) + self.assertGreater(sample_size2, 0) + + def test_adaclip_convergence_behavior(self): + """Test that AdaClip converges toward target quantile.""" + torch.manual_seed(42) + model = SimpleNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + + target_quantile = 0.7 + unclipped_num_std = 0.5 + model, optimizer, dataloader, self.controller = self._make_private( + model=model, + optimizer=optimizer, + noise_multiplier=0.8, # < 2 * unclipped_num_std (0.8 < 1.0) + max_grad_norm=1.0, + poisson_sampling=False, + clipping="adaptive", + target_unclipped_quantile=target_quantile, + clipbound_learning_rate=0.1, + max_clipbound=10.0, + min_clipbound=0.01, + unclipped_num_std=unclipped_num_std, + ) + + criterion = nn.CrossEntropyLoss() + unclipped_fractions = [] + + # Train for multiple steps + for i, (x, y) in enumerate(dataloader): + if i >= 10: + break + + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + + # Record unclipped fraction before step + if optimizer.sample_size > 0: + unclipped_frac = float(optimizer.unclipped_num) / optimizer.sample_size + unclipped_fractions.append(unclipped_frac) + + optimizer.step() + + # Average unclipped fraction should be reasonably close to target + # (not exact due to noise and limited steps) + if len(unclipped_fractions) > 5: + avg_unclipped = sum(unclipped_fractions[-5:]) / 5 + # Should be within reasonable range of target + self.assertGreater(avg_unclipped, target_quantile - 0.3) + self.assertLess(avg_unclipped, target_quantile + 0.3) + + def test_adaclip_vs_fixed_clipping(self): + """Test that AdaClip behaves differently from fixed clipping.""" + torch.manual_seed(42) + + # Train with AdaClip + model1 = SimpleNet() + optimizer1 = torch.optim.SGD(model1.parameters(), lr=self.LR) + + model1, optimizer1, dataloader1, controller1 = self._make_private( + model=model1, + optimizer=optimizer1, + noise_multiplier=0.0, + max_grad_norm=1.0, + poisson_sampling=False, + clipping="adaptive", + target_unclipped_quantile=0.5, + clipbound_learning_rate=0.2, + max_clipbound=10.0, + min_clipbound=0.01, + unclipped_num_std=0.05, + ) + + # Train with fixed clipping + torch.manual_seed(42) + model2 = SimpleNet() + # Handle both wrapped (GradSampleModule) and unwrapped models + state_dict1 = model1.state_dict() + # If wrapped, state_dict has _module. prefix, need to remove it + if any(key.startswith("_module.") for key in state_dict1.keys()): + state_dict1 = {k.replace("_module.", ""): v for k, v in state_dict1.items()} + model2.load_state_dict(state_dict1) + optimizer2 = torch.optim.SGD(model2.parameters(), lr=self.LR) + + model2, optimizer2, dataloader2, controller2 = self._make_private( + model=model2, + optimizer=optimizer2, + noise_multiplier=0.0, + max_grad_norm=1.0, + poisson_sampling=False, + clipping="flat", # Fixed clipping + ) + + criterion = nn.CrossEntropyLoss() + + # Train both for several steps + for i, ((x1, y1), (x2, y2)) in enumerate(zip(dataloader1, dataloader2)): + if i >= 5: + break + + # AdaClip training + optimizer1.zero_grad() + output1 = model1(x1) + loss1 = criterion(output1, y1) + loss1.backward() + optimizer1.step() + + # Fixed clipping training + optimizer2.zero_grad() + output2 = model2(x2) + loss2 = criterion(output2, y2) + loss2.backward() + optimizer2.step() + + # After training, parameters should differ + # (because AdaClip adjusts clipbound while fixed doesn't) + params_differ = False + for p1, p2 in zip(model1.parameters(), model2.parameters()): + if not torch.allclose(p1, p2, atol=1e-5): + params_differ = True + break + + self.assertTrue( + params_differ, "AdaClip and fixed clipping should produce different results" + ) + + # Cleanup both controllers if they exist + if controller1: + controller1.cleanup() + if controller2: + controller2.cleanup() + # Mark as cleaned up so tearDown doesn't try again + self.controller = None + + def test_adaclip_parameter_validation(self): + """Test that AdaClip validates parameters correctly.""" + model = SimpleNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + + # Test: max_clipbound <= min_clipbound should raise error + with self.assertRaises(ValueError): + self._make_private( + model=model, + optimizer=optimizer, + noise_multiplier=0.05, # < 2 * unclipped_num_std + max_grad_norm=1.0, + clipping="adaptive", + target_unclipped_quantile=0.5, # Required param + clipbound_learning_rate=0.2, # Required param + max_clipbound=0.01, # Less than min - should trigger error + min_clipbound=0.1, + unclipped_num_std=0.05, + ) + + def test_adaclip_with_nonzero_noise(self): + """Test AdaClip works with noise (full DP training).""" + model = SimpleNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + + unclipped_num_std = 0.5 + model, optimizer, dataloader, self.controller = self._make_private( + model=model, + optimizer=optimizer, + noise_multiplier=0.8, # With noise, < 2 * unclipped_num_std + max_grad_norm=1.0, + poisson_sampling=False, + clipping="adaptive", + target_unclipped_quantile=0.5, + clipbound_learning_rate=0.2, + max_clipbound=10.0, + min_clipbound=0.01, + unclipped_num_std=unclipped_num_std, + ) + + criterion = nn.CrossEntropyLoss() + + # Train one step with noise + x, y = next(iter(dataloader)) + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + + # Verify training completed successfully + for param in model.parameters(): + self.assertIsNotNone(param.grad) + + +class AdaClipStandardEngineTest(BaseAdaClipTest, unittest.TestCase): + """Test AdaClipDPOptimizer with standard PrivacyEngine.""" + + ENGINE_CLASS = PrivacyEngine + + +class AdaClipGradSampleControllerEngineTest(BaseAdaClipTest, unittest.TestCase): + """Test AdaClipDPOptimizer with GradSampleController-based PrivacyEngine.""" + + ENGINE_CLASS = PrivacyEngine + USE_CONTROLLER = True # Use controller mode + + +if __name__ == "__main__": + unittest.main() diff --git a/opacus/tests/grad_sample_controller_fast_gradient_clipping_test.py b/opacus/tests/grad_sample_controller_fast_gradient_clipping_test.py new file mode 100644 index 00000000..8c768da4 --- /dev/null +++ b/opacus/tests/grad_sample_controller_fast_gradient_clipping_test.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.nn as nn +from opacus.grad_sample.grad_sample_controller_fast_gradient_clipping import ( + GradSampleControllerFastGradientClipping, +) +from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import ( + GradSampleModuleFastGradientClipping, +) + + +class SimpleModel(nn.Module): + """Simple model for testing""" + + def __init__(self, input_dim=10, hidden_dim=20, output_dim=5): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +class GradSampleControllerFastGradientClippingTest(unittest.TestCase): + def setUp(self): + self.batch_size = 4 + self.input_dim = 10 + self.hidden_dim = 20 + self.output_dim = 5 + self.max_grad_norm = 1.0 + self.loss_reduction = "mean" + + def test_controller_creation(self): + """Test that controller can be created without wrapping model""" + model = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + original_type = type(model) + + controller = GradSampleControllerFastGradientClipping( + model, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + use_ghost_clipping=True, + ) + + # Model type should be preserved + self.assertEqual(type(model), original_type) + self.assertIsInstance(model, SimpleModel) + + # Controller should have hooks + self.assertTrue(len(controller.autograd_grad_sample_hooks) > 0) + self.assertTrue(controller.hooks_enabled) + + # Clean up + controller.cleanup() + + def test_norm_sample_computation(self): + """Test that norm samples are computed correctly""" + model = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + + controller = GradSampleControllerFastGradientClipping( + model, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + use_ghost_clipping=False, # Use fast gradient clipping for testing + ) + + # Create dummy input and target + x = torch.randn(self.batch_size, self.input_dim) + target = torch.randint(0, self.output_dim, (self.batch_size,)) + + # Forward and backward pass + model.train() + output = model(x) + loss = nn.functional.cross_entropy( + output, target, reduction=self.loss_reduction + ) + loss.backward() + + # Check that norm samples are computed + for param in controller.trainable_parameters: + if param.requires_grad: + self.assertIsNotNone(param._norm_sample) + self.assertEqual(param._norm_sample.shape[0], self.batch_size) + + # Clean up + controller.cleanup() + + def test_clipping_coefficient(self): + """Test that clipping coefficients are computed correctly""" + model = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + + controller = GradSampleControllerFastGradientClipping( + model, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + use_ghost_clipping=False, + ) + + # Create dummy input and target + x = torch.randn(self.batch_size, self.input_dim) + target = torch.randint(0, self.output_dim, (self.batch_size,)) + + # Forward and backward pass + model.train() + output = model(x) + loss = nn.functional.cross_entropy( + output, target, reduction=self.loss_reduction + ) + loss.backward() + + # Get clipping coefficient + coeff = controller.get_clipping_coef() + + # Coefficients should be between 0 and 1 + self.assertTrue(torch.all(coeff >= 0)) + self.assertTrue(torch.all(coeff <= 1)) + self.assertEqual(coeff.shape[0], self.batch_size) + + # Clean up + controller.cleanup() + + def test_hooks_enable_disable(self): + """Test that hooks can be enabled and disabled""" + model = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + + controller = GradSampleControllerFastGradientClipping( + model, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + ) + + # Hooks should be enabled by default + self.assertTrue(controller.hooks_enabled) + + # Disable hooks + controller.disable_hooks() + self.assertFalse(controller.hooks_enabled) + + # Enable hooks + controller.enable_hooks() + self.assertTrue(controller.hooks_enabled) + + # Clean up + controller.cleanup() + + def test_cleanup(self): + """Test that cleanup removes all hooks and attributes""" + model = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + + controller = GradSampleControllerFastGradientClipping( + model, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + ) + + # Verify hooks and attributes exist + self.assertTrue(len(controller.autograd_grad_sample_hooks) > 0) + for param in model.parameters(): + self.assertTrue(hasattr(param, "_forward_counter")) + + # Cleanup + controller.cleanup() + + # Verify hooks are removed + self.assertEqual(len(controller.autograd_grad_sample_hooks), 0) + + # Verify attributes are removed + for param in model.parameters(): + self.assertFalse(hasattr(param, "grad_sample")) + self.assertFalse(hasattr(param, "_forward_counter")) + self.assertFalse(hasattr(param, "_norm_sample")) + + def test_controller_vs_wrapped_equivalence(self): + """Test that controller produces same norms as wrapped module""" + # Create two identical models + model_controller = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + model_wrapped = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + + # Copy weights + model_wrapped.load_state_dict(model_controller.state_dict()) + + # Create controller and wrapped module + controller = GradSampleControllerFastGradientClipping( + model_controller, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + use_ghost_clipping=False, # Use FGC for comparison + ) + + wrapped = GradSampleModuleFastGradientClipping( + model_wrapped, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + use_ghost_clipping=False, + ) + + # Create dummy input and target + torch.manual_seed(42) + x = torch.randn(self.batch_size, self.input_dim) + target = torch.randint(0, self.output_dim, (self.batch_size,)) + + # Forward and backward pass for controller + model_controller.train() + output_controller = model_controller(x.clone()) + loss_controller = nn.functional.cross_entropy( + output_controller, target, reduction=self.loss_reduction + ) + loss_controller.backward() + + # Forward and backward pass for wrapped + wrapped.train() + output_wrapped = wrapped(x.clone()) + loss_wrapped = nn.functional.cross_entropy( + output_wrapped, target, reduction=self.loss_reduction + ) + loss_wrapped.backward() + + # Get norms + norm_controller = controller.get_norm_sample() + norm_wrapped = wrapped.get_norm_sample() + + # Norms should be very close + self.assertTrue( + torch.allclose(norm_controller, norm_wrapped, rtol=1e-4, atol=1e-4) + ) + + # Clean up + controller.cleanup() + + def test_isinstance_preserved(self): + """Test that isinstance checks work after controller attachment""" + model = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + + # Before controller + self.assertIsInstance(model, SimpleModel) + self.assertIsInstance(model, nn.Module) + + # Create controller + controller = GradSampleControllerFastGradientClipping( + model, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + ) + + # After controller - isinstance should still work + self.assertIsInstance(model, SimpleModel) + self.assertIsInstance(model, nn.Module) + + # Clean up + controller.cleanup() + + def test_dp_tensor_arithmetic_operations(self): + """Test that DPTensorFastGradientClipping supports arithmetic operations""" + from opacus.optimizers import DPOptimizerFastGradientClipping + from opacus.utils.fast_gradient_clipping_utils import ( + DPTensorFastGradientClipping, + ) + + model = SimpleModel(self.input_dim, self.hidden_dim, self.output_dim) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + controller = GradSampleControllerFastGradientClipping( + model, + batch_first=True, + loss_reduction=self.loss_reduction, + max_grad_norm=self.max_grad_norm, + ) + + dp_optimizer = DPOptimizerFastGradientClipping( + optimizer=optimizer, + noise_multiplier=1.0, + max_grad_norm=self.max_grad_norm, + expected_batch_size=self.batch_size, + loss_reduction=self.loss_reduction, + ) + + loss_per_sample = torch.randn(self.batch_size) + dp_loss = DPTensorFastGradientClipping( + controller, dp_optimizer, loss_per_sample, self.loss_reduction + ) + + # Test division + divided_loss = dp_loss / 2.0 + self.assertIsInstance(divided_loss, DPTensorFastGradientClipping) + self.assertTrue( + torch.allclose(divided_loss.loss_per_sample, loss_per_sample / 2.0) + ) + + # Test multiplication + multiplied_loss = dp_loss * 3.0 + self.assertIsInstance(multiplied_loss, DPTensorFastGradientClipping) + self.assertTrue( + torch.allclose(multiplied_loss.loss_per_sample, loss_per_sample * 3.0) + ) + + # Test right multiplication + rmultiplied_loss = 3.0 * dp_loss + self.assertIsInstance(rmultiplied_loss, DPTensorFastGradientClipping) + self.assertTrue( + torch.allclose(rmultiplied_loss.loss_per_sample, 3.0 * loss_per_sample) + ) + + # Test addition with scalar + added_loss = dp_loss + 1.0 + self.assertIsInstance(added_loss, DPTensorFastGradientClipping) + self.assertTrue( + torch.allclose(added_loss.loss_per_sample, loss_per_sample + 1.0) + ) + + # Test addition with another DPTensor + loss_per_sample2 = torch.randn(self.batch_size) + dp_loss2 = DPTensorFastGradientClipping( + controller, dp_optimizer, loss_per_sample2, self.loss_reduction + ) + summed_loss = dp_loss + dp_loss2 + self.assertIsInstance(summed_loss, DPTensorFastGradientClipping) + self.assertTrue( + torch.allclose( + summed_loss.loss_per_sample, loss_per_sample + loss_per_sample2 + ) + ) + + # Test subtraction + subtracted_loss = dp_loss - 0.5 + self.assertIsInstance(subtracted_loss, DPTensorFastGradientClipping) + self.assertTrue( + torch.allclose(subtracted_loss.loss_per_sample, loss_per_sample - 0.5) + ) + + # Test right subtraction + rsubtracted_loss = 1.0 - dp_loss + self.assertIsInstance(rsubtracted_loss, DPTensorFastGradientClipping) + self.assertTrue( + torch.allclose(rsubtracted_loss.loss_per_sample, 1.0 - loss_per_sample) + ) + + # Test negation + negated_loss = -dp_loss + self.assertIsInstance(negated_loss, DPTensorFastGradientClipping) + self.assertTrue(torch.allclose(negated_loss.loss_per_sample, -loss_per_sample)) + + # Test item() + item_value = dp_loss.item() + self.assertIsInstance(item_value, float) + + # Test string representations + repr_str = repr(dp_loss) + self.assertIn("DPTensorFastGradientClipping", repr_str) + str_str = str(dp_loss) + self.assertIn("DPTensorFastGradientClipping", str_str) + + # Clean up + controller.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/opacus/tests/grad_sample_controller_test.py b/opacus/tests/grad_sample_controller_test.py new file mode 100644 index 00000000..85ad25e3 --- /dev/null +++ b/opacus/tests/grad_sample_controller_test.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from opacus.grad_sample import GradSampleController, GradSampleModule +from opacus.grad_sample.grad_sample_controller_fast_gradient_clipping import ( + GradSampleControllerFastGradientClipping, +) +from opacus.grad_sample.linear import compute_linear_grad_sample +from opacus.grad_sample.utils import register_grad_sampler +from torch.testing import assert_close +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import FakeData +from torchvision.models import mobilenet_v3_small + + +class SampleConvNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, 8, 2, padding=3) + self.conv2 = nn.Conv2d(16, 32, 4, 2) + self.fc1 = nn.Linear(32 * 4 * 4, 32) + self.fc2 = nn.Linear(32, 10) + + def forward(self, x): + # x of shape [B, 3, 28, 28] + x = F.relu(self.conv1(x)) # -> [B, 16, 14, 14] + x = F.max_pool2d(x, 2, 1) # -> [B, 16, 13, 13] + x = F.relu(self.conv2(x)) # -> [B, 32, 5, 5] + x = F.max_pool2d(x, 2, 1) # -> [B, 32, 4, 4] + x = x.view(-1, 32 * 4 * 4) # -> [B, 512] + x = F.relu(self.fc1(x)) # -> [B, 32] + x = self.fc2(x) # -> [B, 10] + return x + + def name(self): + return "SampleConvNet" + + +class GradSampleControllerTest(unittest.TestCase): + """Test GradSampleController - controller-based approach without model wrapping.""" + + CLS = GradSampleController + + def setUp(self): + self.original_model = SampleConvNet() + self.controller_model = SampleConvNet() + self.controller_model.load_state_dict( + self.original_model.state_dict(), strict=True + ) + + self.grad_sample_controller = self.CLS( + self.controller_model, batch_first=True, loss_reduction="mean" + ) + self.DATA_SIZE = 8 + self.setUp_data() + self.criterion = nn.L1Loss() + + def setUp_data(self): + self.ds = FakeData( + size=self.DATA_SIZE, + image_size=(3, 28, 28), + num_classes=10, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), + ) + self.dl = DataLoader(self.ds, batch_size=self.DATA_SIZE) + + def tearDown(self): + """Clean up controller hooks after each test.""" + if hasattr(self, "grad_sample_controller"): + self.grad_sample_controller.cleanup() + + def test_outputs_unaltered(self): + """ + Test that controller won't alter any outputs. + Model should behave identically with or without controller. + """ + x, _ = next(iter(self.dl)) + self.original_model = self.original_model.eval() + self.controller_model = self.controller_model.eval() + with torch.no_grad(): + normal_out = self.original_model(x) + controller_out = self.controller_model(x) + msg = ( + f"Controller L2 norm = : {controller_out.norm(2)}, ", + f"Original L2 norm = : {normal_out.norm(2)}, ", + f"MSE = {F.mse_loss(controller_out, normal_out)}, ", + f"L1 Loss = {F.l1_loss(controller_out, normal_out)}", + ) + assert_close(controller_out, normal_out, atol=1e-7, rtol=1e-5, msg=msg) + + def test_zero_grad(self): + """Test that zero_grad properly clears grad_sample attributes.""" + x, _ = next(iter(self.dl)) + self.controller_model = self.controller_model.train() + controller_out = self.controller_model(x) + loss = self.criterion(controller_out, torch.zeros_like(controller_out)) + loss.backward() + + # Clear grad_sample manually (controllers don't have zero_grad method) + for p in self.controller_model.parameters(): + if hasattr(p, "grad_sample"): + p.grad_sample = None + + params_with_gs = [ + n + for n, p in self.controller_model.named_parameters() + if hasattr(p, "grad_sample") and p.grad_sample is not None + ] + msg = ( + "After clearing grad_sample, the following parameters still " + f"have a grad_sample: {params_with_gs}" + ) + assert len(params_with_gs) == 0, msg + + def test_model_not_wrapped(self): + """Test that model is NOT wrapped - maintains original type.""" + # Model should be the original type, not wrapped + assert isinstance(self.controller_model, SampleConvNet) + assert not isinstance(self.controller_model, GradSampleModule) + assert type(self.controller_model).__name__ == "SampleConvNet" + + def test_remove_hooks(self): + """ + Test that after calling .remove_hooks() no hooks are left + """ + copy_of_original_model = SampleConvNet() + copy_of_original_model.load_state_dict( + self.original_model.state_dict(), + strict=True, + ) + new_grad_sample_controller = GradSampleController( + copy_of_original_model, batch_first=True, loss_reduction="mean" + ) + new_grad_sample_controller.remove_hooks() + + remaining_forward_hooks = { + module: module._forward_hooks + for module in copy_of_original_model.modules() + if module._forward_hooks + } + assert ( + not remaining_forward_hooks + ), f"Some forward hooks remain after .remove_hooks(): {remaining_forward_hooks}" + + remaining_backward_hooks = { + module: module._backward_hooks + for module in copy_of_original_model.modules() + if module._backward_hooks + } + assert ( + not remaining_backward_hooks + ), f"Some backward hooks remain after .remove_hooks(): {remaining_backward_hooks}" + + # Cleanup + new_grad_sample_controller.cleanup() + + def test_enable_hooks(self): + """Test that hooks can be enabled.""" + self.grad_sample_controller.enable_hooks() + assert self.grad_sample_controller.hooks_enabled + + def test_disable_hooks(self): + """Test that hooks can be disabled.""" + self.grad_sample_controller.disable_hooks() + assert not self.grad_sample_controller.hooks_enabled + + def test_standard_module_validation(self): + """Test validation behavior for standard modules.""" + + class SimpleLinear(nn.Module): + def __init__(self, in_f, out_f): + super().__init__() + self.p = nn.Parameter(torch.Tensor(in_f, out_f)) + + def forward(self, x: torch.Tensor): + return F.linear(x, self.p) + + # Should be handled by functorch + try: + controller = GradSampleController(SimpleLinear(4, 2)) + # Check that functorch is used for this module + self.assertTrue(hasattr(controller.module, "ft_compute_sample_grad")) + controller.cleanup() + except ImportError: + print("Test could not be ran because functorch not available") + + # Should not raise exception if strict=False + try: + controller = GradSampleController(SimpleLinear(4, 2), strict=False) + controller.cleanup() + except ImportError: + print("Test could not be ran because functorch not available") + + # Should not fail after relevant grad sampler has been registered + register_grad_sampler(SimpleLinear)(compute_linear_grad_sample) + controller = GradSampleController(SimpleLinear(4, 2)) + controller.cleanup() + + def test_custom_module_validation(self) -> None: + """Test that unsupported modules raise appropriate errors.""" + from opacus.validators.errors import UnsupportedModuleError + + with self.assertRaises(UnsupportedModuleError): + controller = GradSampleController(mobilenet_v3_small()) + controller.cleanup() + + def test_submodule_access(self) -> None: + """Test that submodules can be accessed directly (no wrapping).""" + # Direct access to submodules - no _module prefix needed + _ = self.controller_model.fc1 + _ = self.controller_model.fc2 + + with self.assertRaises(AttributeError): + _ = self.controller_model.fc3 + + def test_state_dict(self) -> None: + """Test that state_dict has no _module prefix (not wrapped).""" + controller_state_dict = self.controller_model.state_dict() + og_state_dict = self.original_model.state_dict() + + # Controller approach: state dict keys should match exactly (no _module prefix) + self.assertEqual(set(controller_state_dict.keys()), set(og_state_dict.keys())) + + for key in og_state_dict.keys(): + # Keys should be identical, no _module prefix + self.assertTrue(key in controller_state_dict) + assert_close(og_state_dict[key], controller_state_dict[key]) + + def test_load_state_dict(self) -> None: + """Test that state_dict can be loaded without _module prefix.""" + controller_state_dict = self.controller_model.state_dict() + new_model = SampleConvNet() + new_controller = GradSampleController( + new_model, batch_first=False, loss_reduction="mean" + ) + + # Should be able to load directly (no _module prefix) + new_model.load_state_dict(controller_state_dict) + + # Models should match + for key in self.original_model.state_dict().keys(): + self.assertTrue(key in new_model.state_dict()) + assert_close( + self.original_model.state_dict()[key], new_model.state_dict()[key] + ) + + new_controller.cleanup() + + def test_grad_sample_computation(self): + """Test that per-sample gradients are computed correctly.""" + x, _ = next(iter(self.dl)) + self.controller_model.train() + controller_out = self.controller_model(x) + loss = self.criterion(controller_out, torch.zeros_like(controller_out)) + loss.backward() + + # Check that grad_sample was computed for all trainable parameters + for name, param in self.controller_model.named_parameters(): + if param.requires_grad: + self.assertTrue( + hasattr(param, "grad_sample"), + f"Parameter {name} should have grad_sample", + ) + self.assertIsNotNone( + param.grad_sample, + f"Parameter {name} grad_sample should not be None", + ) + # grad_sample should have batch dimension + self.assertEqual( + param.grad_sample.shape[0], + self.DATA_SIZE, + f"Parameter {name} grad_sample batch dimension mismatch", + ) + + def test_cleanup(self): + """Test that cleanup removes all hooks and attributes.""" + x, _ = next(iter(self.dl)) + self.controller_model.train() + controller_out = self.controller_model(x) + loss = self.criterion(controller_out, torch.zeros_like(controller_out)) + loss.backward() + + # Verify grad_sample exists + for param in self.controller_model.parameters(): + if param.requires_grad: + self.assertTrue(hasattr(param, "grad_sample")) + + # Cleanup + self.grad_sample_controller.cleanup() + + # Verify attributes are removed + for param in self.controller_model.parameters(): + self.assertFalse( + hasattr(param, "grad_sample"), "grad_sample should be removed" + ) + self.assertFalse( + hasattr(param, "_forward_counter"), + "_forward_counter should be removed", + ) + + # Verify hooks are removed + remaining_forward_hooks = { + module: module._forward_hooks + for module in self.controller_model.modules() + if module._forward_hooks + } + self.assertFalse(remaining_forward_hooks, "All forward hooks should be removed") + + remaining_backward_hooks = { + module: module._backward_hooks + for module in self.controller_model.modules() + if module._backward_hooks + } + self.assertFalse( + remaining_backward_hooks, "All backward hooks should be removed" + ) + + def test_isinstance_preserved(self): + """Test that isinstance checks work correctly with controller (no wrapping).""" + # Model should still be instance of original class + self.assertIsInstance(self.controller_model, SampleConvNet) + self.assertIsInstance(self.controller_model, nn.Module) + + # Should NOT be instance of GradSampleModule + self.assertNotIsInstance(self.controller_model, GradSampleModule) + + +class GradSampleControllerFastGradientClippingTestUnit(GradSampleControllerTest): + """Test GradSampleControllerFastGradientClipping - controller with ghost clipping.""" + + CLS = GradSampleControllerFastGradientClipping + + def setUp(self): + """Set up with ghost clipping controller.""" + self.original_model = SampleConvNet() + self.controller_model = SampleConvNet() + self.controller_model.load_state_dict( + self.original_model.state_dict(), strict=True + ) + + # Ghost clipping requires max_grad_norm + self.grad_sample_controller = self.CLS( + self.controller_model, + batch_first=True, + loss_reduction="mean", + max_grad_norm=1.0, + use_ghost_clipping=False, # Use fast gradient clipping for these tests + ) + self.DATA_SIZE = 8 + self.setUp_data() + self.criterion = nn.L1Loss() + + def test_norm_sample_computation(self): + """Test that norm samples are computed correctly.""" + x, _ = next(iter(self.dl)) + self.controller_model.train() + controller_out = self.controller_model(x) + loss = self.criterion(controller_out, torch.zeros_like(controller_out)) + loss.backward() + + # Check that norm samples are computed + for name, param in self.controller_model.named_parameters(): + if param.requires_grad: + self.assertTrue( + hasattr(param, "_norm_sample"), + f"Parameter {name} should have _norm_sample", + ) + self.assertIsNotNone( + param._norm_sample, + f"Parameter {name} _norm_sample should not be None", + ) + # _norm_sample should have batch dimension + self.assertEqual( + param._norm_sample.shape[0], + self.DATA_SIZE, + f"Parameter {name} _norm_sample batch dimension mismatch", + ) + + def test_get_norm_sample(self): + """Test that get_norm_sample returns correct per-example norms.""" + x, _ = next(iter(self.dl)) + self.controller_model.train() + controller_out = self.controller_model(x) + loss = self.criterion(controller_out, torch.zeros_like(controller_out)) + loss.backward() + + # Get norms + norms = self.grad_sample_controller.get_norm_sample() + + # Check shape and values + self.assertEqual(norms.shape[0], self.DATA_SIZE) + self.assertTrue(torch.all(norms >= 0), "Norms should be non-negative") + + def test_get_clipping_coef(self): + """Test that clipping coefficients are computed correctly.""" + x, _ = next(iter(self.dl)) + self.controller_model.train() + controller_out = self.controller_model(x) + loss = self.criterion(controller_out, torch.zeros_like(controller_out)) + loss.backward() + + # Get clipping coefficients + coeff = self.grad_sample_controller.get_clipping_coef() + + # Check shape and values + self.assertEqual(coeff.shape[0], self.DATA_SIZE) + self.assertTrue(torch.all(coeff >= 0), "Coefficients should be non-negative") + self.assertTrue(torch.all(coeff <= 1), "Coefficients should be <= 1") + + +if __name__ == "__main__": + unittest.main() diff --git a/opacus/tests/grad_sample_module_test.py b/opacus/tests/grad_sample_module_test.py index 85c4ead4..eea6a042 100644 --- a/opacus/tests/grad_sample_module_test.py +++ b/opacus/tests/grad_sample_module_test.py @@ -226,7 +226,9 @@ def forward(self, x: torch.Tensor): GradSampleModule(SimpleLinear(4, 2)) def test_custom_module_validation(self) -> None: - with self.assertRaises(NotImplementedError): + from opacus.validators.errors import UnsupportedModuleError + + with self.assertRaises(UnsupportedModuleError): GradSampleModule(mobilenet_v3_small()) def test_submodule_access(self) -> None: diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index 5898d9ea..4217263d 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -69,6 +69,7 @@ def setUp(self): self.criterion = nn.CrossEntropyLoss() self.BATCH_FIRST = True self.GRAD_SAMPLE_MODE = "hooks" + self.RETURN_CONTROLLER = False # Override in subclasses for controller mode torch.manual_seed(42) @@ -143,6 +144,7 @@ def _init_private_training( poisson_sampling=poisson_sampling, clipping=clipping, grad_sample_mode=grad_sample_mode, + return_controller=self.RETURN_CONTROLLER, ) return model, optimizer, poisson_dl, privacy_engine @@ -1020,3 +1022,69 @@ def _init_data(self): def _init_model(self): return ModelWithCustomLinear() + + +# ============================================================================ +# Controller-based tests - Same tests but with return_controller=True +# ============================================================================ + + +class PrivacyEngineConvNetControllerTest(PrivacyEngineConvNetTest): + """Test ConvNet with controller-based approach (no model wrapping).""" + + def setUp(self) -> None: + super().setUp() + self.RETURN_CONTROLLER = True + + def tearDown(self) -> None: + """Clean up controller hooks after each test.""" + # The model might have a controller attached that needs cleanup + pass + + +class PrivacyEngineConvNetFrozenControllerTest(PrivacyEngineConvNetFrozenTest): + """Test ConvNet with frozen layers using controller-based approach.""" + + def setUp(self) -> None: + super().setUp() + self.RETURN_CONTROLLER = True + + def tearDown(self) -> None: + """Clean up controller hooks after each test.""" + pass + + +class PrivacyEngineTextControllerTest(PrivacyEngineTextTest): + """Test text models with controller-based approach.""" + + def setUp(self) -> None: + super().setUp() + self.RETURN_CONTROLLER = True + + def tearDown(self) -> None: + """Clean up controller hooks after each test.""" + pass + + +class PrivacyEngineTiedWeightsControllerTest(PrivacyEngineTiedWeightsTest): + """Test tied weights with controller-based approach.""" + + def setUp(self) -> None: + super().setUp() + self.RETURN_CONTROLLER = True + + def tearDown(self) -> None: + """Clean up controller hooks after each test.""" + pass + + +class PrivacyEngineCustomLayerControllerTest(PrivacyEngineCustomLayerTest): + """Test custom layers with controller-based approach.""" + + def setUp(self) -> None: + super().setUp() + self.RETURN_CONTROLLER = True + + def tearDown(self) -> None: + """Clean up controller hooks after each test.""" + pass diff --git a/opacus/utils/fast_gradient_clipping_utils.py b/opacus/utils/fast_gradient_clipping_utils.py index e051bbe6..4dac31cd 100644 --- a/opacus/utils/fast_gradient_clipping_utils.py +++ b/opacus/utils/fast_gradient_clipping_utils.py @@ -14,9 +14,6 @@ # limitations under the License. import torch -from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import ( - GradSampleModuleFastGradientClipping, -) from opacus.optimizers import DPOptimizerFastGradientClipping @@ -27,7 +24,7 @@ class DPTensorFastGradientClipping: def __init__( self, - module: GradSampleModuleFastGradientClipping, + module, # Union[GradSampleModuleFastGradientClipping, GradSampleControllerFastGradientClipping] optimizer: DPOptimizerFastGradientClipping, loss_per_sample: torch.Tensor, loss_reduction: str = "mean", @@ -35,7 +32,7 @@ def __init__( """ Args: - module: the module to train + module: the module or controller to train (GradSampleModuleFastGradientClipping or GradSampleControllerFastGradientClipping) optimizer: the optimizer used to train the module loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] @@ -47,10 +44,131 @@ def __init__( self.loss_reduction = loss_reduction def item(self): + return self.detach().item() + + def detach(self): if self.loss_reduction == "mean": - return torch.mean(self.loss_per_sample).detach().item() + return torch.mean(self.loss_per_sample).detach() elif self.loss_reduction == "sum": - return torch.sum(self.loss_per_sample).detach().item() + return torch.sum(self.loss_per_sample).detach() + + def __truediv__(self, other): + """ + Division operation for DPTensorFastGradientClipping. + Enables: loss / scalar + """ + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + self.loss_per_sample / other, + self.loss_reduction, + ) + + def __mul__(self, other): + """ + Multiplication operation for DPTensorFastGradientClipping. + Enables: loss * scalar or scalar * loss + """ + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + self.loss_per_sample * other, + self.loss_reduction, + ) + + def __rmul__(self, other): + """ + Right multiplication operation for DPTensorFastGradientClipping. + Enables: scalar * loss + """ + return self.__mul__(other) + + def __add__(self, other): + """ + Addition operation for DPTensorFastGradientClipping. + Enables: loss + scalar or loss + loss + """ + if isinstance(other, DPTensorFastGradientClipping): + if self.loss_reduction != other.loss_reduction: + raise ValueError( + f"Cannot add losses with different reductions: {self.loss_reduction} vs {other.loss_reduction}" + ) + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + self.loss_per_sample + other.loss_per_sample, + self.loss_reduction, + ) + else: + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + self.loss_per_sample + other, + self.loss_reduction, + ) + + def __radd__(self, other): + """ + Right addition operation for DPTensorFastGradientClipping. + Enables: scalar + loss + """ + return self.__add__(other) + + def __sub__(self, other): + """ + Subtraction operation for DPTensorFastGradientClipping. + Enables: loss - scalar or loss - loss + """ + if isinstance(other, DPTensorFastGradientClipping): + if self.loss_reduction != other.loss_reduction: + raise ValueError( + f"Cannot subtract losses with different reductions: {self.loss_reduction} vs {other.loss_reduction}" + ) + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + self.loss_per_sample - other.loss_per_sample, + self.loss_reduction, + ) + else: + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + self.loss_per_sample - other, + self.loss_reduction, + ) + + def __rsub__(self, other): + """ + Right subtraction operation for DPTensorFastGradientClipping. + Enables: scalar - loss + """ + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + other - self.loss_per_sample, + self.loss_reduction, + ) + + def __neg__(self): + """ + Negation operation for DPTensorFastGradientClipping. + Enables: -loss + """ + return DPTensorFastGradientClipping( + self.module, + self.optimizer, + -self.loss_per_sample, + self.loss_reduction, + ) + + def __repr__(self): + """String representation""" + return f"DPTensorFastGradientClipping(loss_reduction={self.loss_reduction}, shape={self.loss_per_sample.shape})" + + def __str__(self): + """String representation""" + return f"DPTensorFastGradientClipping({self.item():.4f})" def backward(self): """ @@ -84,7 +202,7 @@ class DPLossFastGradientClipping: def __init__( self, - module: GradSampleModuleFastGradientClipping, + module, # Union[GradSampleModuleFastGradientClipping, GradSampleControllerFastGradientClipping] optimizer: DPOptimizerFastGradientClipping, criterion, loss_reduction: str = "mean", diff --git a/tutorials/controller_based_privacy_engine.ipynb b/tutorials/controller_based_privacy_engine.ipynb new file mode 100644 index 00000000..ecaffa24 --- /dev/null +++ b/tutorials/controller_based_privacy_engine.ipynb @@ -0,0 +1,361 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training with Controller-Based Privacy Engine (No Model Wrapping)\n", + "\n", + "This tutorial demonstrates how to use Opacus's controller-based privacy engine (`PrivacyEngine` with `return_controller=True`), which provides better compatibility with transformer models and other complex architectures by **avoiding model wrapping**.\n", + "\n", + "## Why Controller-Based?\n", + "\n", + "The standard `PrivacyEngine` wraps your model in a `GradSampleModule`, which can cause issues with:\n", + "- **Type checking**: `isinstance()` checks fail because the model is wrapped\n", + "- **State dict compatibility**: Wrapped models have `_module.` prefixes that complicate checkpoint loading\n", + "- **Complex architectures**: Models with custom `__getattr__` logic (e.g., HuggingFace transformers)\n", + "\n", + "The controller-based approach attaches hooks directly to your model via a `GradSampleController` **without wrapping it**, keeping your model's type and structure intact." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, let's import the necessary libraries and create a simple dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import TensorDataset, DataLoader\n", + "import torch\n", + "\n", + "# Create a synthetic dataset\n", + "n_samples = 1000\n", + "n_features = 20\n", + "n_classes = 10\n", + "\n", + "X = torch.randn(n_samples, n_features)\n", + "y = torch.randint(0, n_classes, (n_samples,))\n", + "\n", + "dataset = TensorDataset(X, y)\n", + "dataloader = DataLoader(dataset, batch_size=32, shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a Simple Model\n", + "\n", + "Let's create a simple neural network classifier:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "\n", + "\n", + "class SimpleClassifier(nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(input_dim, hidden_dim)\n", + " self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n", + " self.fc3 = nn.Linear(hidden_dim, output_dim)\n", + " self.relu = nn.ReLU()\n", + " \n", + " def forward(self, x):\n", + " x = self.relu(self.fc1(x))\n", + " x = self.relu(self.fc2(x))\n", + " x = self.fc3(x)\n", + " return x\n", + "\n", + "model = SimpleClassifier(n_features, 64, n_classes)\n", + "print(f\"Model type before: {type(model).__name__}\")\n", + "print(f\"isinstance check before: {isinstance(model, SimpleClassifier)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Standard PrivacyEngine (for comparison)\n", + "\n", + "Let's first see what happens with the standard `PrivacyEngine`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from opacus import PrivacyEngine\n", + "from torch import optim\n", + "\n", + "# Create a fresh model for standard approach\n", + "model_standard = SimpleClassifier(n_features, 64, n_classes)\n", + "optimizer_standard = optim.Adam(model_standard.parameters(), lr=0.001)\n", + "\n", + "privacy_engine = PrivacyEngine()\n", + "model_standard, optimizer_standard, dataloader_standard = privacy_engine.make_private(\n", + " module=model_standard,\n", + " optimizer=optimizer_standard,\n", + " data_loader=dataloader,\n", + " noise_multiplier=1.0,\n", + " max_grad_norm=1.0,\n", + ")\n", + "\n", + "print(f\"\\nStandard PrivacyEngine:\")\n", + "print(f\"Model type after: {type(model_standard).__name__}\")\n", + "print(f\"isinstance check after: {isinstance(model_standard, SimpleClassifier)}\")\n", + "print(f\"State dict keys (first 3): {list(model_standard.state_dict().keys())[:3]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how the model is now wrapped in `GradSampleModule`, `isinstance` checks fail, and state dict keys have `_module.` prefixes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Controller-Based PrivacyEngine\n", + "\n", + "Now let's use the controller-based approach:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import optim\n", + "\n", + "# Create a fresh model for controller-based approach\n", + "model = SimpleClassifier(n_features, 64, n_classes)\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "\n", + "# Initialize controller-based privacy engine\n", + "privacy_engine = PrivacyEngine()\n", + "\n", + "model, optimizer, dataloader = privacy_engine.make_private(\n", + " module=model,\n", + " optimizer=optimizer,\n", + " data_loader=dataloader,\n", + " noise_multiplier=1.0,\n", + " max_grad_norm=1.0,\n", + " return_controller=True,\n", + ")\n", + "\n", + "print(f\"\\nController-Based PrivacyEngine:\")\n", + "print(f\"Model type after: {type(model).__name__}\")\n", + "print(f\"isinstance check after: {isinstance(model, SimpleClassifier)}\")\n", + "print(f\"State dict keys (first 3): {list(model.state_dict().keys())[:3]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how the model **keeps its original type**, `isinstance` checks **still work**, and state dict keys are **clean without prefixes**!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Loop\n", + "\n", + "The training loop is identical to standard PyTorch:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model = model.to(device)\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "EPOCHS = 3\n", + "DELTA = 1e-5\n", + "\n", + "for epoch in range(EPOCHS):\n", + " model.train()\n", + " total_loss = 0\n", + " \n", + " for batch_idx, (data, target) in enumerate(dataloader):\n", + " data, target = data.to(device), target.to(device)\n", + " \n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = criterion(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " total_loss += loss.item()\n", + " \n", + " epsilon = privacy_engine.get_epsilon(DELTA)\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"Epoch {epoch + 1}/{EPOCHS} | Loss: {avg_loss:.4f} | ε: {epsilon:.2f} (δ={DELTA})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using `make_private_with_epsilon`\n", + "\n", + "You can also specify a target epsilon and have the privacy engine compute the appropriate noise multiplier:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create fresh instances\n", + "model2 = SimpleClassifier(n_features, 64, n_classes)\n", + "optimizer2 = optim.Adam(model2.parameters(), lr=0.001)\n", + "dataloader2 = DataLoader(dataset, batch_size=32, shuffle=True)\n", + "\n", + "privacy_engine2 = PrivacyEngine()\n", + "\n", + "model2, optimizer2, dataloader2 = privacy_engine2.make_private_with_epsilon(\n", + " module=model2,\n", + " optimizer=optimizer2,\n", + " data_loader=dataloader2,\n", + " target_epsilon=3.0,\n", + " target_delta=1e-5,\n", + " epochs=EPOCHS,\n", + " max_grad_norm=1.0,\n", + " return_controller=True,\n", + ")\n", + "\n", + "print(f\"Target epsilon: 3.0\")\n", + "print(f\"Computed noise multiplier: {privacy_engine2.noise_multiplier:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Checkpoint Saving and Loading\n", + "\n", + "Checkpoints are easier with controller-based approach since there are no `_module.` prefixes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save checkpoint\n", + "privacy_engine.save_checkpoint(\n", + " path=\"checkpoint.pt\",\n", + " module=model,\n", + " optimizer=optimizer,\n", + ")\n", + "print(\"Checkpoint saved!\")\n", + "\n", + "# Load checkpoint\n", + "model_loaded = SimpleClassifier(n_features, 64, n_classes)\n", + "optimizer_loaded = optim.Adam(model_loaded.parameters(), lr=0.001)\n", + "\n", + "privacy_engine_loaded = PrivacyEngine()\n", + "privacy_engine_loaded.load_checkpoint(\n", + " path=\"checkpoint.pt\",\n", + " module=model_loaded,\n", + " optimizer=optimizer_loaded,\n", + ")\n", + "print(\"Checkpoint loaded!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Key Differences Summary\n", + "\n", + "| Feature | Standard PrivacyEngine | Controller-Based PrivacyEngine |\n", + "|---------|------------------------|---------------------------|\n", + "| Model wrapping | Yes (GradSampleModule) | **No** |\n", + "| Type preservation | No | **Yes** |\n", + "| `isinstance()` works | No | **Yes** |\n", + "| State dict prefixes | `_module.` prefix | **Clean** |\n", + "| Direct attribute access | Via forwarding | **Direct** |\n", + "| Transformer compatibility | Can have issues | **Better** |\n", + "| Requires cleanup | No | **Yes** |\n", + "| API | Standard | **Same** |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## When to Use Controller-Based?\n", + "\n", + "Use `PrivacyEngine` with `return_controller=True` when:\n", + "- Working with HuggingFace transformers or other models with complex `__getattr__` logic\n", + "- You need `isinstance()` checks to work correctly\n", + "- You want clean state dicts without `_module.` prefixes\n", + "- You need direct access to model attributes\n", + "\n", + "Use standard `PrivacyEngine` when:\n", + "- You have simple models without complex introspection\n", + "- You don't need the benefits above\n", + "- You prefer the more battle-tested approach" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/website/tutorials.json b/website/tutorials.json index 2f8d9e68..f4264dfc 100644 --- a/website/tutorials.json +++ b/website/tutorials.json @@ -27,6 +27,10 @@ { "id": "ddp_tutorial", "title": "Training on multiple GPUs with DistributedDataParallel" + }, + { + "id": "controller_based_privacy_engine", + "title": "Training with Controller-Based Privacy Engine (No Model Wrapping)" } ] }