Skip to content

Commit da05f77

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Fixing bugs for DP MultiheadAttention (#598)
Summary: Pull Request resolved: #598 Fixing the null pointers in calling DP MultiheadAttention by transform.forward Reviewed By: karthikprasad Differential Revision: D47405312 fbshipit-source-id: c323503ed5ecf2e8f0fc8e5d588cee563d972a4a
1 parent ad084da commit da05f77

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

opacus/layers/dp_multihead_attention.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,21 @@ def __init__(
8989
add_zero_attn=False,
9090
kdim=None,
9191
vdim=None,
92+
batch_first=False,
9293
device=None,
9394
dtype=None,
9495
):
9596
super(DPMultiheadAttention, self).__init__()
9697
self.embed_dim = embed_dim
9798
self.kdim = kdim if kdim is not None else embed_dim
9899
self.vdim = vdim if vdim is not None else embed_dim
99-
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
100+
101+
# when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False.
102+
self._qkv_same_embed_dim = False
100103

101104
self.num_heads = num_heads
102105
self.dropout = dropout
106+
self.batch_first = batch_first
103107
self.head_dim = embed_dim // num_heads
104108
assert (
105109
self.head_dim * num_heads == self.embed_dim
@@ -120,6 +124,10 @@ def __init__(
120124

121125
self.dropout = nn.Dropout(dropout)
122126

127+
# to avoid null pointers in Transformer.forward
128+
self.in_proj_weight = None
129+
self.in_proj_bias = None
130+
123131
def load_state_dict(self, state_dict):
124132
r"""
125133
Loads module from previously saved state.
@@ -178,7 +186,33 @@ def forward(
178186
key_padding_mask=None,
179187
need_weights=True,
180188
attn_mask=None,
189+
is_causal=False,
181190
):
191+
is_batched = query.dim() == 3
192+
193+
assert is_batched == True, "The query must have a dimension of 3."
194+
195+
r"""
196+
As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function,
197+
since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``.
198+
"""
199+
assert (
200+
is_causal == False
201+
), "We currently do not support causal mask. Will fix it in the future."
202+
203+
r"""
204+
Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).
205+
"""
206+
if self.batch_first:
207+
if key is value:
208+
if query is key:
209+
query = key = value = query.transpose(1, 0)
210+
else:
211+
query, key = [x.transpose(1, 0) for x in (query, key)]
212+
value = key
213+
else:
214+
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
215+
182216
tgt_len, bsz, embed_dim = query.size()
183217
if embed_dim != self.embed_dim:
184218
raise ValueError(
@@ -323,6 +357,9 @@ def forward(
323357
)
324358
attn_output = self.out_proj(attn_output)
325359

360+
if self.batch_first:
361+
attn_output = attn_output.transpose(1, 0)
362+
326363
if need_weights:
327364
# average attention weights over heads
328365
attn_output_weights = attn_output_weights.view(
@@ -361,7 +398,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
361398
keep_vars=keep_vars,
362399
)
363400

364-
if self._qkv_same_embed_dim:
401+
if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim):
365402
destination_alter[prefix + "in_proj_weight"] = torch.cat(
366403
(
367404
destination[prefix + "qlinear.weight"],

opacus/validators/multihead_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention:
4545
add_zero_attn=module.add_zero_attn,
4646
kdim=module.kdim,
4747
vdim=module.vdim,
48+
batch_first=module.batch_first,
4849
)
4950
dp_attn.load_state_dict(module.state_dict())
5051
return dp_attn

0 commit comments

Comments
 (0)