Skip to content

Commit c89eb58

Browse files
lzhangzzirexyc
authored andcommitted
wip
1 parent 7f85148 commit c89eb58

File tree

2 files changed

+148
-1
lines changed

2 files changed

+148
-1
lines changed

demo/python/to_triton_model.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import argparse
3+
import json
4+
import os.path as osp
5+
6+
import tritonclient.grpc.model_config_pb2 as pb
7+
from google.protobuf import text_format
8+
9+
10+
def parse_args():
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument('model_path', type=str)
13+
parser.add_argument('--name', type=str)
14+
parser.add_argument('--nocopy', type=bool)
15+
return parser.parse_args()
16+
17+
18+
IMAGE_INPUT = [
19+
dict(
20+
name='ori_img',
21+
dtype=pb.TYPE_UINT8,
22+
format=pb.ModelInput.FORMAT_NHWC,
23+
dims=[-1, -1, 3]),
24+
dict(name='pix_fmt', dtype=pb.TYPE_INT32, dims=[1], optional=True)
25+
]
26+
27+
TASK_OUTPUT = dict(
28+
Preprocess=[
29+
dict(name='img', dtype=pb.TYPE_FP32, dims=[3, -1, -1]),
30+
dict(name='img_metas', dtype=pb.TYPE_STRING, dims=[1])
31+
],
32+
Classifier=[
33+
dict(name='scores', dtype=pb.TYPE_FP32, dims=[-1, 1]),
34+
dict(name='label_ids', dtype=pb.TYPE_FP32, dims=[-1, 1])
35+
],
36+
Detector=[dict(name='dets', dtype=pb.TYPE_FP32, dims=[-1, 1])],
37+
Segmentor=[
38+
dict(name='mask', dtype=pb.TYPE_INT32, dims=[-1, -1]),
39+
dict(name='score', dtype=pb.TYPE_FP32, dims=[-1, -1, -1])
40+
],
41+
Restorer=[
42+
dict(name='output', dtype=pb.TYPE_FP32, dims=[-1, -1, 3])
43+
],
44+
TextDetector=[],
45+
TextRecognizer=[],
46+
PoseDetector=[],
47+
RotatedDetector=[],
48+
TextOCR=[],
49+
DetPose=[])
50+
51+
52+
def add_input(model_config, params):
53+
p = model_config.input.add()
54+
p.name = params['name']
55+
p.data_type = params['dtype']
56+
p.dims.extend(params['dims'])
57+
if 'format' in params:
58+
p.format = params['format']
59+
if 'optional' in params:
60+
p.optional = params['optional']
61+
62+
63+
def add_output(model_config, params):
64+
p = model_config.output.add()
65+
p.name = params['name']
66+
p.data_type = params['dtype']
67+
p.dims.extend(params['dims'])
68+
69+
70+
def serialize_model_config(model_config):
71+
return text_format.MessageToString(
72+
model_config,
73+
use_short_repeated_primitives=True,
74+
use_index_order=True,
75+
print_unknown_fields=True)
76+
77+
78+
def create_model_config(name, task, backend=None, platform=None):
79+
model_config = pb.ModelConfig()
80+
if backend:
81+
model_config.backend = backend
82+
if platform:
83+
model_config.platform = platform
84+
model_config.name = name
85+
model_config.max_batch_size = 0
86+
87+
for input in IMAGE_INPUT:
88+
add_input(model_config, input)
89+
for output in TASK_OUTPUT[task]:
90+
add_output(model_config, output)
91+
return model_config
92+
93+
94+
def create_preprocess_model():
95+
pass
96+
97+
98+
def get_onnx_io_names(detail_info):
99+
onnx_config = detail_info['onnx_config']
100+
return onnx_config['input_names'], onnx_config['output_names']
101+
102+
103+
def create_inference_model(deploy_info, pipeline_info, detail_info):
104+
if 'pipeline' in pipeline_info:
105+
# old-style pipeline specification
106+
pipeline = pipeline_info['pipeline']['tasks']
107+
else:
108+
pipeline = pipeline_info['tasks']
109+
110+
for task_cfg in pipeline:
111+
if task_cfg['module'] == 'Net':
112+
input_names, output_names = get_onnx_io_names(detail_info)
113+
114+
115+
def create_postprocess_model():
116+
pass
117+
118+
119+
def create_pipeline_model():
120+
pass
121+
122+
123+
def create_ensemble_model(deploy_cfg, pipeline_cfg):
124+
inference_model_config = create_inference_model(deploy_cfg, pipeline_cfg)
125+
preprocess_model_config = create_preprocess_model()
126+
postprocess_model_config = create_postprocess_model()
127+
pipeline_model_config = create_pipeline_model()
128+
129+
130+
def main():
131+
args = parse_args()
132+
model_path = args.model_path
133+
if not osp.isdir(model_path):
134+
model_path = osp.split(model_path)[-2]
135+
if osp.isdir(model_path):
136+
with open(osp.join(model_path, 'deploy.json'), 'r') as f:
137+
deploy_cfg = json.load(f)
138+
with open(osp.join(model_path, 'pipeline.json'), 'r') as f:
139+
pipeline_cfg = json.load(f)
140+
task = deploy_cfg['task']
141+
model_config = create_model_config('model', task, 'onnxruntime')
142+
data = serialize_model_config(model_config)
143+
print(data)
144+
145+
146+
if __name__ == '__main__':
147+
main()

demo/python/triton_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def main():
105105
Classifier=(('scores', 'labels'), vis_cls),
106106
Detector=(('bboxes', 'labels'), vis_det),
107107
TextOCR=(('dets', 'text', 'text_score'), vis_ocr),
108-
Restorer=(('output',), lambda _, hires: hires),
108+
Restorer=(('output', ), lambda _, hires: hires),
109109
Segmentor=(('mask', 'score'), vis_seg),
110110
RotatedDetector=(('bboxes', 'labels'), None),
111111
DetPose=(('bboxes', 'keypoints'), vis_pose))

0 commit comments

Comments
 (0)