-
I have a specific question about about In the code below, we have a
ring_attention_sharded = shard_map(
# calls ring_attention_standard in shard_map
partial(ring_attention_standard, axis_name="sp"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
in_specs=(
PS(("dp", "fsdp"), q_sp_dim, "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), None, q_sp_dim, None)
),
out_specs=PS(("dp", "fsdp"), q_sp_dim, "tp", None),
check_rep=False
)
@partial(jax.custom_vjp, nondiff_argnums=[4, 5])
def ring_attention_standard(q, k, v, attn_mask, axis_name, float32_logits=True):
y, _ = _ring_attention_standard_fwd(q, k, v, attn_mask, axis_name, float32_logits)
return y
def _ring_attention_standard_fwd(q, k, v, attn_mask, axis_name, float32_logits):
if float32_logits:
q, k = q.astype(jnp.float32), k.astype(jnp.float32)
batch, q_len, num_heads, _ = q.shape
batch, kv_len, num_heads, dim_per_head = k.shape
numerator = jnp.zeros((batch, q_len, num_heads, dim_per_head)).astype(q.dtype)
denominator = jnp.zeros((batch, num_heads, q_len)).astype(q.dtype)
axis_size = lax.psum(1, axis_name)
scale = jnp.sqrt(q.shape[-1])
def scan_kv_block(carry, idx):
prev_max_score, numerator, denominator, k, v = carry
mask = lax.dynamic_slice_in_dim(attn_mask,
(lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1)
attn_weights = jnp.einsum("bqhd,bkhd->bhqk", q, k) / scale
attn_weights = jnp.where(mask, attn_weights, jnp.finfo(attn_weights.dtype).min)
max_score = jnp.maximum(prev_max_score, jnp.max(attn_weights, axis=-1))
exp_weights = jnp.exp(attn_weights - max_score[..., None])
correction = rearrange(jnp.exp(prev_max_score - max_score), 'b h q -> b q h')[..., None]
numerator = numerator * correction + jnp.einsum("bhqk,bkhd->bqhd", exp_weights, v)
denominator = denominator * jnp.exp(prev_max_score - max_score) + jnp.sum(exp_weights, axis=-1)
k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i,
(i + 1) % axis_size) for i in range(axis_size)]), (k, v))
return (max_score, numerator, denominator, k, v), None
prev_max_score = jnp.full((batch, num_heads, q_len), -jnp.inf).astype(q.dtype)
# calls scan here
(max_score, numerator, denominator, _, _), _ = lax.scan(scan_kv_block,
init=(prev_max_score, numerator, denominator, k, v), xs=jnp.arange(0, axis_size))
output = numerator / rearrange(denominator, 'b h q -> b q h')[..., None]
return output.astype(v.dtype), (output, q, k, v, attn_mask, numerator, denominator, max_score) Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Beta Was this translation helpful? Give feedback.
-
Yes. Each device in the mesh will run the shard_mapped function (
It is sequential. The devices can run independently of each other though, except if there's a collective operation (i.e. communication across all devices). The collective operation
The same as described for 1. See https://jax.readthedocs.io/en/latest/notebooks/shard_map.html for more on shard_map, including the collective operations. |
Beta Was this translation helpful? Give feedback.
Yes. Each device in the mesh will run the shard_mapped function (
ring_attention_standard
in this case). In other words, you pass the "per-device" function to shard_map, and exactly what's in the function will be run on each device (unlikejit
, which you pass a "global" function that is then automatically rewritten to run across multiple devices). So the scan is run on each device, although it won't be automatically distributed, it just runs as specified.It is sequential. The devices can run independently of each other though, …