Skip to content

Commit ab816b3

Browse files
Hack to support > 2GB tensors (#1558)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 9223376 commit ab816b3

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def get_value_attr(self, external_tensor_storage=None):
105105
if external_tensor_storage is None or a.type != AttributeProto.TENSOR:
106106
return a
107107
if np.product(a.t.dims) > external_tensor_storage.external_tensor_size_threshold:
108-
a = copy.copy(a)
108+
a = copy.deepcopy(a)
109109
tensor_name = self.name.strip() + "_" + str(external_tensor_storage.name_counter)
110110
for c in '~"#%&*:<>?/\\{|}':
111111
tensor_name = tensor_name.replace(c, '_')

tf2onnx/tf_loader.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tensorflow as tf
1010
import numpy as np
1111
from google.protobuf.message import DecodeError
12+
from tensorflow.core.framework import tensor_pb2
1213
from tensorflow.core.protobuf import saved_model_pb2
1314
from tensorflow.python.ops import lookup_ops
1415
from tensorflow.python.util import compat
@@ -127,8 +128,28 @@ def convert_variables_to_constants_large_model(func):
127128
_FunctionConverterData, _replace_variables_by_constants # pylint: disable=protected-access
128129
except ImportError:
129130
_not_implemented_tf_placeholder("_replace_variables_by_constants")()
130-
converter_data = _FunctionConverterData(func=func, lower_control_flow=False, aggressive_inlining=True)
131-
frozen_graph_def, _ = _replace_variables_by_constants(converter_data=converter_data)
131+
132+
from tensorflow.python.framework import tensor_util, tensor_shape
133+
make_tensor_proto_original = tensor_util.make_tensor_proto
134+
# Hack to avoid 2GB check
135+
def make_tensor_proto_wrapped(values, dtype=None, shape=None, verify_shape=False, allow_broadcast=False):
136+
try:
137+
return make_tensor_proto_original(values, dtype, shape, verify_shape, allow_broadcast)
138+
except ValueError:
139+
if dtype is None:
140+
dtype = tf.dtypes.as_dtype(values.dtype).as_datatype_enum
141+
tensor_proto = tensor_pb2.TensorProto(
142+
dtype=dtype,
143+
tensor_shape=tensor_shape.as_shape(values.shape).as_proto())
144+
tensor_proto.tensor_content = values.tobytes()
145+
return tensor_proto
146+
tensor_util.make_tensor_proto = make_tensor_proto_wrapped
147+
148+
try:
149+
converter_data = _FunctionConverterData(func=func, lower_control_flow=False, aggressive_inlining=True)
150+
frozen_graph_def, _ = _replace_variables_by_constants(converter_data=converter_data)
151+
finally:
152+
tensor_util.make_tensor_proto = make_tensor_proto_original
132153
return frozen_graph_def
133154

134155

0 commit comments

Comments
 (0)