Skip to content

Commit 6d55af7

Browse files
committed
fix fwd shardings with replicated q
1 parent 1df9b3b commit 6d55af7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/flash_attn_jax/flash_sharding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ def partition_fwd(softmax_scale, is_causal, window_size,
3232
mesh: Mesh,
3333
arg_shapes: List[jax.ShapeDtypeStruct],
3434
result_shape: List[jax.ShapeDtypeStruct]):
35-
result_shardings = [x.sharding for x in result_shape],
36-
arg_shardings = [x.sharding for x in arg_shapes]
35+
result_shardings = tuple([x.sharding for x in result_shape])
36+
arg_shardings = tuple([x.sharding for x in arg_shapes])
3737

3838
q_sharding = arg_shardings[0]
3939
k_sharding = arg_shardings[1]
4040
v_sharding = arg_shardings[2]
4141
assert q_sharding == k_sharding and q_sharding == v_sharding, "Only support q, k, v sharing the same sharding."
4242
if is_replicated(q_sharding):
43-
result_sharding = (q_sharding, q_sharding)
43+
result_shardings = (q_sharding, q_sharding)
4444
elif isinstance(q_sharding, NamedSharding):
4545
mesh = q_sharding.mesh
4646
[n,l,h,d] = q_sharding.spec

0 commit comments

Comments
 (0)