@@ -175,6 +175,31 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
175175 return jax .random .uniform (
176176 rng , shape = (param_count ,), dtype = jnp .float32 , minval = - k , maxval = k )
177177
178+ def swap_lstm_gates (weights , input_size , hidden_size , num_layers , bidirectional ):
179+ """Swaps the weights for the input and output gates for an LSTM model."""
180+ weights = jnp .asarray (weights ) # Ensure weights are JAX arrays
181+ flat_shapes = _get_params_shapes_in_lstm (input_size , hidden_size , num_layers , bidirectional )
182+ num_directions = 2 if bidirectional else 1
183+
184+ w_offsets = 0
185+ for l in range (num_layers ):
186+ for direction in range (num_directions ):
187+ # Iterate through all weight and bias gate names to swap gates in both weights and biases
188+ for gate_name in ["W_ih" , "W_hh" , "b_ih" , "b_hh" ]:
189+ shape = flat_shapes .pop (0 ) # Get the current shape and remove it from the list
190+ num_elems = math .prod (shape )
191+ matrix = weights [w_offsets :w_offsets + num_elems ].reshape (shape )
192+
193+ # Swap between the input and output gates (third and fourth gates)
194+ gates = jnp .split (matrix , 4 , axis = 0 )
195+ swapped_matrix = jnp .concatenate ([gates [0 ], gates [1 ], gates [3 ], gates [2 ]], axis = 0 )
196+
197+ # Update the weights with swapped matrix
198+ weights = weights .at [w_offsets :w_offsets + num_elems ].set (swapped_matrix .flatten ())
199+ w_offsets += num_elems
200+
201+ return weights
202+
178203
179204def unpack_lstm_weights (
180205 weights : Array , input_size : int , hidden_size : int , num_layers : int ,
@@ -437,7 +462,8 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
437462rnn_fwd_p .def_impl (partial (xla .apply_primitive , rnn_fwd_p ))
438463rnn_fwd_p .def_abstract_eval (rnn_abstract_eval )
439464if gpu_rnn :
440- mlir .register_lowering (rnn_fwd_p , gpu_rnn .cudnn_rnn_lowering , platform = 'cuda' )
465+ mlir .register_lowering (rnn_fwd_p , gpu_rnn .cudnn_rnn_fwd_lowering , platform = 'cuda' )
466+ mlir .register_lowering (rnn_fwd_p , gpu_rnn .miopen_rnn_fwd_lowering , platform = 'rocm' )
441467
442468
443469def lstm_bwd (input_size : int , hidden_size : int , num_layers : int , dropout : float ,
@@ -481,5 +507,7 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
481507if gpu_rnn :
482508 mlir .register_lowering (
483509 rnn_bwd_p , gpu_rnn .cudnn_rnn_bwd_lowering , platform = 'cuda' )
510+ mlir .register_lowering (
511+ rnn_bwd_p , gpu_rnn .miopen_rnn_bwd_lowering , platform = 'rocm' )
484512
485513lstm .defvjp (lstm_fwd , lstm_bwd )
0 commit comments