Skip to content

Commit dbf6d40

Browse files
typo
1 parent d5accff commit dbf6d40

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tf2onnx/rewriter/gru_rewriter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,18 @@ def get_weight_and_bias(self, context):
6161

6262
def _state_variable_finder(self, context):
6363
if self.gru_cell_type == RNNUnitType.GRUCell:
64-
lstm_cell = context.cell_match
64+
gru_cell = context.cell_match
6565
return self._find_state_variable_with_select(
6666
context,
67-
lstm_cell.get_op("cell_output").output[0],
68-
[lstm_cell.get_op("cell_inputs")]
67+
gru_cell.get_op("cell_output").output[0],
68+
[gru_cell.get_op("cell_inputs")]
6969
)
7070
if self.gru_cell_type == RNNUnitType.GRUBlockCell:
71-
lstm_block_cell = context.cell_match.get_op("gru_block_cell")
71+
gru_block_cell = context.cell_match.get_op("gru_block_cell")
7272
return self._find_state_variable_with_select(
7373
context,
74-
lstm_block_cell.output[3],
75-
[lstm_block_cell]
74+
gru_block_cell.output[3],
75+
[gru_block_cell]
7676
)
7777
return None
7878

@@ -97,7 +97,7 @@ def is_valid(self, context):
9797
# output should be no more than 1
9898
outputs = context.loop_properties.scan_outputs_exits
9999
if len(outputs) > 1:
100-
log.debug("found %d outputs for lstm: %s", len(outputs), outputs)
100+
log.debug("found %d outputs for gru: %s", len(outputs), outputs)
101101
return False
102102
return True
103103

0 commit comments

Comments
 (0)