Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/circle-operator/driver/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ int entry(int argc, char **argv)
"circle-operator allows users to retrieve operator information from a Circle model file"};
arser.add_argument("--name").nargs(0).help("Dump operators name in circle file");
arser.add_argument("--code").nargs(0).help("Dump operators code in circle file");
arser.add_argument("--shapes").nargs(0).help("Dump shapes");
arser.add_argument("--output_path").help("Save output to file (default output is console)");
arser.add_argument("circle").help("Circle file to dump");

Expand All @@ -59,6 +60,7 @@ int entry(int argc, char **argv)
cirops::DumpOption option;
option.names = arser["--name"];
option.codes = arser["--code"];
option.shapes = arser["--shapes"];

std::ofstream oFstream;
std::ostream *oStream = &std::cout;
Expand Down
45 changes: 34 additions & 11 deletions compiler/circle-operator/src/Dump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@
namespace
{

// TODO handle multiple outputs case
const circle::Tensor *get_output_tensor(mio::circle::Reader &reader, const circle::Operator *op)
{
const auto tensors = reader.tensors();
const auto output_tensors = reader.outputs(op);
const auto output = output_tensors.at(0);
return tensors->Get(output);
}

void dump_shape(std::ostream &os, const ::flatbuffers::Vector<int32_t> *shape)
{
os << "[";
for (uint32_t i = 0; i < shape->size(); ++i)
{
os << shape->Get(i);
if (i < shape->size() - 1)
{
os << ",";
}
}
os << "]";
}

void dump_ops(std::ostream &os, mio::circle::Reader &reader, const cirops::DumpOption &option)
{
auto ops = reader.operators();
Expand All @@ -44,19 +67,19 @@ void dump_ops(std::ostream &os, mio::circle::Reader &reader, const cirops::DumpO
const auto op_name = reader.opcode_name(op);
os << op_name;
}

if (option.names)
{
// TODO multiple outputs?
const auto tensors = reader.tensors();
const auto output_tensors = reader.outputs(op);
const auto output = output_tensors.at(0);
const auto tensor = tensors->Get(output);
const std::string name = mio::circle::tensor_name(tensor);
if (option.codes)
{
os << ",";
}
os << name;
const std::string name = mio::circle::tensor_name(get_output_tensor(reader, op));
os << (option.codes ? "," : "") << name;
}

if (option.shapes)
{
os << (option.codes || option.names ? "," : "");
auto const out_tensor = get_output_tensor(reader, op);
dump_shape(os, (nullptr == out_tensor->shape_signature()) ? out_tensor->shape()
: out_tensor->shape());
}
os << std::endl;
}
Expand Down
1 change: 1 addition & 0 deletions compiler/circle-operator/src/Dump.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct DumpOption
bool names = false;
bool codes = false;
bool all_graphs = false;
bool shapes = false;
};

class DumpOperators
Expand Down
83 changes: 83 additions & 0 deletions compiler/one-cmds/one-import-onnx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import sys
import tempfile
import onnx
import onnx_tf
from onnx.tools import update_model_dims

# ONNX legalizer is an optional feature
# It enables conversion of some operations, but in experimental phase for now
Expand Down Expand Up @@ -187,6 +188,26 @@ def _get_parser():
action='store_true',
help='Experimental disable BatchMatMul unfold')

# set static input shape
parser.add_argument(
'--input_shapes',
type=str,
help=
'Set static shape for input tensors in comma-separated list format, like \'[1,2,3]\'.'
'If the model has multiple inputs, tensor names should be provided as well.'
'In such a case, the argument should be in the following format: \'a[1,2,3],b[4,5,6],c[7,8]\'.'
)

# set static output shape
parser.add_argument(
'--output_shapes',
type=str,
help=
'Set static shape for output tensors in comma-separated list format, like \'[1,2,3]\'.'
'If the model has multiple inputs, tensor names should be provided as well.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'If the model has multiple inputs, tensor names should be provided as well.'
'If the model has multiple outputs, tensor names should be provided as well.'

'In such a case, the argument should be in the following format: \'a[1,2,3],b[4,5,6],c[7,8]\'.'
)

return parser


Expand Down Expand Up @@ -257,6 +278,52 @@ def _check_ext(args):
return None


def _extract_origin_tensor_shapes(tensors):
shapes_map = {}
for tensor in tensors:
shapes_map[tensor.name] = []
for dim_proto in tensor.type.tensor_type.shape.dim:
if dim_proto.HasField("dim_value"):
shapes_map[tensor.name].append(dim_proto.dim_value)
elif dim_proto.HasField("dim_param"):
shapes_map[tensor.name].append(dim_proto.dim_param)
else:
shapes_map[tensor.name].append(-1) # dynamic
return shapes_map


