You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False.
102
+
self._qkv_same_embed_dim=False
100
103
101
104
self.num_heads=num_heads
102
105
self.dropout=dropout
106
+
self.batch_first=batch_first
103
107
self.head_dim=embed_dim//num_heads
104
108
assert (
105
109
self.head_dim*num_heads==self.embed_dim
@@ -120,6 +124,10 @@ def __init__(
120
124
121
125
self.dropout=nn.Dropout(dropout)
122
126
127
+
# to avoid null pointers in Transformer.forward
128
+
self.in_proj_weight=None
129
+
self.in_proj_bias=None
130
+
123
131
defload_state_dict(self, state_dict):
124
132
r"""
125
133
Loads module from previously saved state.
@@ -178,7 +186,33 @@ def forward(
178
186
key_padding_mask=None,
179
187
need_weights=True,
180
188
attn_mask=None,
189
+
is_causal=False,
181
190
):
191
+
is_batched=query.dim() ==3
192
+
193
+
assertis_batched==True, "The query must have a dimension of 3."
194
+
195
+
r"""
196
+
As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function,
197
+
since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``.
198
+
"""
199
+
assert (
200
+
is_causal==False
201
+
), "We currently do not support causal mask. Will fix it in the future."
202
+
203
+
r"""
204
+
Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).
0 commit comments