@@ -176,6 +176,50 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
176176 rng , shape = (param_count ,), dtype = jnp .float32 , minval = - k , maxval = k )
177177
178178
179+ def swap_lstm_gates (weights , input_size , hidden_size , num_layers , bidirectional ):
180+ """
181+ Swaps the weights for the input and output gates in an LSTM model's parameters.
182+
183+ This function is specifically designed for compatibility with MIOpen, where the gate ordering
184+ differs from CuDNN. In CuDNN, the gates are ordered as:
185+ - 0: Forget gate (f)
186+ - 1: Input gate (i)
187+ - 2: New memory gate (g)
188+ - 3: Output gate (o)
189+
190+ However, in MIOpen, the ordering of the new memory (g) and output (o) gates is swapped:
191+ - 0: Forget gate (f)
192+ - 1: Input gate (i)
193+ - 2: Output gate (o)
194+ - 3: New memory gate (g)
195+
196+ This function rearranges the weights and biases for the gates to ensure that the model
197+ operates correctly with MIOpen by swapping the third (new memory) and fourth (output) gates.
198+ """
199+ weights = jnp .asarray (weights )
200+ flat_shapes = _get_params_shapes_in_lstm (input_size , hidden_size , num_layers , bidirectional )
201+ num_directions = 2 if bidirectional else 1
202+
203+ w_offsets = 0
204+ for l in range (num_layers ):
205+ for direction in range (num_directions ):
206+ # Iterate through all weight and bias gate names to swap gates in both weights and biases.
207+ for gate_name in ["W_ih" , "W_hh" , "b_ih" , "b_hh" ]:
208+ shape = flat_shapes .pop (0 )
209+ num_elems = math .prod (shape )
210+ matrix = weights [w_offsets :w_offsets + num_elems ].reshape (shape )
211+
212+ # Swap between the input and output gates (third and fourth gates).
213+ gates = jnp .split (matrix , 4 , axis = 0 )
214+ swapped_matrix = jnp .concatenate ([gates [0 ], gates [1 ], gates [3 ], gates [2 ]], axis = 0 )
215+
216+ # Update the weights with swapped matrix.
217+ weights = weights .at [w_offsets :w_offsets + num_elems ].set (swapped_matrix .flatten ())
218+ w_offsets += num_elems
219+
220+ return weights
221+
222+
179223def unpack_lstm_weights (
180224 weights : Array , input_size : int , hidden_size : int , num_layers : int ,
181225 bidirectional : bool
@@ -437,7 +481,8 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
437481rnn_fwd_p .def_impl (partial (xla .apply_primitive , rnn_fwd_p ))
438482rnn_fwd_p .def_abstract_eval (rnn_abstract_eval )
439483if gpu_rnn :
440- mlir .register_lowering (rnn_fwd_p , gpu_rnn .cudnn_rnn_lowering , platform = 'cuda' )
484+ mlir .register_lowering (rnn_fwd_p , gpu_rnn .cudnn_rnn_fwd_lowering , platform = 'cuda' )
485+ mlir .register_lowering (rnn_fwd_p , gpu_rnn .miopen_rnn_fwd_lowering , platform = 'rocm' )
441486
442487
443488def lstm_bwd (input_size : int , hidden_size : int , num_layers : int , dropout : float ,
@@ -481,5 +526,7 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
481526if gpu_rnn :
482527 mlir .register_lowering (
483528 rnn_bwd_p , gpu_rnn .cudnn_rnn_bwd_lowering , platform = 'cuda' )
529+ mlir .register_lowering (
530+ rnn_bwd_p , gpu_rnn .miopen_rnn_bwd_lowering , platform = 'rocm' )
484531
485532lstm .defvjp (lstm_fwd , lstm_bwd )
0 commit comments