@@ -61,18 +61,18 @@ def get_weight_and_bias(self, context):
61
61
62
62
def _state_variable_finder (self , context ):
63
63
if self .gru_cell_type == RNNUnitType .GRUCell :
64
- lstm_cell = context .cell_match
64
+ gru_cell = context .cell_match
65
65
return self ._find_state_variable_with_select (
66
66
context ,
67
- lstm_cell .get_op ("cell_output" ).output [0 ],
68
- [lstm_cell .get_op ("cell_inputs" )]
67
+ gru_cell .get_op ("cell_output" ).output [0 ],
68
+ [gru_cell .get_op ("cell_inputs" )]
69
69
)
70
70
if self .gru_cell_type == RNNUnitType .GRUBlockCell :
71
- lstm_block_cell = context .cell_match .get_op ("gru_block_cell" )
71
+ gru_block_cell = context .cell_match .get_op ("gru_block_cell" )
72
72
return self ._find_state_variable_with_select (
73
73
context ,
74
- lstm_block_cell .output [3 ],
75
- [lstm_block_cell ]
74
+ gru_block_cell .output [3 ],
75
+ [gru_block_cell ]
76
76
)
77
77
return None
78
78
@@ -97,7 +97,7 @@ def is_valid(self, context):
97
97
# output should be no more than 1
98
98
outputs = context .loop_properties .scan_outputs_exits
99
99
if len (outputs ) > 1 :
100
- log .debug ("found %d outputs for lstm : %s" , len (outputs ), outputs )
100
+ log .debug ("found %d outputs for gru : %s" , len (outputs ), outputs )
101
101
return False
102
102
return True
103
103
0 commit comments