|
43 | 43 |
|
44 | 44 |
|
45 | 45 | from MaxText import max_utils |
46 | | -from MaxText.sharding import maybe_shard_with_name |
| 46 | +from MaxText.sharding import maybe_shard_with_name, logical_to_mesh_axes |
47 | 47 | from MaxText.common_types import ( |
48 | 48 | DEFAULT_MASK_VALUE, |
49 | 49 | BATCH, |
@@ -530,6 +530,9 @@ def maybe_create_nnx(einsum, *args): |
530 | 530 | self.AqtEinsum_2 = jnp.einsum |
531 | 531 | self.AqtEinsum_3 = jnp.einsum |
532 | 532 |
|
| 533 | + def _logical_to_mesh_axes(self, logical_name): |
| 534 | + return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules) |
| 535 | + |
533 | 536 | def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None: |
534 | 537 | """Check attention inputs.""" |
535 | 538 |
|
@@ -950,10 +953,10 @@ def gpu_ragged_attention(self, q: Array, k: Array | KVTensor, v: Array | KVTenso |
950 | 953 | q_for_gqa = q.squeeze(axis=1) |
951 | 954 |
|
952 | 955 | # Define logical axis names - clearer and avoids repeated calls. |
953 | | - b = nn.logical_to_mesh_axes(self.ragged_lengths_names) |
954 | | - bsnd = nn.logical_to_mesh_axes(self.cache_logical_axis_names) |
955 | | - bnd = nn.logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS, CACHE_KV)) |
956 | | - bn = nn.logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS)) |
| 956 | + b = self._logical_to_mesh_axes(self.ragged_lengths_names) |
| 957 | + bsnd = self._logical_to_mesh_axes(self.cache_logical_axis_names) |
| 958 | + bnd = self._logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS, CACHE_KV)) |
| 959 | + bn = self._logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS)) |
957 | 960 |
|
958 | 961 | @functools.partial( |
959 | 962 | jax.shard_map, |
@@ -1006,8 +1009,8 @@ def tpu_ragged_attention( |
1006 | 1009 | """Ragged Attention.""" |
1007 | 1010 | if isinstance(query, KVTensor): |
1008 | 1011 | raise TypeError("Ragged attention does not currently support quantized tensors.") |
1009 | | - b = nn.logical_to_mesh_axes(self.ragged_lengths_names) |
1010 | | - bsnd = nn.logical_to_mesh_axes(self.cache_logical_axis_names) |
| 1012 | + b = self._logical_to_mesh_axes(self.ragged_lengths_names) |
| 1013 | + bsnd = self._logical_to_mesh_axes(self.cache_logical_axis_names) |
1011 | 1014 |
|
1012 | 1015 | @functools.partial( |
1013 | 1016 | jax.shard_map, |
@@ -1050,23 +1053,23 @@ def tpu_flash_attention( |
1050 | 1053 | value = jnp.transpose(value, axes=(0, 2, 1, 3)) |
1051 | 1054 | segment_axis_names_q = None |
1052 | 1055 | segment_axis_names_kv = None |
1053 | | - sink_axis_names = nn.logical_to_mesh_axes((HEAD,)) |
| 1056 | + sink_axis_names = self._logical_to_mesh_axes((HEAD,)) |
1054 | 1057 | if decoder_segment_ids is not None: |
1055 | 1058 | if self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1056 | | - segment_axis_names_q = nn.logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH)) |
1057 | | - segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH)) |
| 1059 | + segment_axis_names_q = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH)) |
| 1060 | + segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH)) |
1058 | 1061 | else: |
1059 | | - segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP)) |
1060 | | - segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, KV_LENGTH)) |
| 1062 | + segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP)) |
| 1063 | + segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH)) |
1061 | 1064 |
|
1062 | 1065 | if self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1063 | | - axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep) |
1064 | | - axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q_ep) |
1065 | | - axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv_ep) |
| 1066 | + axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep) |
| 1067 | + axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep) |
| 1068 | + axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep) |
1066 | 1069 | else: |
1067 | | - axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel) |
1068 | | - axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) |
1069 | | - axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) |
| 1070 | + axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) |
| 1071 | + axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) |
| 1072 | + axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) |
1070 | 1073 |
|
1071 | 1074 | global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv |
1072 | 1075 | global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel |
@@ -1197,9 +1200,9 @@ def wrap_splash_kernel(single_head_mask, shard_head_size=1): |
1197 | 1200 | shard_head_size = np.prod(logical_axis_rules_head) |
1198 | 1201 | splash_kernel = wrap_splash_kernel(single_head_mask, int(shard_head_size)) |
1199 | 1202 | if self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1200 | | - segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) |
| 1203 | + segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) |
1201 | 1204 | else: |
1202 | | - segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) |
| 1205 | + segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) |
1203 | 1206 | else: |
1204 | 1207 | # Create multi-head mask |
1205 | 1208 | multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) |
|
0 commit comments