Skip to content

Commit 50dfe61

Browse files
committed
fix
1 parent 598e239 commit 50dfe61

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torch_to_onnx.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,18 @@ def build_input_data1():
4242

4343
def build_input_data2():
4444
pad_size = config.pad_size
45-
mask = torch.LongTensor([[0]*pad_size]).cuda()
4645
ids = torch.randint(1, 10, (1, pad_size)).cuda()
4746
seq_len = torch.randint(1, 10, (1,)).cuda() # , 不能少
4847
mask = torch.randint(1, 10, (1, pad_size)).cuda()
4948
return [ids, seq_len, mask]
5049

5150

5251
if __name__ == '__main__':
53-
data = build_input_data1()
52+
args = build_input_data2()
5453

5554
input_names = ['ids','seq_len', 'mask']
5655
torch.onnx.export(model,
57-
(data,),
56+
args,
5857
'model.onnx',
5958
export_params = True,
6059
opset_version=11,

0 commit comments

Comments
 (0)