def _parse_shapes(shape_str, origin_tensor_shapes_map):
user_shapes_map = {}
for idx, single_shape in enumerate(shape_str.split(']')):
if single_shape:
tensor_name, single_shape = single_shape.split('[')
if not tensor_name:
# using tensor name model allowed only for single input models
if idx == 0 and len(origin_tensor_shapes_map) == 1:
tensor_name = next(iter(origin_tensor_shapes_map.keys()))
else:
raise ValueError(
'You must provide tenors name for the model with multiple inputs/outputs'
)
tensor_name = tensor_name.replace(',', '')
user_shapes_map[tensor_name] = []
for dim in single_shape.split(','):
user_shapes_map[tensor_name].append(int(dim))

for tensor_name in user_shapes_map.keys():
if tensor_name not in origin_tensor_shapes_map:
raise ValueError(
f'Tensor with name={tensor_name} do NOT match any input/output of the model'
)
if len(user_shapes_map[tensor_name]) != len(
origin_tensor_shapes_map[tensor_name]):
raise ValueError(
f'Rank of provided shape must be compatible with the origin tensor={tensor_name} from the model'
)

return user_shapes_map


def _convert(args):
_apply_verbosity(args.verbose)

Expand All @@ -272,6 +339,22 @@ def _convert(args):
# convert onnx to tf saved model
onnx_model = onnx.load(getattr(args, 'input_path'))
_sanitize_io_names(onnx_model)
input_shape_provided = oneutils.is_valid_attr(args, 'input_shapes')
output_shape_provided = oneutils.is_valid_attr(args, 'output_shapes')
if input_shape_provided or output_shape_provided: # any shape arg provided
input_shapes_map = _extract_origin_tensor_shapes(onnx_model.graph.input)
if input_shape_provided:
input_shapes_map = _parse_shapes(
getattr(args, 'input_shapes'), input_shapes_map)
delattr(args, 'input_shapes'
) # avoid argument colision with tf2tflite conversion step
output_shapes_map = _extract_origin_tensor_shapes(onnx_model.graph.output)
if output_shape_provided:
output_shapes_map = _parse_shapes(
getattr(args, 'output_shapes'), output_shapes_map)
onnx_model = update_model_dims.update_inputs_outputs_dims(
onnx_model, input_shapes_map, output_shapes_map)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we use strict_mode and handle errors?

https://onnx.ai/onnx/api/shape_inference.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation, but IHMO infer_shapes from onnx is something optional here. Even if something goes wrong we have still shape inference provided by ONE itself. My assumption was to treat it as an improvements for some edge cases ;-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is about not just dying when something goes wrong, but recognizing and dealing with the situation. If someone is running a multi-step toolchain and it crashes, make sure that the person using the toolchain knows why it crashed and what to do about it.

Remember that not everyone using the toolchain is a brilliant programmer like you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enabling strict_mode, catching exception and printing warning if shapes calculated by onnx lib were not applied?

If someone is running a multi-step toolchain and it crashes, make sure that the person using the toolchain knows why it crashed and what to do about it.

I am also OK with stopping conversion if onnx shape inference fails. You are right that I haven't a big experience with the whole related toolchain ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enabling strict_mode, catching exception and printing warning if shapes calculated by onnx lib were not applied?

Is it worth continuing execution of the toolchain after this? Is there any chance that the toolchain execution will end successfully?

If not, wouldn't it be better to display an error message explaining the problem and abort execution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely yes. onnx.shape_inference.infer_shapes is not needed for models dedicated to be supported by this PR.

My proposition is to remove it now and add separately if really needed ;)

Copy link
Member

@lemmaa lemmaa Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enabling strict_mode, catching exception and printing warning if shapes calculated by onnx lib were not applied?

Is it worth continuing execution of the toolchain after this? Is there any chance that the toolchain execution will end successfully?

To clarify the understanding,

I agree with using strict_mode. However, when catching an exception, it is appropriate to terminate with an error rather than proceeding with a warning. Also, it is necessary to clearly inform the user of the cause of the error.

If you agree, please proceed carefully to avoid regression. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's proceed it in this way ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lemmaa Is it ok for you now? Can I start creating PRs > ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanshpark Can I start introduce this feature? Is the current design acceptable for you ;-) ?

if _onnx_legalizer_enabled:
options = onnx_legalizer.LegalizeOptions
options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
Expand Down
Loading