Skip to content

Commit 089e42b

Browse files
committed
[ROCm] Implement RNN support
1 parent 5cda053 commit 089e42b

File tree

7 files changed

+425
-125
lines changed

7 files changed

+425
-125
lines changed

jax/experimental/rnn.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,31 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
175175
return jax.random.uniform(
176176
rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k)
177177

178+
def swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional):
179+
"""Swaps the weights for the input and output gates for an LSTM model."""
180+
weights = jnp.asarray(weights) # Ensure weights are JAX arrays
181+
flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional)
182+
num_directions = 2 if bidirectional else 1
183+
184+
w_offsets = 0
185+
for l in range(num_layers):
186+
for direction in range(num_directions):
187+
# Iterate through all weight and bias gate names to swap gates in both weights and biases
188+
for gate_name in ["W_ih", "W_hh", "b_ih", "b_hh"]:
189+
shape = flat_shapes.pop(0) # Get the current shape and remove it from the list
190+
num_elems = math.prod(shape)
191+
matrix = weights[w_offsets:w_offsets + num_elems].reshape(shape)
192+
193+
# Swap between the input and output gates (third and fourth gates)
194+
gates = jnp.split(matrix, 4, axis=0)
195+
swapped_matrix = jnp.concatenate([gates[0], gates[1], gates[3], gates[2]], axis=0)
196+
197+
# Update the weights with swapped matrix
198+
weights = weights.at[w_offsets:w_offsets + num_elems].set(swapped_matrix.flatten())
199+
w_offsets += num_elems
200+
201+
return weights
202+
178203

179204
def unpack_lstm_weights(
180205
weights: Array, input_size: int, hidden_size: int, num_layers: int,
@@ -437,7 +462,8 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
437462
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
438463
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
439464
if gpu_rnn:
440-
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
465+
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda')
466+
mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm')
441467

442468

443469
def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
@@ -481,5 +507,7 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
481507
if gpu_rnn:
482508
mlir.register_lowering(
483509
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
510+
mlir.register_lowering(
511+
rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm')
484512

485513
lstm.defvjp(lstm_fwd, lstm_bwd)

0 commit comments

Comments
 (0)