|
18 | 18 | import collections |
19 | 19 | from collections.abc import Callable |
20 | 20 | import functools |
| 21 | +import math |
21 | 22 | from typing import NamedTuple |
| 23 | + |
| 24 | +import jax |
22 | 25 | from jax import util as jax_util |
23 | 26 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib |
| 27 | +import jax.numpy as jnp |
24 | 28 | import numpy as np |
25 | 29 |
|
26 | 30 | # mypy: ignore-errors |
@@ -65,10 +69,10 @@ class MaskInfo(NamedTuple): |
65 | 69 | causal this is just np.arange(q_sequence_length). |
66 | 70 | """ |
67 | 71 |
|
68 | | - data_next: np.ndarray | None |
69 | | - mask_next: np.ndarray | None |
70 | | - block_mask: np.ndarray | None |
71 | | - partial_mask_blocks: np.ndarray | None |
| 72 | + data_next: np.ndarray | jax.Array | None |
| 73 | + mask_next: np.ndarray | jax.Array | None |
| 74 | + block_mask: np.ndarray | jax.Array | None |
| 75 | + partial_mask_blocks: np.ndarray | jax.Array | None |
72 | 76 | q_sequence: np.ndarray | None |
73 | 77 |
|
74 | 78 |
|
@@ -245,7 +249,7 @@ def _get_mask_info_for_shard( |
245 | 249 | mask_next = np.zeros(output_shape, dtype=np.int32) |
246 | 250 | data_next = np.zeros(output_shape, dtype=np.int32) |
247 | 251 |
|
248 | | - # If the mask is completelly zero'd out return freshly initialized outputs. |
| 252 | + # If the mask is completely zero'd out return freshly initialized outputs. |
249 | 253 | if not data_coords: |
250 | 254 | return data_next, mask_next |
251 | 255 |
|
@@ -304,6 +308,152 @@ def _get_mask_info_for_shard( |
304 | 308 | return data_next, mask_next |
305 | 309 |
|
306 | 310 |
|
| 311 | +def _process_dynamic_mask( |
| 312 | + mask: jax.Array, |
| 313 | + block_shape: tuple[int, int], |
| 314 | + is_dkv: bool, |
| 315 | + *, |
| 316 | + downcast_smem_data: bool = True, |
| 317 | + head_shards: int = 1, |
| 318 | + q_seq_shards: int = 1, |
| 319 | + shrink_grid: bool = True, |
| 320 | +) -> tuple[MaskInfo, None]: |
| 321 | + """Similar to `_process_mask` but the mask must be a dynamic array. |
| 322 | +
|
| 323 | + Since the mask is dynamic, we can't know the exact number of partial mask |
| 324 | + blocks at trace time. Therefore, the entire mask is materialized in |
| 325 | + `partial_mask_blocks`. |
| 326 | +
|
| 327 | + Note that we can still populate MaskInfo to skip fully-masked blocks. |
| 328 | +
|
| 329 | + Args: |
| 330 | + mask: A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense |
| 331 | + mask to process. |
| 332 | + block_shape: A Tuple[int, int] representing the shape of the Pallas grid |
| 333 | + block. |
| 334 | + is_dkv: True if we are processing the dKV mask |
| 335 | + downcast_smem_data: If True, downcast the scalar-memory data of MaskInfo to |
| 336 | + a data type smaller than np.int32 (if possible). |
| 337 | + head_shards: Number of head shards of the mesh in which the kernel is |
| 338 | + launched. |
| 339 | + q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is |
| 340 | + launched. |
| 341 | + shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored. |
| 342 | +
|
| 343 | + Returns: |
| 344 | + `MaskInfo`, a sparse representation of the dense mask. |
| 345 | +
|
| 346 | + Raises: |
| 347 | + ValueError: if the input mask is invalid or the block sizes are not |
| 348 | + compatible with the mask sizes. |
| 349 | + """ |
| 350 | + |
| 351 | + del shrink_grid |
| 352 | + |
| 353 | + # TODO(pobudzey): Properly support sharding. |
| 354 | + if head_shards != 1 or q_seq_shards != 1: |
| 355 | + raise ValueError('Dynamic mask processing does not support sharding.') |
| 356 | + |
| 357 | + if len(mask.shape) != 3: |
| 358 | + raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.') |
| 359 | + |
| 360 | + if mask.dtype != jnp.bool: |
| 361 | + raise ValueError(f'Expected a bool mask, instead got: {mask.dtype}.') |
| 362 | + |
| 363 | + head_count, q_seq_len, kv_seq_len = mask.shape |
| 364 | + q_block_size, kv_block_size = block_shape |
| 365 | + q_blocks_count, q_mod = divmod(q_seq_len, q_block_size) |
| 366 | + kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_block_size) |
| 367 | + |
| 368 | + if q_mod != 0: |
| 369 | + raise ValueError(f'{q_block_size=} should divide {q_seq_len=}.') |
| 370 | + if kv_mod != 0: |
| 371 | + raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') |
| 372 | + |
| 373 | + block_mask_shape = ( |
| 374 | + head_count, |
| 375 | + q_blocks_count, |
| 376 | + kv_blocks_count, |
| 377 | + ) |
| 378 | + |
| 379 | + # Tile the last 2 dimensions of the mask into 2D tiles of size `block_shape`. |
| 380 | + partial_mask_blocks = ( |
| 381 | + mask.reshape( |
| 382 | + head_count, |
| 383 | + q_blocks_count, |
| 384 | + q_block_size, |
| 385 | + kv_blocks_count, |
| 386 | + kv_block_size, |
| 387 | + ) |
| 388 | + .swapaxes(-2, -3) |
| 389 | + .astype(np.bool_) |
| 390 | + ) |
| 391 | + |
| 392 | + # The block mask is 2 for all blocks with all entries set to True and 1 for |
| 393 | + # blocks with a mix of True and False entries. |
| 394 | + is_full_mask = jnp.all(partial_mask_blocks, axis=(-1, -2)) |
| 395 | + is_empty_mask = jnp.logical_not(jnp.any(partial_mask_blocks, axis=(-1, -2))) |
| 396 | + |
| 397 | + block_mask = jnp.ones(block_mask_shape, dtype=np.int32) |
| 398 | + block_mask = jnp.where(is_full_mask, 2, block_mask) |
| 399 | + block_mask = jnp.where(is_empty_mask, 0, block_mask) |
| 400 | + |
| 401 | + # TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline. |
| 402 | + mask_next = jnp.where( |
| 403 | + jnp.logical_or(is_empty_mask, is_full_mask), |
| 404 | + 0, |
| 405 | + jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape( |
| 406 | + block_mask_shape |
| 407 | + ), |
| 408 | + ) |
| 409 | + |
| 410 | + # data_next stores the index of the next non-empty data block in the sequence. |
| 411 | + # The indices of empty blocks are set to 0 to avoid copying extra data when |
| 412 | + # pipeling. |
| 413 | + if is_dkv: |
| 414 | + data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None] |
| 415 | + else: |
| 416 | + data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :] |
| 417 | + data_next = jnp.broadcast_to(data_next, block_mask_shape) |
| 418 | + data_next = jnp.where(is_empty_mask, 0, data_next) |
| 419 | + |
| 420 | + partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape) |
| 421 | + if is_dkv: |
| 422 | + partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2) |
| 423 | + |
| 424 | + def _downcast(array: jax.Array, max_value: int) -> jax.Array: |
| 425 | + if array.size == 0: |
| 426 | + return array |
| 427 | + |
| 428 | + if array.dtype != np.int32: |
| 429 | + raise ValueError(f'Expected int32 input, but got {array.dtype}.') |
| 430 | + |
| 431 | + if max_value <= np.iinfo(np.int8).max: |
| 432 | + return array.astype(np.int8) |
| 433 | + elif max_value <= np.iinfo(np.int16).max: |
| 434 | + return array.astype(np.int16) |
| 435 | + else: |
| 436 | + return array.astype(np.int32) |
| 437 | + |
| 438 | + if downcast_smem_data: |
| 439 | + block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] |
| 440 | + data_next = _downcast( |
| 441 | + data_next, q_blocks_count if is_dkv else kv_blocks_count |
| 442 | + ) |
| 443 | + mask_next = _downcast(mask_next, math.prod(block_mask_shape)) |
| 444 | + |
| 445 | + return ( |
| 446 | + MaskInfo( |
| 447 | + data_next=data_next, |
| 448 | + mask_next=mask_next, |
| 449 | + block_mask=block_mask, |
| 450 | + partial_mask_blocks=partial_mask_blocks, |
| 451 | + q_sequence=None, |
| 452 | + ), |
| 453 | + None, |
| 454 | + ) |
| 455 | + |
| 456 | + |
307 | 457 | # When used in a transformer network with multiple layers, the SplashAttention |
308 | 458 | # kernel is created several times with the same mask. Cache MaskInfo to avoid |
309 | 459 | # blowing up compile times. Ideally the size of the cache should be determined |
@@ -410,7 +560,7 @@ def assign_unique_ids(objects): |
410 | 560 | mask_id_to_heads[mask_id].append(head) |
411 | 561 | mask_id_to_head_shards[mask_id].add(head_shard) |
412 | 562 |
|
413 | | - # If we have at most one unique mask per each head shard, then we can brodcast |
| 563 | + # If we have at most one unique mask per each head shard, then we can broadcast |
414 | 564 | # the mask to all the heads in the shard. This is the common case. |
415 | 565 | # If we have more than one mask in each head shard, then the optimization |
416 | 566 | # cannot kick in and we use one mask for each head. |
@@ -699,9 +849,7 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int): |
699 | 849 | current_block_mask, |
700 | 850 | current_data_next, |
701 | 851 | current_mask_next, |
702 | | - ) in zip( |
703 | | - block_mask_shards, data_next_shards, mask_next_shards |
704 | | - ): |
| 852 | + ) in zip(block_mask_shards, data_next_shards, mask_next_shards): |
705 | 853 | # For dKV shrinking happens along axis Q (the rows of MaskInfo), for |
706 | 854 | # fwd and dQ shrinking happens along axis KV (the columns of MaskInfo). |
707 | 855 | if is_dkv: |
@@ -924,3 +1072,6 @@ def _slice_mask_info( |
924 | 1072 |
|
925 | 1073 | process_mask = functools.partial(_process_mask, is_dkv=False) |
926 | 1074 | process_mask_dkv = functools.partial(_process_mask, is_dkv=True) |
| 1075 | + |
| 1076 | +process_dynamic_mask = functools.partial(_process_dynamic_mask, is_dkv=False) |
| 1077 | +process_dynamic_mask_dkv = functools.partial(_process_dynamic_mask, is_dkv=True) |
0 commit comments