Skip to content

Commit 2ad402e

Browse files
handle DecodeError in tf_loader (#1466)
Signed-off-by: Calvin McCarter <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 7922978 commit 2ad402e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tf2onnx/tf_loader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313

1414
import tensorflow as tf
1515
import numpy as np
16+
from google.protobuf.message import DecodeError
17+
from tensorflow.core.protobuf import saved_model_pb2
1618
from tensorflow.python.ops import lookup_ops
19+
from tensorflow.python.util import compat
1720

1821
from tf2onnx import utils
1922
from tf2onnx.tf_utils import get_tf_version, tflist_to_onnx, get_hash_table_info, replace_placeholders_with_tables
@@ -194,6 +197,11 @@ def from_graphdef(model_path, input_names, output_names):
194197
"Unable to load file '{}'.".format(model_path)) from e
195198
try:
196199
graph_def.ParseFromString(content)
200+
except DecodeError:
201+
content_as_bytes = compat.as_bytes(content)
202+
saved_model = saved_model_pb2.SavedModel()
203+
saved_model.ParseFromString(content_as_bytes)
204+
graph_def = saved_model.meta_graphs[0].graph_def
197205
except Exception as e:
198206
raise RuntimeError(
199207
"Unable to parse file '{}'.".format(model_path)) from e

0 commit comments

Comments
 (0)