Skip to content

Commit 609ec62

Browse files
committed
[ROCm] Implement RNN support
1 parent 5cda053 commit 609ec62

File tree

7 files changed

+446
-125
lines changed

7 files changed

+446
-125
lines changed

jax/experimental/rnn.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,50 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
176176
rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k)
177177

178178

179+
def swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional):
180+
"""
181+
Swaps the weights for the input and output gates in an LSTM model's parameters.
182+
183+
This function is specifically designed for compatibility with MIOpen, where the gate ordering
184+
differs from CuDNN. In CuDNN, the gates are ordered as:
185+
- 0: Forget gate (f)
186+
- 1: Input gate (i)
187+
- 2: New memory gate (g)
188+
- 3: Output gate (o)
189+
190+
However, in MIOpen, the ordering of the new memory (g) and output (o) gates is swapped:
191+
- 0: Forget gate (f)
192+
- 1: Input gate (i)
193+
- 2: Output gate (o)
194+
- 3: New memory gate (g)
195+
196+
This function rearranges the weights and biases for the gates to ensure that the model
197+
operates correctly with MIOpen by swapping the third (new memory) and fourth (output) gates.
198+
"""
199+
weights = jnp.asarray(weights)
200+
flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional)
201+
num_directions = 2 if bidirectional else 1
202+
203+
w_offsets = 0
204+
for l in range(num_layers):
205+
for direction in range(num_directions):
206+
# Iterate through all weight and bias gate names to swap gates in both weights and biases.
207+
for gate_name in ["W_ih", "W_hh", "b_ih", "b_hh"]:
208+
shape = flat_shapes.pop(0)
209+
num_elems = math.prod(shape)
210+
matrix = weights[w_offsets:w_offsets + num_elems].reshape(shape)
211+
212+
# Swap between the input and output gates (third and fourth gates).
213+
gates = jnp.split(matrix, 4, axis=0)
214+
swapped_matrix = jnp.concatenate([gates[0], gates[1], gates[3], gates[2]], axis=0)
215+
216+
# Update the weights with swapped matrix.
217+
weights = weights.at[w_offsets:w_offsets + num_elems].set(swapped_matrix.flatten())
218+
w_offsets += num_elems
219+
220+
return weights
221+
222+
179223
def unpack_lstm_weights(
180224
weights: Array, input_size: int, hidden_size: int, num_layers: int,
181225
bidirectional: bool
@@ -437,7 +481,8 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
437481
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
438482
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
439483
if gpu_rnn:
440-
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
484+
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda')
485+
mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm')
441486

442487

443488
def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
@@ -481,5 +526,7 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
481526
if gpu_rnn:
482527
mlir.register_lowering(
483528
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
529+
mlir.register_lowering(
530+
rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm')
484531

485532
lstm.defvjp(lstm_fwd, lstm_bwd)

0 commit comments

Comments
 (0)