Skip to content

Commit 8a3c573

Browse files
committed
fix lstm bug with peephole
1 parent d441206 commit 8a3c573

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,14 @@ def _get_weight_and_bias_for_lstm_cell(self, context):
125125
def parse_attributes(self, context):
126126
if self.lstm_cell_type == RNNUnitType.LSTMBlockCell:
127127
lstm_block_cell = context.cell_match.get_op("lstm_block_cell")
128-
clip = float(lstm_block_cell.get_attr("cell_clip").f)
128+
clip = lstm_block_cell.get_attr_value("cell_clip")
129129
# current LSTM op cannot handle clip
130130
if clip > 0:
131131
return False
132+
133+
use_peephole = lstm_block_cell.get_attr_value("use_peephole")
134+
if use_peephole:
135+
return False
132136
return True
133137

134138
def _ct_variable_finder(self, context):

0 commit comments

Comments
 (0)