|
28 | 28 | _flash_mha_fwd_hlo_sharded = custom_partitioning(_flash_mha_fwd_hlo, static_argnums=(3,4,5))
|
29 | 29 | _flash_mha_bwd_hlo_sharded = custom_partitioning(_flash_mha_bwd_hlo, static_argnums=(6,7,8))
|
30 | 30 |
|
| 31 | +from jax._src.ad_checkpoint import _optimization_barrier |
| 32 | + |
| 33 | +def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v): |
| 34 | + [n,l,h,d] = q.shape |
| 35 | + |
| 36 | + q_ix = jax.lax.axis_index(axis_name) |
| 37 | + k_ix = jax.lax.axis_index(axis_name) |
| 38 | + |
| 39 | + o = jnp.zeros([n,l,h,d], jnp.float32) |
| 40 | + lse = jnp.full([n,h,l], float('-inf'), jnp.float32) |
| 41 | + |
| 42 | + # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) |
| 43 | + def f(c, a): |
| 44 | + (k, v, o, lse, k_ix) = c |
| 45 | + |
| 46 | + o1, lse1 = o, lse |
| 47 | + if is_causal: |
| 48 | + o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32), |
| 49 | + [ |
| 50 | + lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)), |
| 51 | + lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)), |
| 52 | + lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)), |
| 53 | + ], q, k, v) |
| 54 | + else: |
| 55 | + o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)) |
| 56 | + o2 = o2.astype(jnp.float32) |
| 57 | + |
| 58 | + mx = jnp.maximum(lse1,lse2) |
| 59 | + mn = jnp.minimum(lse1,lse2) |
| 60 | + lse = jnp.log1p(jnp.exp(mn-mx)) + mx |
| 61 | + |
| 62 | + o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') + |
| 63 | + o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1')) |
| 64 | + |
| 65 | + k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) |
| 66 | + v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) |
| 67 | + k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)]) |
| 68 | + |
| 69 | + return ((k2, v2, o, lse, k_ix), None) |
| 70 | + acc = (k,v,o,lse,k_ix) |
| 71 | + # We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either). |
| 72 | + # Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently. |
| 73 | + for _ in range(axis_size): |
| 74 | + acc, _ = f(acc, None) |
| 75 | + acc = _optimization_barrier(acc) |
| 76 | + (_,_,o,lse,_) = acc |
| 77 | + # (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size) |
| 78 | + return o.astype(q.dtype), lse |
| 79 | + |
31 | 80 | def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
|
32 | 81 | result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
|
33 | 82 | arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
|
34 | 83 |
|
35 | 84 | q_sharding = arg_shardings[0]
|
36 | 85 | if isinstance(q_sharding, PositionalSharding):
|
37 |
| - if not is_causal and window_size == (-1,-1): |
38 |
| - # We can handle Q that's sharded across the L dimension |
39 |
| - # without replicating Q by executing it as a cross |
40 |
| - # attention: |
41 |
| - # |
42 |
| - # q : n [L/devices] h d |
43 |
| - # kv : n L h d |
44 |
| - # -> o : n [L/devices] h d |
45 |
| - # |
46 |
| - # TODO: We could handle q sharded across L even with |
47 |
| - # causal/local if we could communicate the slice offset |
48 |
| - # (of q in kv) to the c++ driver. But it's unclear how to |
49 |
| - # do that since the HLO has to be identical (SPMD). |
50 |
| - q_sharding = q_sharding.replicate(3) |
51 |
| - kv_sharding = q_sharding.replicate(1) |
52 |
| - (n,l,h,d) = q_sharding.shape |
53 |
| - result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1) # n h l |
54 |
| - arg_shardings = q_sharding, kv_sharding, kv_sharding |
55 |
| - else: |
56 |
| - # We need to replicate d always. |
57 |
| - q_sharding = q_sharding.replicate((1,3)) |
58 |
| - (n,l,h,d) = q_sharding.shape # l=1, d=1 |
59 |
| - result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1) |
60 |
| - arg_shardings = q_sharding, q_sharding, q_sharding |
| 86 | + (n,l,h,d) = q_sharding.shape |
| 87 | + assert d == 1, "Sharding across `d` won't be efficient, so it's not supported." |
| 88 | + assert l == 1, "For ring attention, use `with Mesh(...) as mesh` and NamedSharding." |
| 89 | + result_shardings = q_sharding, q_sharding.reshape((n,h,1)) # n h l |
| 90 | + arg_shardings = q_sharding, q_sharding, q_sharding |
61 | 91 | elif isinstance(q_sharding, NamedSharding):
|
62 | 92 | mesh = q_sharding.mesh
|
63 | 93 | [n,l,h,d] = q_sharding.spec
|
64 |
| - if not is_causal and window_size == (-1,-1): |
65 |
| - q_sharding = NamedSharding(mesh, P(n,l,h,None)) |
66 |
| - kv_sharding = NamedSharding(mesh, P(n,None,h,None)) |
67 |
| - lse_sharding = NamedSharding(mesh, P(n,h,l)) |
| 94 | + assert d == None, "Sharding across `d` won't be efficient, so it's not supported." |
| 95 | + if l != None: |
| 96 | + # assert not is_causal and window_size == (-1,-1), "Ring attention doesn't support causal or local masking yet." |
| 97 | + assert window_size == (-1,-1), "Ring attention doesn't support local masking yet." |
| 98 | + result_shardings = q_sharding, NamedSharding(mesh, P(n,h,l)) |
| 99 | + arg_shardings = q_sharding, q_sharding, q_sharding |
| 100 | + axis_name = l |
| 101 | + axis_size = mesh.shape[axis_name] |
| 102 | + # ring attention |
| 103 | + return mesh, partial(ring_fwd, softmax_scale, is_causal, axis_name, axis_size), result_shardings, arg_shardings |
68 | 104 | else:
|
69 |
| - q_sharding = NamedSharding(mesh, P(n,None,h,None)) |
70 |
| - kv_sharding = q_sharding |
71 |
| - lse_sharding = NamedSharding(mesh, P(n,h,None)) |
72 |
| - result_sharding = (q_sharding, lse_sharding) |
73 |
| - arg_shardings = (q_sharding, kv_sharding, kv_sharding) |
| 105 | + result_shardings = q_sharding, NamedSharding(mesh, P(n,h,l)) |
| 106 | + arg_shardings = q_sharding, q_sharding, q_sharding |
74 | 107 | def fwd(q,k,v):
|
75 | 108 | return _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
|
76 | 109 | return mesh, fwd, result_shardings, arg_shardings
|
|
0 commit comments