Skip to content

tf2onns failed : google.protobuf.message.DecodeError: Error parsing message with type 'onnx.AttributeProto' #2408

@wushandinghua

Description

@wushandinghua

Describe the bug

I have inference function and params of a jax model and convert it to a tf saved model. I encounter a issue when i convert the saved model to onnx model.How can i solve it?
tf2onnx issue:

<frozen runpy>:128: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
2025-08-27 18:19:57,445 - WARNING - tf2onnx.tf_loader: '--tag' not specified for saved_model. Using --tag serve
2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Signatures found in model: [serving_default].
2025-08-27 18:20:01,172 - WARNING - tf2onnx.tf_loader: '--signature_def' not specified, using first signature: serving_default
2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Output names: ['output_0']
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1756290001.225592    5292 devices.cc:76] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support)
I0000 00:00:1756290001.225854    5292 single_machine.cc:376] Starting new session
2025-08-27 18:20:11,084 - INFO - tf2onnx: inputs: ['inputs_0:0', 'inputs_1:0', 'inputs_2:0', 'inputs_3:0', 'inputs_4:0', 'inputs_5:0', 'inputs_6:0']
2025-08-27 18:20:11,084 - INFO - tf2onnx: outputs: ['Identity:0']
2025-08-27 18:20:14,362 - INFO - tf2onnx.tfonnx: Using tensorflow=2.20.0, onnx=1.17.0, tf2onnx=1.16.1/15c810
2025-08-27 18:20:14,363 - INFO - tf2onnx.tfonnx: Using opset <onnx, 21>
2025-08-27 18:20:20,789 - ERROR - tf2onnx.tf_utils: pass1 convert failed for name: "unknown_43"
op: "Const"
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_HALF
      tensor_shape {
        dim {
          size: 18
        }
        dim {
          size: 2
        }
        dim {
          size: 2048
        }
        dim {
          size: 16384
        }
      }
    }
  }
}
attr {
  key: "dtype"
  value {
    type: DT_HALF
  }
}
, ex=Error parsing message with type 'onnx.AttributeProto'
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 714, in <module>
    main()
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 273, in main
    model_proto, _ = _convert_common(
                     ^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 168, in _convert_common
    g = process_tf_graph(tf_graph, const_node_values=const_node_values,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 459, in process_tf_graph
    main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 474, in graphs_from_tf
    ordered_func = resolve_functions(tf_graph)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_loader.py", line 784, in resolve_functions
    _, _, _, _, _, functions = tflist_to_onnx(tf_graph, {})
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_utils.py", line 463, in tflist_to_onnx
    onnx_node = utils.make_onnx_node_with_attr(node_type, input_names, output_names, name=node.name, **attr)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/utils.py", line 207, in make_onnx_node_with_attr
    onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **valid_attrs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in make_node
    node.attribute.extend(
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in <genexpr>
    node.attribute.extend(
                         ^
google.protobuf.message.DecodeError: Error parsing message with type 'onnx.AttributeProto'

tf2onnx scripts:

python -m tf2onnx.convert --saved-model /dev/shm/tmp/tf_model --output /dev/shm/pi0_galaxea_lora.onnx --opset 21 --large_model --verbose

convert jax to tf saved model scripts:

def jax2tf_saved_model(inference_fn, params, save_path, batch_size, action_dim, max_token_len):
    """Convert JAX function to TensorFlow and then to ONNX."""
    # This function is not used in the final export, but can be useful for debugging.
    def extract_value(p):
        if isinstance(p, (dict, nnx.State)):
            return {k: extract_value(v) for k, v in p.items()}
        elif isinstance(p, nnx.variablelib.VariableState):
            return p.value
        return p
            

    params_plain = extract_value(params)
    
    # print("params_plain:", params_plain)
    print("get value finished")

    def to_tf_variable(x):
        if isinstance(x, (float, int, bool, list, tuple)):
            return tf.Variable(x)
        elif isinstance(x, dict):
            return {k: to_tf_variable(v) for k, v in x.items()}
        elif isinstance(x, (jax.Array)):
            return tf.Variable(tf.convert_to_tensor(np.asarray(x, copy=False)))
        return x
    # params_vars = to_tf_variable(params_plain)
    params_vars = tf.nest.map_structure(tf.Variable, params_plain)
    del params_plain
    print(params_vars)
    print("to tf variable finished")
    
    input_specs = [
        tf.TensorSpec([2], tf.uint32),  # rng
        tf.TensorSpec([batch_size, 480, 640, 3], tf.float32),  # base image
        tf.TensorSpec([batch_size, 480, 640, 3], tf.float32),  # left image
        tf.TensorSpec([batch_size, 480, 640, 3], tf.float32),  # right image
        tf.TensorSpec([batch_size, action_dim], tf.float32),  # state
        tf.TensorSpec([batch_size, max_token_len], tf.int32),  # tokens
        tf.TensorSpec([batch_size, max_token_len], tf.bool),  # token mask
    ]
    my_model = tf.Module()
    my_model._variables = tf.nest.flatten(params_vars)
    prediction_tf = lambda *inputs: jax2tf.convert(inference_fn, native_serialization=False, with_gradient=False)(params_vars, *inputs)
    my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False, input_signature=input_specs)
    tf.saved_model.save(my_model, f'{save_path}/tf_model', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

Urgency

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 18.04*):nvidia jetpack 6.1
  • TensorFlow Version:2.20
  • Python version:3.11
  • ONNX version (if applicable, e.g. 1.11*):1.17.0
  • ONNXRuntime version (if applicable, e.g. 1.11*):none

To Reproduce

Screenshots

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugAn unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions