Skip to content

Commit 88c5d0b

Browse files
committed
set bias to 0
1 parent c5872a6 commit 88c5d0b

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,18 @@ def _get_weight_and_bias_for_lstm_cell(self, context):
9898
# check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
9999
# for bias_add data format
100100
bias_add = match.get_op("bias_add")
101-
if bias_add.data_format != "NHWC":
101+
if bias_add is not None and bias_add.data_format != "NHWC":
102102
logger.debug("BiasAdd data_format is not NHWC, SKIP")
103103
return None
104104

105105
b_e = match.get_op("cell_bias")
106-
b = get_weights_from_const_node(self.g, b_e)
107-
if b is None or b.shape[0] != w.shape[1]:
108-
logger.warning("cell_kernel and cell_bias's dimensions does not match, skip")
109-
return None
106+
if b_e is None:
107+
b = np.array([0 for i in range(len(w[0]))]).astype(w.dtype)
108+
else:
109+
b = get_weights_from_const_node(self.g, b_e)
110+
if b is None or b.shape[0] != w.shape[1]:
111+
logger.warning("cell_kernel and cell_bias's dimensions does not match, skip")
112+
return None
110113

111114
ft_bias_node = match.get_op("ft_bias")
112115
ft_bias = get_weights_from_const_node(self.g, ft_bias_node)

tf2onnx/rewriter/rnn_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ class REWRITER_RESULT(Enum):
5151
OpTypePattern('Mul', name='ht', inputs=[
5252
OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern]),
5353
OpTypePattern('Tanh', inputs=[
54-
OpTypePattern("Add", name="ct", inputs=[
54+
OpTypePattern("Add|AddV2", name="ct", inputs=[
5555
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
5656
OpTypePattern("Sigmoid", name="ft", inputs=[
57-
OpTypePattern("Add", inputs=[
57+
OpTypePattern("Add|AddV2", inputs=[
5858
xc_pattern,
5959
OpTypePattern("*", name="ft_bias"),
6060
]),

0 commit comments

Comments
 (0)