-
Notifications
You must be signed in to change notification settings - Fork 178
[one-cmds] Allow to specialize shape of input/output tensors during ONNX-Circle conversion #13638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
3c6e783
2991c62
e5f6197
dc32280
508689d
39157fd
e66fe0f
11606fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ import sys | |
| import tempfile | ||
| import onnx | ||
| import onnx_tf | ||
| from onnx.tools import update_model_dims | ||
|
|
||
| # ONNX legalizer is an optional feature | ||
| # It enables conversion of some operations, but in experimental phase for now | ||
|
|
@@ -187,6 +188,26 @@ def _get_parser(): | |
| action='store_true', | ||
| help='Experimental disable BatchMatMul unfold') | ||
|
|
||
| # set static input shape | ||
| parser.add_argument( | ||
| '--input_shapes', | ||
| type=str, | ||
| help= | ||
| 'Set static shape for input tensors in comma-separated list format, like \'[1,2,3]\'.' | ||
| 'If the model has multiple inputs, tensor names should be provided as well.' | ||
| 'In such a case, the argument should be in the following format: \'a[1,2,3],b[4,5,6],c[7,8]\'.' | ||
| ) | ||
|
|
||
| # set static output shape | ||
| parser.add_argument( | ||
| '--output_shapes', | ||
| type=str, | ||
| help= | ||
| 'Set static shape for output tensors in comma-separated list format, like \'[1,2,3]\'.' | ||
| 'If the model has multiple inputs, tensor names should be provided as well.' | ||
| 'In such a case, the argument should be in the following format: \'a[1,2,3],b[4,5,6],c[7,8]\'.' | ||
| ) | ||
|
|
||
| return parser | ||
|
|
||
|
|
||
|
|
@@ -257,6 +278,52 @@ def _check_ext(args): | |
| return None | ||
|
|
||
|
|
||
| def _extract_origin_tensor_shapes(tensors): | ||
| shapes_map = {} | ||
| for tensor in tensors: | ||
| shapes_map[tensor.name] = [] | ||
| for dim_proto in tensor.type.tensor_type.shape.dim: | ||
| if dim_proto.HasField("dim_value"): | ||
| shapes_map[tensor.name].append(dim_proto.dim_value) | ||
| elif dim_proto.HasField("dim_param"): | ||
| shapes_map[tensor.name].append(dim_proto.dim_param) | ||
| else: | ||
| shapes_map[tensor.name].append(-1) # dynamic | ||
| return shapes_map | ||
|
|
||
|
|
||
| def _parse_shapes(shape_str, origin_tensor_shapes_map): | ||
| user_shapes_map = {} | ||
| for idx, single_shape in enumerate(shape_str.split(']')): | ||
| if single_shape: | ||
| tensor_name, single_shape = single_shape.split('[') | ||
| if not tensor_name: | ||
| # using tensor name model allowed only for single input models | ||
| if idx == 0 and len(origin_tensor_shapes_map) == 1: | ||
| tensor_name = next(iter(origin_tensor_shapes_map.keys())) | ||
| else: | ||
| raise ValueError( | ||
| 'You must provide tenors name for the model with multiple inputs/outputs' | ||
| ) | ||
| tensor_name = tensor_name.replace(',', '') | ||
| user_shapes_map[tensor_name] = [] | ||
| for dim in single_shape.split(','): | ||
| user_shapes_map[tensor_name].append(int(dim)) | ||
|
|
||
| for tensor_name in user_shapes_map.keys(): | ||
| if tensor_name not in origin_tensor_shapes_map: | ||
| raise ValueError( | ||
| f'Tensor with name={tensor_name} do NOT match any input/output of the model' | ||
| ) | ||
| if len(user_shapes_map[tensor_name]) != len( | ||
| origin_tensor_shapes_map[tensor_name]): | ||
| raise ValueError( | ||
| f'Rank of provided shape must be compatible with the origin tensor={tensor_name} from the model' | ||
| ) | ||
|
|
||
| return user_shapes_map | ||
|
|
||
|
|
||
| def _convert(args): | ||
| _apply_verbosity(args.verbose) | ||
|
|
||
|
|
@@ -272,6 +339,22 @@ def _convert(args): | |
| # convert onnx to tf saved model | ||
| onnx_model = onnx.load(getattr(args, 'input_path')) | ||
| _sanitize_io_names(onnx_model) | ||
| input_shape_provided = oneutils.is_valid_attr(args, 'input_shapes') | ||
| output_shape_provided = oneutils.is_valid_attr(args, 'output_shapes') | ||
| if input_shape_provided or output_shape_provided: # any shape arg provided | ||
| input_shapes_map = _extract_origin_tensor_shapes(onnx_model.graph.input) | ||
| if input_shape_provided: | ||
| input_shapes_map = _parse_shapes( | ||
| getattr(args, 'input_shapes'), input_shapes_map) | ||
| delattr(args, 'input_shapes' | ||
| ) # avoid argument colision with tf2tflite conversion step | ||
| output_shapes_map = _extract_origin_tensor_shapes(onnx_model.graph.output) | ||
| if output_shape_provided: | ||
| output_shapes_map = _parse_shapes( | ||
| getattr(args, 'output_shapes'), output_shapes_map) | ||
| onnx_model = update_model_dims.update_inputs_outputs_dims( | ||
| onnx_model, input_shapes_map, output_shapes_map) | ||
| onnx_model = onnx.shape_inference.infer_shapes(onnx_model) | ||
|
||
| if _onnx_legalizer_enabled: | ||
| options = onnx_legalizer.LegalizeOptions | ||
| options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn') | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.