-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathonnx_export.py
More file actions
145 lines (110 loc) · 5.05 KB
/
onnx_export.py
File metadata and controls
145 lines (110 loc) · 5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import torch
import torch.nn.functional as F
from models.bisenet import BiSeNet
class ONNXBiSeNet(torch.nn.Module):
"""Wrapper for BiSeNet that handles ONNX export safely"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
# Override the problematic parts for ONNX export
return self._forward_onnx_safe(x)
def _forward_onnx_safe(self, x):
"""ONNX-safe forward pass"""
h, w = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self._context_path_onnx_safe(x)
feat_fuse = self._ffm_onnx_safe(feat_res8, feat_cp8)
feat_out = self.model.conv_out(feat_fuse)
feat_out16 = self.model.conv_out16(feat_cp8)
feat_out32 = self.model.conv_out32(feat_cp16)
feat_out = F.interpolate(feat_out, (h, w), mode="bilinear", align_corners=True)
feat_out16 = F.interpolate(feat_out16, (h, w), mode="bilinear", align_corners=True)
feat_out32 = F.interpolate(feat_out32, (h, w), mode="bilinear", align_corners=True)
return feat_out, feat_out16, feat_out32
def _context_path_onnx_safe(self, x):
"""ONNX-safe Context Path"""
# features from backbone
feat8, feat16, feat32 = self.model.fpn.backbone(x)
h8, w8 = feat8.size()[2:]
h16, w16 = feat16.size()[2:]
h32, w32 = feat32.size()[2:]
# Use adaptive_avg_pool2d instead of avg_pool2d with dynamic size
avg = F.adaptive_avg_pool2d(feat32, (1, 1))
avg = self.model.fpn.conv_avg(avg)
avg_up = F.interpolate(avg, (h32, w32), mode="nearest")
feat32_arm = self._arm_onnx_safe(self.model.fpn.arm32, feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (h16, w16), mode="nearest")
feat32_up = self.model.fpn.conv_head32(feat32_up)
feat16_arm = self._arm_onnx_safe(self.model.fpn.arm16, feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (h8, w8), mode="nearest")
feat16_up = self.model.fpn.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up
def _arm_onnx_safe(self, arm_module, x):
"""ONNX-safe Attention Refinement Module"""
feat = arm_module.conv_block(x)
# Use adaptive_avg_pool2d instead of avg_pool2d with dynamic size
pool = F.adaptive_avg_pool2d(feat, (1, 1))
attention = arm_module.attention(pool)
out = torch.mul(feat, attention)
return out
def _ffm_onnx_safe(self, fsp, fcp):
"""ONNX-safe Feature Fusion Module"""
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.model.ffm.conv_block(fcat)
# Use adaptive_avg_pool2d instead of avg_pool2d with dynamic size
attention = F.adaptive_avg_pool2d(feat, (1, 1))
attention = self.model.ffm.conv1(attention)
attention = self.model.ffm.relu(attention)
attention = self.model.ffm.conv2(attention)
attention = self.model.ffm.sigmoid(attention)
feat_attention = torch.mul(feat, attention)
feat_out = feat_attention + feat
return feat_out
def torch2onnx_export(params):
num_classes = 19
model = BiSeNet(num_classes, backbone_name=params.model)
model.load_state_dict(torch.load(params.weight))
model.eval()
# Wrap model for ONNX export
onnx_model = ONNXBiSeNet(model)
onnx_model.eval()
onnx_model_path = params.weight.replace(".pt", ".onnx")
dummy_input = torch.randn(1, 3, 256, 256, requires_grad=True)
# Export the model to ONNX
torch.onnx.export(
onnx_model,
dummy_input,
onnx_model_path,
export_params=True,
opset_version=19, # the ONNX version to export the model to (compatible with ONNX Runtime)
do_constant_folding=True,
input_names=['input'],
output_names=['output', 'output16', 'output32'],
dynamic_axes={
'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'},
'output16': {0: 'batch_size'},
'output32': {0: 'batch_size'}
}
)
def parse_args():
parser = argparse.ArgumentParser(description="Face parsing inference")
parser.add_argument("--model", type=str, default="resnet18", help="model name, i.e resnet18, resnet34")
parser.add_argument(
"--weight",
type=str,
default="./weights/resnet18.pt",
help="path to trained model, i.e resnet18/34"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
torch2onnx_export(params=args)
# python onnx_export.py --model resnet18 --weight ./weights/resnet18.pt
# python onnx_export.py --model resnet34 --weight ./weights/resnet34.pt
# python onnx_export.py --model efficientnet_b0 --weight ./weights/efficientnet_b0.pt
# python onnx_export.py --model efficientnet_b1 --weight ./weights/efficientnet_b1.pt
# python onnx_export.py --model efficientnet_b2 --weight ./weights/efficientnet_b2.pt