Skip to content

Commit 51f5111

Browse files
authored
Merge branch 'main' into main
2 parents 255e371 + da05f77 commit 51f5111

File tree

6 files changed

+56
-15
lines changed

6 files changed

+56
-15
lines changed

.circleci/config.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ jobs:
343343
integrationtest_py39_torch_release_cuda:
344344
machine:
345345
resource_class: gpu.nvidia.small.multi
346-
image: ubuntu-2004-cuda-11.4:202110-01
346+
image: linux-cuda-12:default
347347
steps:
348348
- checkout
349349
- py_3_9_setup
@@ -363,7 +363,7 @@ jobs:
363363
micro_benchmarks_py39_torch_release_cuda:
364364
machine:
365365
resource_class: gpu.nvidia.small.multi
366-
image: ubuntu-2004-cuda-11.4:202110-01
366+
image: linux-cuda-12:default
367367
steps:
368368
- checkout
369369
- py_3_9_setup
@@ -447,7 +447,7 @@ jobs:
447447
unittest_multi_gpu:
448448
machine:
449449
resource_class: gpu.nvidia.medium.multi
450-
image: ubuntu-2004-cuda-11.4:202110-01
450+
image: linux-cuda-12:default
451451
steps:
452452
- checkout
453453
- py_3_9_setup
@@ -515,4 +515,3 @@ workflows:
515515
filters: *exclude_ghpages
516516
- micro_benchmarks_py39_torch_release_cuda:
517517
filters: *exclude_ghpages
518-

opacus/accountants/analysis/prv/prvs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,15 @@ def mean(self) -> float:
9696
"""
9797
Calculate the mean using numerical integration.
9898
"""
99+
# determine points based on t_min and t_max
100+
lower_exponent = int(np.log10(np.abs(self.t_min)))
101+
upper_exponent = int(np.log10(self.t_max))
99102
points = np.concatenate(
100103
[
101104
[self.t_min],
102-
-np.logspace(-5, -1, 5)[::-1],
103-
np.logspace(-5, -1, 5),
105+
-np.logspace(start=lower_exponent, stop=-5, num=10),
106+
[0],
107+
np.logspace(start=-5, stop=upper_exponent, num=10),
104108
[self.t_max],
105109
]
106110
)

opacus/layers/dp_multihead_attention.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,21 @@ def __init__(
8989
add_zero_attn=False,
9090
kdim=None,
9191
vdim=None,
92+
batch_first=False,
9293
device=None,
9394
dtype=None,
9495
):
9596
super(DPMultiheadAttention, self).__init__()
9697
self.embed_dim = embed_dim
9798
self.kdim = kdim if kdim is not None else embed_dim
9899
self.vdim = vdim if vdim is not None else embed_dim
99-
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
100+
101+
# when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False.
102+
self._qkv_same_embed_dim = False
100103

