Skip to content

Commit 5892425

Browse files
Fix tfjs quantized weight parsing for fp16 quant (#1625)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 101fbc4 commit 5892425

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tf2onnx/tfjs_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def resolve_output(output, op_info, func_name=None):
108108

109109
def get_output_names_and_dtypes(op_type, tf_attr):
110110
"""Parses the tf documentation to determine the names and dtypes of the outputs of the op"""
111+
# TODO: ['Prelu', 'Conv1D', 'DepthwiseConv2d', 'FusedDepthwiseConv2dNative', 'Ones', 'Zeros']
111112
try:
112113
tf_op_def = tf_api_def_map.get_op_def(op_type)
113114
except ValueError:
@@ -264,7 +265,10 @@ def read_tfjs_weight(weight, weights_data, offset):
264265
q_dtype = np.dtype(q_info['dtype'])
265266
np_arr = np.frombuffer(weights_data, dtype=q_dtype, count=count, offset=offset)
266267
num_bytes = np_arr.nbytes
267-
np_arr = np_arr.astype(np_dtype) * q_info['scale'] + q_info['min']
268+
if 'scale' in q_info:
269+
np_arr = np_arr.astype(np_dtype) * q_info['scale'] + q_info['min']
270+
else:
271+
np_arr = np_arr.astype(np_dtype)
268272
else:
269273
np_arr = np.frombuffer(weights_data, dtype=np_dtype, count=count, offset=offset)
270274
num_bytes = np_arr.nbytes

0 commit comments

Comments
 (0)