Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions jax/experimental/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,31 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
return jax.random.uniform(
rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k)

def swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional):
"""Swaps the weights for the input and output gates for an LSTM model."""
weights = jnp.asarray(weights) # Ensure weights are JAX arrays
flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional)
num_directions = 2 if bidirectional else 1

w_offsets = 0
for l in range(num_layers):
for direction in range(num_directions):
# Iterate through all weight and bias gate names to swap gates in both weights and biases
for gate_name in ["W_ih", "W_hh", "b_ih", "b_hh"]:
shape = flat_shapes.pop(0) # Get the current shape and remove it from the list
num_elems = math.prod(shape)
matrix = weights[w_offsets:w_offsets + num_elems].reshape(shape)

# Swap between the input and output gates (third and fourth gates)
gates = jnp.split(matrix, 4, axis=0)
swapped_matrix = jnp.concatenate([gates[0], gates[1], gates[3], gates[2]], axis=0)

# Update the weights with swapped matrix
weights = weights.at[w_offsets:w_offsets + num_elems].set(swapped_matrix.flatten())
w_offsets += num_elems

return weights


def unpack_lstm_weights(
weights: Array, input_size: int, hidden_size: int, num_layers: int,
Expand Down Expand Up @@ -438,6 +463,8 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
if gpu_rnn:
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"):
mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_lowering, platform='rocm')


def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
Expand Down Expand Up @@ -481,5 +508,8 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
if gpu_rnn:
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
if hasattr(gpu_rnn, "miopen_rnn_bwd_lowering"):
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm')

lstm.defvjp(lstm_fwd, lstm_bwd)
Loading
Loading