Skip to content

Commit eecd3a4

Browse files
Added a workaround for the TF bug that tensor.numpy() doesn't always have a dtype (#1250)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent cf8c953 commit eecd3a4

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tf2onnx/tf_loader.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from distutils.version import LooseVersion
1313

1414
import tensorflow as tf
15+
import numpy as np
1516
from tensorflow.python.ops import lookup_ops
1617

1718
from tf2onnx import utils
@@ -311,6 +312,12 @@ def _remove_non_variable_resources_from_captures(concrete_func):
311312
for i in reversed(range(len(concrete_func._captured_inputs))):
312313
if concrete_func._captured_inputs[i] is val_tensor:
313314
concrete_func._captured_inputs.pop(i)
315+
elif val_tensor.dtype != tf.resource:
316+
npval = val_tensor.numpy()
317+
if not hasattr(npval, 'dtype'):
318+
# Hack around a TF bug until PR is merged: https://github.com/tensorflow/tensorflow/pull/45610
319+
arr = np.array(npval)
320+
val_tensor.numpy = lambda arr=arr: arr
314321
else:
315322
logger.warning(
316323
"Could not search for non-variable resources. Concrete function internal representation may have changed.")

0 commit comments

Comments
 (0)