101104
self.num_heads = num_heads
102105
self.dropout = dropout
106+
self.batch_first = batch_first
103107
self.head_dim = embed_dim // num_heads
104108
assert (
105109
self.head_dim * num_heads == self.embed_dim
@@ -120,6 +124,10 @@ def __init__(
120124

121125
self.dropout = nn.Dropout(dropout)
122126

127+
# to avoid null pointers in Transformer.forward
128+
self.in_proj_weight = None
129+
self.in_proj_bias = None
130+
123131
def load_state_dict(self, state_dict):
124132
r"""
125133
Loads module from previously saved state.
@@ -178,7 +186,33 @@ def forward(
178186
key_padding_mask=None,
179187
need_weights=True,
180188
attn_mask=None,
189+
is_causal=False,
181190
):
191+
is_batched = query.dim() == 3
192+
193+
assert is_batched == True, "The query must have a dimension of 3."
194+
195+
r"""
196+
As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function,
197+
since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``.
198+
"""
199+
assert (
200+
is_causal == False
201+
), "We currently do not support causal mask. Will fix it in the future."
202+
203+
r"""
204+
Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).
205+
"""
206+
if self.batch_first:
207+
if key is value:
208+
if query is key:
209+
query = key = value = query.transpose(1, 0)
210+
else:
211+
query, key = [x.transpose(1, 0) for x in (query, key)]
212+
value = key
213+
else:
214+
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
215+
182216
tgt_len, bsz, embed_dim = query.size()
183217
if embed_dim != self.embed_dim:
184218
raise ValueError(
@@ -323,6 +357,9 @@ def forward(
323357
)
324358
attn_output = self.out_proj(attn_output)
325359

360+
if self.batch_first:
361+
attn_output = attn_output.transpose(1, 0)
362+
326363
if need_weights:
327364
# average attention weights over heads
328365
attn_output_weights = attn_output_weights.view(
@@ -361,7 +398,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
361398
keep_vars=keep_vars,
362399
)
363400

364-
if self._qkv_same_embed_dim:
401+
if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim):
365402
destination_alter[prefix + "in_proj_weight"] = torch.cat(
366403
(
367404
destination[prefix + "qlinear.weight"],

opacus/tests/batch_memory_manager_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818
import torch.nn as nn
19-
from hypothesis import given, settings
19+
from hypothesis import HealthCheck, given, settings
2020
from hypothesis import strategies as st
2121
from opacus import PrivacyEngine
2222
from opacus.utils.batch_memory_manager import BatchMemoryManager
@@ -59,7 +59,7 @@ def _init_training(self, batch_size=10, **data_loader_kwargs):
5959
batch_size=st.sampled_from([8, 16, 64]),
6060
max_physical_batch_size=st.sampled_from([4, 8]),
6161
)
62-
@settings(deadline=10000)
62+
@settings(suppress_health_check=list(HealthCheck), deadline=10000)
6363
def test_basic(
6464
self,
6565
num_workers: int,
@@ -119,7 +119,7 @@ def test_basic(
119119
num_workers=st.integers(0, 4),
120120
pin_memory=st.booleans(),
121121
)
122-
@settings(deadline=10000)
122+
@settings(suppress_health_check=list(HealthCheck), deadline=10000)
123123
def test_empty_batch(
124124
self,
125125
num_workers: int,

opacus/tests/privacy_engine_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import torch
2727
import torch.nn as nn
2828
import torch.nn.functional as F
29-
from hypothesis import given, settings
29+
from hypothesis import HealthCheck, given, settings
3030
from opacus import PrivacyEngine
3131
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
3232
from opacus.optimizers.optimizer import _generate_noise
@@ -266,7 +266,7 @@ def _compare_to_vanilla(
266266
use_closure=st.booleans(),
267267
max_steps=st.sampled_from([1, 4]),
268268
)
269-
@settings(deadline=None)
269+
@settings(suppress_health_check=list(HealthCheck), deadline=None)
270270
def test_compare_to_vanilla(
271271
self,
272272
do_clip: bool,
@@ -552,7 +552,7 @@ def test_parameters_match(self):
552552
has_noise_scheduler=st.booleans(),
553553
has_grad_clip_scheduler=st.booleans(),
554554
)
555-
@settings(deadline=None)
555+
@settings(suppress_health_check=list(HealthCheck), deadline=None)
556556
def test_checkpoints(
557557
self, has_noise_scheduler: bool, has_grad_clip_scheduler: bool
558558
):
@@ -659,7 +659,7 @@ def test_checkpoints(
659659
max_steps=st.integers(8, 10),
660660
secure_mode=st.just(False), # TODO: enable after fixing torchcsprng build
661661
)
662-
@settings(deadline=None)
662+
@settings(suppress_health_check=list(HealthCheck), deadline=None)
663663
def test_noise_level(
664664
self,
665665
noise_multiplier: float,

opacus/validators/multihead_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention:
4545
add_zero_attn=module.add_zero_attn,
4646
kdim=module.kdim,
4747
vdim=module.vdim,
48+
batch_first=module.batch_first,
4849
)
4950
dp_attn.load_state_dict(module.state_dict())
5051
return dp_attn

0 commit comments

Comments
 (0)