Skip to content

Commit d2fa9fd

Browse files
committed
commenting out v_ and q_ biases as they are always const
1 parent b301b46 commit d2fa9fd

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

DeBERTa/apps/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,10 @@ def run_onnx_training(args, model, device, train_data, prefix=None):
265265
for step, batch in enumerate(AsyncDataLoader(train_dataloader, 100)):
266266
#import pdb
267267
#pdb.set_trace()
268+
lr = torch.tensor([0.0000000e+00]).to(device)
268269
batch = batch_to(batch, device)
269270
with torch.no_grad():
270-
trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels'])
271+
trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels'], lr)
271272
# conversion fails now with:
272273
# site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors
273274
# because torch.onnx.symbolic_opset10.broadcast_tensors does not exist

DeBERTa/deberta/disentangled_attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def __init__(self, config):
7777
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
7878
self.all_head_size = self.num_attention_heads * self.attention_head_size
7979
self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size*3, bias=False)
80-
self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
81-
self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
80+
# Looks like params below are never updated and const, so removing them
81+
#self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
82+
#self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
8283
self.pos_att_type = [x.strip() for x in getattr(config, 'pos_att_type', 'none').lower().split('|')] # c2p|p2c
8384

8485
self.relative_attention = getattr(config, 'relative_attention', False)
@@ -148,8 +149,10 @@ def linear(w,b,x):
148149
k,v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1,3)]
149150
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q,k,v]]
150151

151-
query_layer += self.transpose_for_scores(self.q_bias.unsqueeze(0).unsqueeze(0))
152-
value_layer += self.transpose_for_scores(self.v_bias.unsqueeze(0).unsqueeze(0))
152+
q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
153+
v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
154+
query_layer += self.transpose_for_scores(q_bias.unsqueeze(0).unsqueeze(0))
155+
value_layer += self.transpose_for_scores(v_bias.unsqueeze(0).unsqueeze(0))
153156

154157
rel_att = None
155158
# Take the dot product between "query" and "key" to get the raw attention scores.

0 commit comments

Comments
 (0)