-
Notifications
You must be signed in to change notification settings - Fork 455
Open
Labels
bugAn unexpected problem or unintended behaviorAn unexpected problem or unintended behavior
Description
Describe the bug
I have inference function and params of a jax model and convert it to a tf saved model. I encounter a issue when i convert the saved model to onnx model.How can i solve it?
tf2onnx issue:
<frozen runpy>:128: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
2025-08-27 18:19:57,445 - WARNING - tf2onnx.tf_loader: '--tag' not specified for saved_model. Using --tag serve
2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Signatures found in model: [serving_default].
2025-08-27 18:20:01,172 - WARNING - tf2onnx.tf_loader: '--signature_def' not specified, using first signature: serving_default
2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Output names: ['output_0']
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1756290001.225592 5292 devices.cc:76] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support)
I0000 00:00:1756290001.225854 5292 single_machine.cc:376] Starting new session
2025-08-27 18:20:11,084 - INFO - tf2onnx: inputs: ['inputs_0:0', 'inputs_1:0', 'inputs_2:0', 'inputs_3:0', 'inputs_4:0', 'inputs_5:0', 'inputs_6:0']
2025-08-27 18:20:11,084 - INFO - tf2onnx: outputs: ['Identity:0']
2025-08-27 18:20:14,362 - INFO - tf2onnx.tfonnx: Using tensorflow=2.20.0, onnx=1.17.0, tf2onnx=1.16.1/15c810
2025-08-27 18:20:14,363 - INFO - tf2onnx.tfonnx: Using opset <onnx, 21>
2025-08-27 18:20:20,789 - ERROR - tf2onnx.tf_utils: pass1 convert failed for name: "unknown_43"
op: "Const"
attr {
key: "value"
value {
tensor {
dtype: DT_HALF
tensor_shape {
dim {
size: 18
}
dim {
size: 2
}
dim {
size: 2048
}
dim {
size: 16384
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_HALF
}
}
, ex=Error parsing message with type 'onnx.AttributeProto'
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 714, in <module>
main()
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 273, in main
model_proto, _ = _convert_common(
^^^^^^^^^^^^^^^^
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 168, in _convert_common
g = process_tf_graph(tf_graph, const_node_values=const_node_values,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 459, in process_tf_graph
main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 474, in graphs_from_tf
ordered_func = resolve_functions(tf_graph)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_loader.py", line 784, in resolve_functions
_, _, _, _, _, functions = tflist_to_onnx(tf_graph, {})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_utils.py", line 463, in tflist_to_onnx
onnx_node = utils.make_onnx_node_with_attr(node_type, input_names, output_names, name=node.name, **attr)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/utils.py", line 207, in make_onnx_node_with_attr
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **valid_attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in make_node
node.attribute.extend(
File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in <genexpr>
node.attribute.extend(
^
google.protobuf.message.DecodeError: Error parsing message with type 'onnx.AttributeProto'
tf2onnx scripts:
python -m tf2onnx.convert --saved-model /dev/shm/tmp/tf_model --output /dev/shm/pi0_galaxea_lora.onnx --opset 21 --large_model --verbose
convert jax to tf saved model scripts:
def jax2tf_saved_model(inference_fn, params, save_path, batch_size, action_dim, max_token_len):
"""Convert JAX function to TensorFlow and then to ONNX."""
# This function is not used in the final export, but can be useful for debugging.
def extract_value(p):
if isinstance(p, (dict, nnx.State)):
return {k: extract_value(v) for k, v in p.items()}
elif isinstance(p, nnx.variablelib.VariableState):
return p.value
return p
params_plain = extract_value(params)
# print("params_plain:", params_plain)
print("get value finished")
def to_tf_variable(x):
if isinstance(x, (float, int, bool, list, tuple)):
return tf.Variable(x)
elif isinstance(x, dict):
return {k: to_tf_variable(v) for k, v in x.items()}
elif isinstance(x, (jax.Array)):
return tf.Variable(tf.convert_to_tensor(np.asarray(x, copy=False)))
return x
# params_vars = to_tf_variable(params_plain)
params_vars = tf.nest.map_structure(tf.Variable, params_plain)
del params_plain
print(params_vars)
print("to tf variable finished")
input_specs = [
tf.TensorSpec([2], tf.uint32), # rng
tf.TensorSpec([batch_size, 480, 640, 3], tf.float32), # base image
tf.TensorSpec([batch_size, 480, 640, 3], tf.float32), # left image
tf.TensorSpec([batch_size, 480, 640, 3], tf.float32), # right image
tf.TensorSpec([batch_size, action_dim], tf.float32), # state
tf.TensorSpec([batch_size, max_token_len], tf.int32), # tokens
tf.TensorSpec([batch_size, max_token_len], tf.bool), # token mask
]
my_model = tf.Module()
my_model._variables = tf.nest.flatten(params_vars)
prediction_tf = lambda *inputs: jax2tf.convert(inference_fn, native_serialization=False, with_gradient=False)(params_vars, *inputs)
my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False, input_signature=input_specs)
tf.saved_model.save(my_model, f'{save_path}/tf_model', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
Urgency
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 18.04*):nvidia jetpack 6.1
- TensorFlow Version:2.20
- Python version:3.11
- ONNX version (if applicable, e.g. 1.11*):1.17.0
- ONNXRuntime version (if applicable, e.g. 1.11*):none
To Reproduce
Screenshots
Additional context
Metadata
Metadata
Assignees
Labels
bugAn unexpected problem or unintended behaviorAn unexpected problem or unintended behavior