|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import argparse |
| 3 | +import json |
| 4 | +import os.path as osp |
| 5 | + |
| 6 | +import tritonclient.grpc.model_config_pb2 as pb |
| 7 | +from google.protobuf import text_format |
| 8 | + |
| 9 | + |
| 10 | +def parse_args(): |
| 11 | + parser = argparse.ArgumentParser() |
| 12 | + parser.add_argument('model_path', type=str) |
| 13 | + parser.add_argument('--name', type=str) |
| 14 | + parser.add_argument('--nocopy', type=bool) |
| 15 | + return parser.parse_args() |
| 16 | + |
| 17 | + |
| 18 | +IMAGE_INPUT = [ |
| 19 | + dict( |
| 20 | + name='ori_img', |
| 21 | + dtype=pb.TYPE_UINT8, |
| 22 | + format=pb.ModelInput.FORMAT_NHWC, |
| 23 | + dims=[-1, -1, 3]), |
| 24 | + dict(name='pix_fmt', dtype=pb.TYPE_INT32, dims=[1], optional=True) |
| 25 | +] |
| 26 | + |
| 27 | +TASK_OUTPUT = dict( |
| 28 | + Preprocess=[ |
| 29 | + dict(name='img', dtype=pb.TYPE_FP32, dims=[3, -1, -1]), |
| 30 | + dict(name='img_metas', dtype=pb.TYPE_STRING, dims=[1]) |
| 31 | + ], |
| 32 | + Classifier=[ |
| 33 | + dict(name='scores', dtype=pb.TYPE_FP32, dims=[-1, 1]), |
| 34 | + dict(name='label_ids', dtype=pb.TYPE_FP32, dims=[-1, 1]) |
| 35 | + ], |
| 36 | + Detector=[dict(name='dets', dtype=pb.TYPE_FP32, dims=[-1, 1])], |
| 37 | + Segmentor=[ |
| 38 | + dict(name='mask', dtype=pb.TYPE_INT32, dims=[-1, -1]), |
| 39 | + dict(name='score', dtype=pb.TYPE_FP32, dims=[-1, -1, -1]) |
| 40 | + ], |
| 41 | + Restorer=[ |
| 42 | + dict(name='output', dtype=pb.TYPE_FP32, dims=[-1, -1, 3]) |
| 43 | + ], |
| 44 | + TextDetector=[], |
| 45 | + TextRecognizer=[], |
| 46 | + PoseDetector=[], |
| 47 | + RotatedDetector=[], |
| 48 | + TextOCR=[], |
| 49 | + DetPose=[]) |
| 50 | + |
| 51 | + |
| 52 | +def add_input(model_config, params): |
| 53 | + p = model_config.input.add() |
| 54 | + p.name = params['name'] |
| 55 | + p.data_type = params['dtype'] |
| 56 | + p.dims.extend(params['dims']) |
| 57 | + if 'format' in params: |
| 58 | + p.format = params['format'] |
| 59 | + if 'optional' in params: |
| 60 | + p.optional = params['optional'] |
| 61 | + |
| 62 | + |
| 63 | +def add_output(model_config, params): |
| 64 | + p = model_config.output.add() |
| 65 | + p.name = params['name'] |
| 66 | + p.data_type = params['dtype'] |
| 67 | + p.dims.extend(params['dims']) |
| 68 | + |
| 69 | + |
| 70 | +def serialize_model_config(model_config): |
| 71 | + return text_format.MessageToString( |
| 72 | + model_config, |
| 73 | + use_short_repeated_primitives=True, |
| 74 | + use_index_order=True, |
| 75 | + print_unknown_fields=True) |
| 76 | + |
| 77 | + |
| 78 | +def create_model_config(name, task, backend=None, platform=None): |
| 79 | + model_config = pb.ModelConfig() |
| 80 | + if backend: |
| 81 | + model_config.backend = backend |
| 82 | + if platform: |
| 83 | + model_config.platform = platform |
| 84 | + model_config.name = name |
| 85 | + model_config.max_batch_size = 0 |
| 86 | + |
| 87 | + for input in IMAGE_INPUT: |
| 88 | + add_input(model_config, input) |
| 89 | + for output in TASK_OUTPUT[task]: |
| 90 | + add_output(model_config, output) |
| 91 | + return model_config |
| 92 | + |
| 93 | + |
| 94 | +def create_preprocess_model(): |
| 95 | + pass |
| 96 | + |
| 97 | + |
| 98 | +def get_onnx_io_names(detail_info): |
| 99 | + onnx_config = detail_info['onnx_config'] |
| 100 | + return onnx_config['input_names'], onnx_config['output_names'] |
| 101 | + |
| 102 | + |
| 103 | +def create_inference_model(deploy_info, pipeline_info, detail_info): |
| 104 | + if 'pipeline' in pipeline_info: |
| 105 | + # old-style pipeline specification |
| 106 | + pipeline = pipeline_info['pipeline']['tasks'] |
| 107 | + else: |
| 108 | + pipeline = pipeline_info['tasks'] |
| 109 | + |
| 110 | + for task_cfg in pipeline: |
| 111 | + if task_cfg['module'] == 'Net': |
| 112 | + input_names, output_names = get_onnx_io_names(detail_info) |
| 113 | + |
| 114 | + |
| 115 | +def create_postprocess_model(): |
| 116 | + pass |
| 117 | + |
| 118 | + |
| 119 | +def create_pipeline_model(): |
| 120 | + pass |
| 121 | + |
| 122 | + |
| 123 | +def create_ensemble_model(deploy_cfg, pipeline_cfg): |
| 124 | + inference_model_config = create_inference_model(deploy_cfg, pipeline_cfg) |
| 125 | + preprocess_model_config = create_preprocess_model() |
| 126 | + postprocess_model_config = create_postprocess_model() |
| 127 | + pipeline_model_config = create_pipeline_model() |
| 128 | + |
| 129 | + |
| 130 | +def main(): |
| 131 | + args = parse_args() |
| 132 | + model_path = args.model_path |
| 133 | + if not osp.isdir(model_path): |
| 134 | + model_path = osp.split(model_path)[-2] |
| 135 | + if osp.isdir(model_path): |
| 136 | + with open(osp.join(model_path, 'deploy.json'), 'r') as f: |
| 137 | + deploy_cfg = json.load(f) |
| 138 | + with open(osp.join(model_path, 'pipeline.json'), 'r') as f: |
| 139 | + pipeline_cfg = json.load(f) |
| 140 | + task = deploy_cfg['task'] |
| 141 | + model_config = create_model_config('model', task, 'onnxruntime') |
| 142 | + data = serialize_model_config(model_config) |
| 143 | + print(data) |
| 144 | + |
| 145 | + |
| 146 | +if __name__ == '__main__': |
| 147 | + main() |
0 commit comments