@@ -67,13 +67,18 @@ class MaskInfo(NamedTuple):
6767 q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking,
6868 this contains the list of indices that correspond to q tokens. For plain
6969 causal this is just np.arange(q_sequence_length).
70+ is_dynamic_mask: A bool indicating whether the mask is dynamic or static.
71+ When True, the leading dimensions of `partial_mask_blocks` (num_heads,
72+ q_blocks, kv_blocks) are not collapsed, allowing us to shard it along
73+ those dimensions.
7074 """
7175
7276 data_next : np .ndarray | jax .Array | None
7377 mask_next : np .ndarray | jax .Array | None
7478 block_mask : np .ndarray | jax .Array | None
7579 partial_mask_blocks : np .ndarray | jax .Array | None
7680 q_sequence : np .ndarray | None
81+ is_dynamic_mask : bool = None
7782
7883
7984def _downcast_to_small_type (array : np .ndarray ) -> np .ndarray :
@@ -168,7 +173,7 @@ def __eq__(self, other: object) -> bool:
168173def _get_mask_info_for_shard (
169174 output_shape : tuple [int , int , int ],
170175 has_mask_next : bool ,
171- mask : mask_lib .MultiHeadMask ,
176+ mask : mask_lib .MultiHeadMask | jax . Array ,
172177 block_shape : tuple [int , int ],
173178 coords_to_partial_mask_block_index : dict [tuple [int , int , int ], int ],
174179 masks_per_head_shard : int ,
@@ -338,7 +343,8 @@ def _process_dynamic_mask(
338343 launched.
339344 q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is
340345 launched.
341- shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored.
346+ shrink_grid: Whether or not we should apply the grid shrinking optimization.
347+ This is currently ignored.
342348
343349 Returns:
344350 `MaskInfo`, a sparse representation of the dense mask.
@@ -349,11 +355,6 @@ def _process_dynamic_mask(
349355 """
350356
351357 del shrink_grid
352-
353- # TODO(pobudzey): Properly support sharding.
354- if head_shards != 1 or q_seq_shards != 1 :
355- raise ValueError ('Dynamic mask processing does not support sharding.' )
356-
357358 if len (mask .shape ) != 3 :
358359 raise ValueError (f'Expected a 3-dim mask, instead got: { mask .shape } .' )
359360
@@ -370,6 +371,18 @@ def _process_dynamic_mask(
370371 if kv_mod != 0 :
371372 raise ValueError (f'{ kv_block_size = } should divide { kv_seq_len = } .' )
372373
374+ q_seq_len_per_shard , mod = divmod (q_seq_len , q_seq_shards )
375+ if mod != 0 :
376+ raise ValueError (f'{ q_seq_shards = } should divide { q_seq_len = } .' )
377+
378+ q_blocks_per_shard , mod = divmod (q_seq_len_per_shard , q_block_size )
379+ if mod != 0 :
380+ raise ValueError (f'{ q_block_size = } should divide { q_seq_len_per_shard = } .' )
381+
382+ heads_per_shard , mod = divmod (head_count , head_shards )
383+ if mod != 0 :
384+ raise ValueError (f'{ head_shards = } should divide { head_count = } .' )
385+
373386 block_mask_shape = (
374387 head_count ,
375388 q_blocks_count ,
@@ -398,26 +411,66 @@ def _process_dynamic_mask(
398411 block_mask = jnp .where (is_full_mask , 2 , block_mask )
399412 block_mask = jnp .where (is_empty_mask , 0 , block_mask )
400413
401- # TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline.
402- mask_next = jnp .where (
403- jnp .logical_or (is_empty_mask , is_full_mask ),
404- 0 ,
405- jnp .arange (math .prod (block_mask_shape ), dtype = np .int32 ).reshape (
406- block_mask_shape
407- ),
408- )
414+ q_sequence_axis = 1
415+ head_axis = 0
409416
410- # data_next stores the index of the next non-empty data block in the sequence.
411- # The indices of empty blocks are set to 0 to avoid copying extra data when
412- # pipeling.
413- if is_dkv :
414- data_next = jnp .arange (q_blocks_count , dtype = np .int32 )[None , :, None ]
415- else :
416- data_next = jnp .arange (kv_blocks_count , dtype = np .int32 )[None , None , :]
417- data_next = jnp .broadcast_to (data_next , block_mask_shape )
418- data_next = jnp .where (is_empty_mask , 0 , data_next )
417+ # Each iteration of the loop processes a slice of the mask info
418+ # tensors of this shape:
419+ mask_info_slice_shape = (heads_per_shard , q_blocks_per_shard , kv_blocks_count )
420+
421+ # Collect mask_info shards along the head dimension, concatentate (or
422+ # broadcast) them after the loop.
423+ data_next_per_head_list , mask_next_per_head_list = [], []
424+ for head_shard in range (head_shards ):
425+ head_start = head_shard * heads_per_shard
426+ mask_head_slice = slice (head_start , head_start + heads_per_shard )
427+
428+ # Collect mask_info shards along the q_sequence dimension, concatenate them
429+ # after the loop.
430+ data_next_sequence_slices , mask_next_sequence_slices = [], []
431+ for q_seq_len_shard in range (q_seq_shards ):
432+ q_seq_len_start = q_seq_len_shard * q_blocks_per_shard
433+ blocked_q_seq_len_slice = slice (
434+ q_seq_len_start , q_seq_len_start + q_blocks_per_shard
435+ )
436+ local_block_mask = block_mask [mask_head_slice , blocked_q_seq_len_slice ]
437+
438+ mask_next_slice = jnp .arange (
439+ math .prod (mask_info_slice_shape ), dtype = np .int32
440+ ).reshape (mask_info_slice_shape )
441+ mask_next_slice = jnp .where (local_block_mask == 1 , mask_next_slice , 0 )
442+
443+ # data_next stores the index of the next non-empty data block in the sequence.
444+ # The indices of empty blocks are set to 0 to avoid copying extra data when
445+ # pipeling.
446+ if is_dkv :
447+ data_next_slice = jnp .arange (q_blocks_per_shard , dtype = np .int32 )[
448+ None , :, None
449+ ]
450+ else :
451+ data_next_slice = jnp .arange (kv_blocks_count , dtype = np .int32 )[
452+ None , None , :
453+ ]
454+ data_next_slice = jnp .broadcast_to (data_next_slice , mask_info_slice_shape )
455+ data_next_slice = jnp .where (local_block_mask == 0 , 0 , data_next_slice )
456+
457+ data_next_sequence_slices .append (data_next_slice )
458+ mask_next_sequence_slices .append (mask_next_slice )
459+
460+ # Concatenate the sequence shards.
461+ data_next_per_head = jnp .concatenate (
462+ data_next_sequence_slices , axis = q_sequence_axis
463+ )
464+ data_next_per_head_list .append (data_next_per_head )
465+ mask_next_per_head = jnp .concatenate (
466+ mask_next_sequence_slices , axis = q_sequence_axis
467+ )
468+ mask_next_per_head_list .append (mask_next_per_head )
469+
470+ # Concatenate (or broadcast) the head shards.
471+ data_next = jnp .concatenate (data_next_per_head_list , axis = head_axis )
472+ mask_next = jnp .concatenate (mask_next_per_head_list , axis = head_axis )
419473
420- partial_mask_blocks = partial_mask_blocks .reshape (- 1 , * block_shape )
421474 if is_dkv :
422475 partial_mask_blocks = partial_mask_blocks .swapaxes (- 1 , - 2 )
423476
@@ -438,9 +491,11 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array:
438491 if downcast_smem_data :
439492 block_mask = block_mask .astype (np .int8 ) # values are in the range [0, 1, 2]
440493 data_next = _downcast (
441- data_next , q_blocks_count if is_dkv else kv_blocks_count
494+ data_next , q_blocks_per_shard if is_dkv else kv_blocks_count
495+ )
496+ mask_next = _downcast (
497+ mask_next , heads_per_shard * q_blocks_per_shard * kv_blocks_count
442498 )
443- mask_next = _downcast (mask_next , math .prod (block_mask_shape ))
444499
445500 return (
446501 MaskInfo (
@@ -449,6 +504,7 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array:
449504 block_mask = block_mask ,
450505 partial_mask_blocks = partial_mask_blocks ,
451506 q_sequence = None ,
507+ is_dynamic_mask = True ,
452508 ),
453509 None ,
454510 )
0 commit comments