Skip to content

Commit 42cc52b

Browse files
david-stanevgri243
authored andcommitted
Add GradSampleControllerFastGradientClipping for controller-based privacy engine
1 parent bbc5f68 commit 42cc52b

File tree

6 files changed

+935
-20
lines changed

6 files changed

+935
-20
lines changed

opacus/grad_sample/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from .embedding import compute_embedding_grad_sample # noqa
2020
from .embedding_norm_sample import compute_embedding_norm_sample # noqa
2121
from .grad_sample_controller import GradSampleController # noqa
22+
from .grad_sample_controller_fast_gradient_clipping import ( # noqa
23+
GradSampleControllerFastGradientClipping,
24+
)
2225
from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample
2326
from .grad_sample_module_fast_gradient_clipping import ( # noqa
2427
GradSampleModuleFastGradientClipping,
@@ -47,6 +50,7 @@
4750

4851
__all__ = [
4952
"GradSampleController",
53+
"GradSampleControllerFastGradientClipping",
5054
"GradSampleModule",
5155
"GradSampleModuleFastGradientClipping",
5256
"GradSampleModuleFastGradientClippingFSDP",
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
GradSampleControllerFastGradientClipping: Controller-based Fast Gradient and Ghost Clipping.
18+
19+
This module provides a GradSampleModule-less approach with ghost clipping support,
20+
combining the benefits of:
21+
- Controller-based hook management (no model wrapping)
22+
- Ghost clipping (memory-efficient gradient norm computation)
23+
"""
24+
25+
import logging
26+
from typing import List
27+
28+
import torch
29+
import torch.nn as nn
30+
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
31+
from opacus.grad_sample.grad_sample_controller import GradSampleController
32+
from opacus.grad_sample.grad_sample_module import (
33+
create_or_accumulate_grad_sample,
34+
promote_current_grad_sample,
35+
)
36+
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN
37+
from opacus.utils.module_utils import trainable_modules, trainable_parameters
38+
39+
40+
logger = logging.getLogger(__name__)
41+
logger.disabled = True
42+
43+
44+
def create_norm_sample(
45+
*, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int
46+
) -> None:
47+
"""
48+
Creates a ``_norm_sample`` attribute in the given parameter
49+
50+
51+
Args:
52+
param: Parameter to which ``_norm_sample`` will be added
53+
grad_sample: Per-sample gradients tensor. Must be of the same
54+
shape as ``param`` with extra batch dimension
55+
"""
56+
57+
if param.requires_grad:
58+
if (
59+
max_batch_len == 0
60+
): # To handle the case of empty batch that may arise from Poisson sampling
61+
param._norm_sample = torch.tensor(
62+
[], device=grad_sample.device, dtype=grad_sample.dtype
63+
)
64+
else:
65+
param._norm_sample = torch.zeros(
66+
torch.Size([max_batch_len, 1]),
67+
device=grad_sample.device,
68+
dtype=grad_sample.dtype,
69+
)
70+
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(
71+
2, dim=-1
72+
)
73+
74+
75+
class GradSampleControllerFastGradientClipping(GradSampleController):
76+
"""
77+
Controller for managing privacy hooks with Fast Gradient and Ghost Clipping support
78+
79+
Extends GradSampleController to add ghost clipping support for memory-efficient
80+
gradient norm computation. Supports both:
81+
- Ghost Clipping: Direct norm computation without materializing full gradients
82+
- Fast Gradient Clipping: Full gradient computation followed by norm computation
83+
84+
This class attaches hooks directly to model modules and manages their lifecycle,
85+
providing an alternative to GradSampleModule wrapping that's more compatible
86+
with transformers and other complex models.
87+
"""
88+
89+
NORM_SAMPLERS = {}
90+
91+
def __init__(
92+
self,
93+
m: nn.Module,
94+
*,
95+
batch_first=True,
96+
loss_reduction="mean",
97+
strict: bool = True,
98+
force_functorch=False,
99+
max_grad_norm=1,
100+
use_ghost_clipping=True,
101+
):
102+
"""
103+
104+
Args:
105+
m: nn.Module to attach hooks to
106+
batch_first: Flag to indicate if the input tensor to the corresponding module
107+
has the first dimension representing the batch. If set to True, dimensions on
108+
input tensor are expected be ``[batch_size, ...]``, otherwise
109+
``[K, batch_size, ...]``
110+
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
111+
is a sum or a mean operation. Can take values "sum" or "mean"
112+
max_grad_norm: The value at which gradients are to be clipped.
113+
strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers,
114+
which is not currently supported by Opacus.
115+
If set to ``False``, per sample gradients will
116+
be computed on "best effort" basis - they will be available where
117+
possible and set to None otherwise. This is not recommended, because
118+
some unsupported modules (e.g. BatchNorm) affect other parameters and
119+
invalidate the concept of per sample gradients for the entire model.
120+
force_functorch: If set to ``True``, will use functorch to compute
121+
all per sample gradients. Otherwise, functorch will be used only
122+
for layers without registered grad sampler methods.
123+
use_ghost_clipping: If set to ``True``, Ghost Clipping
124+
will be used for clipping gradients of supported layers. If ``False``, Fast
125+
Gradient Clipping will be used for all layers.
126+
127+
Raises:
128+
NotImplementedError
129+
If ``strict`` is set to ``True`` and module ``m`` (or any of its
130+
submodules) includes a buffer.
131+
"""
132+
# Call parent constructor
133+
super().__init__(
134+
m,
135+
batch_first=batch_first,
136+
loss_reduction=loss_reduction,
137+
strict=strict,
138+
force_functorch=force_functorch,
139+
)
140+
141+
# Add ghost clipping specific attributes
142+
self.max_grad_norm = max_grad_norm
143+
self.use_ghost_clipping = use_ghost_clipping
144+
self._per_sample_gradient_norms = None
145+
146+
# Initialize _norm_sample attribute for parameters
147+
for _, p in trainable_parameters(self.module):
148+
p._norm_sample = None
149+
150+
self.trainable_parameters = [p for _, p in trainable_parameters(self.module)]
151+
152+
if logger.isEnabledFor(logging.INFO):
153+
self.log_module_gradient_sample_mode(
154+
module=m,
155+
force_functorch=force_functorch,
156+
use_ghost_clipping=use_ghost_clipping,
157+
)
158+
159+
def get_clipping_coef(self) -> torch.Tensor:
160+
"""Get per-example gradient scaling factor for clipping."""
161+
norm_sample = self.get_norm_sample()
162+
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
163+
164+
def get_norm_sample(self) -> torch.Tensor:
165+
"""Get per-example gradient norms."""
166+
norm_sample = torch.stack(
167+
[param._norm_sample for param in self.trainable_parameters], dim=0
168+
).norm(2, dim=0)
169+
self.per_sample_gradient_norms = norm_sample
170+
return norm_sample
171+
172+
def capture_activations_hook(
173+
self,
174+
module: nn.Module,
175+
forward_input: List[torch.Tensor],
176+
_forward_output: torch.Tensor,
177+
):
178+
"""
179+
Override parent method to add parameter tying check for ghost clipping.
180+
"""
181+
# Call parent implementation
182+
super().capture_activations_hook(module, forward_input, _forward_output)
183+
184+
# Add ghost clipping specific check for parameter tying
185+
if self.hooks_enabled:
186+
for _, p in trainable_parameters(module):
187+
if (
188+
self.use_ghost_clipping
189+
and p._forward_counter > 1
190+
and type(module) in self.NORM_SAMPLERS
191+
):
192+
raise NotImplementedError(
193+
"Parameter tying is not supported with Ghost Clipping"
194+
)
195+
196+
def capture_backprops_hook(
197+
self,
198+
module: nn.Module,
199+
_forward_input: torch.Tensor,
200+
forward_output: torch.Tensor,
201+
loss_reduction: str,
202+
batch_first: bool,
203+
):
204+
"""
205+
Computes per sample gradient norms given the current backprops and activations
206+
stored by the associated forward hook. Computed per sample gradient norms are
207+
stored in ``_norm_sample`` field in each parameter.
208+
209+
For non-recurrent layers the process is straightforward: for each
210+
``loss.backward()`` call this hook will be called exactly one. For recurrent
211+
layers, however, this is more complicated and the hook will be called multiple
212+
times, while still processing the same batch of data.
213+
214+
For this reason we first accumulate the gradients from *the same batch* in
215+
``p._current_grad_sample`` and then, when we detect the end of a full backward
216+
pass - we store accumulated result on ``p.grad_sample`` (for fast gradient clipping)
217+
or ``p._norm_sample`` (for ghost clipping).
218+
219+
Args:
220+
module: nn.Module,
221+
_forward_input: torch.Tensor,
222+
forward_output: torch.Tensor,
223+
loss_reduction: str,
224+
batch_first: bool,
225+
"""
226+
if not self.hooks_enabled:
227+
return
228+
229+
backprops = forward_output[0].detach()
230+
activations, backprops = self.rearrange_grad_samples(
231+
module=module,
232+
backprops=backprops,
233+
loss_reduction=loss_reduction,
234+
batch_first=batch_first,
235+
)
236+
237+
# Handle DTensor if needed
238+
activations = [
239+
temp.to_local() if type(temp) is torch.distributed.tensor.DTensor else temp
240+
for temp in activations
241+
]
242+
243+
if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS:
244+
# Ghost clipping: compute norms directly
245+
norm_sampler_fn = self.NORM_SAMPLERS[type(module)]
246+
norm_samples = norm_sampler_fn(module, activations, backprops)
247+
248+
for param, ns in norm_samples.items():
249+
if param.requires_grad:
250+
param._norm_sample = ns
251+
param._forward_counter -= 1
252+
253+
else:
254+
# Fast gradient clipping: compute full gradients then norms
255+
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
256+
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
257+
else:
258+
grad_sampler_fn = ft_compute_per_sample_gradient
259+
260+
grad_samples = grad_sampler_fn(module, activations, backprops)
261+
for param, gs in grad_samples.items():
262+
create_or_accumulate_grad_sample(
263+
param=param, grad_sample=gs, max_batch_len=module.max_batch_len
264+
)
265+
del grad_samples
266+
267+
# Detect end of current batch processing and switch accumulation
268+
# mode from sum to stacking. Used for RNNs and tied parameters
269+
# (See #417 for details)
270+
for _, p in trainable_parameters(module):
271+
p._forward_counter -= 1
272+
if p._forward_counter == 0:
273+
promote_current_grad_sample(p)
274+
create_norm_sample(
275+
param=p,
276+
grad_sample=p.grad_sample,
277+
max_batch_len=module.max_batch_len,
278+
)
279+
p.grad_sample = None
280+
281+
if len(module.activations) == 0:
282+
if hasattr(module, "max_batch_len"):
283+
del module.max_batch_len
284+
285+
def log_module_gradient_sample_mode(
286+
self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True
287+
):
288+
"""
289+
Add logs to track gradient sample mode for each part of the module, including 1) Ghost Clipping, 2) Fast Gradient Clipping (hook mode), and 3) Fast Gradient Clipping (functorch mode).
290+
291+
Args:
292+
module: nn.Module to be checked
293+
force_functorch: If set to ``True``, will use functorch to compute
294+
all per sample gradients. Otherwise, functorch will be used only
295+
for layers without registered grad sampler methods.
296+
use_ghost_clipping: If set to ``True``, Ghost Clipping
297+
will be used for clipping gradients of supported layers. If ``False``, Fast
298+
Gradient Clipping will be used for all layers.
299+
"""
300+
for m_name, m in trainable_modules(module):
301+
if type(m) in [DPRNN, DPLSTM, DPGRU]:
302+
logger.info(
303+
f"Module name: {m_name}, module type: {type(m)}. No hook or functorch is added."
304+
)
305+
306+
elif use_ghost_clipping and type(m) in self.NORM_SAMPLERS:
307+
logger.info(
308+
f"Module name: {m_name}, module type: {type(m)}, under Ghost Clipping."
309+
)
310+
311+
else:
312+
if not force_functorch and type(m) in self.GRAD_SAMPLERS:
313+
# When functorch is not enforced, use FGC (hook mode) if the layer has a registered grad_sampler (supported). Otherwise, use FGC (functorch mode).
314+
logger.info(
315+
f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (hook mode)."
316+
)
317+
else:
318+
logger.info(
319+
f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (functorch mode)."
320+
)
321+
322+
@property
323+
def per_sample_gradient_norms(self) -> torch.Tensor:
324+
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""
325+
if self._per_sample_gradient_norms is not None:
326+
return self._per_sample_gradient_norms
327+
else:
328+
raise AttributeError(
329+
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
330+
)
331+
332+
@per_sample_gradient_norms.setter
333+
def per_sample_gradient_norms(self, value):
334+
self._per_sample_gradient_norms = value

0 commit comments

Comments
 (0)