Skip to content

Commit 83aa65c

Browse files
feginmansiag05
authored andcommitted
[CP][BE] Cosmetic refactors for CP code base (pytorch#163115)
Summary: This PR is extracted from pytorch#162542, to make the original PR easier to review. This PR only contains cosmetic changes. Pull Request resolved: pytorch#163115 Approved by: https://github.com/tianyu-l ghstack dependencies: pytorch#162539, pytorch#162540, pytorch#162541
1 parent c65bc42 commit 83aa65c

File tree

1 file changed

+101
-104
lines changed

1 file changed

+101
-104
lines changed

torch/distributed/tensor/experimental/_attention.py

Lines changed: 101 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class _ContextParallelOptions:
5555
# errors. It is likely this is always True but we currently keep this variable
5656
# for the experimental purpose.
5757
convert_to_f32: bool = True
58-
enable_load_balance = True
58+
enable_load_balance: bool = True
5959
rotate_method: _RotateMethod = _RotateMethod.ALL_GATHER
6060

6161

@@ -924,18 +924,10 @@ def _distribute_function(
924924
output_fn: Optional[Callable] = None,
925925
) -> None:
926926
"""
927-
``distribute_function`` is an experimental API that allows users to "distribute"
928-
the inputs and outputs of a function. Similar to ``distribute_module``, this API
929-
installs hooks to the ``fn`` to convert the inputs and outputs. There are two
930-
major differences between ``distribute_function`` and ``distribute_module``.
931-
First, a function does not have parameters and buffers, as a result,
932-
``distribute_function`` itself won't convert any parameters/buffers but simply
933-
install the input and output hooks. The tensor conversion will happen in the hooks.
934-
Another difference is an nn.Module subclass can have several instances and each
935-
instance be fed into ``distribute_module`` independently with affecting other
936-
instance. On the other hand, function is a singleton object. So if a function
937-
is distributed by ``distribute_function`` all subsequent calls to the function
938-
will invoke the installed hooks.
927+
A helper function to replace a function with a distributed version by
928+
using the monkey patching approach.
929+
930+
This function is for the CP internal usage only.
939931
940932
Args:
941933
fn (Callable): the function to be distributed.
@@ -986,7 +978,7 @@ def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
986978

987979

988980
@contextlib.contextmanager
989-
def _enable_cp_dispatcher() -> Generator[None, None, None]:
981+
def _enable_cp_dtensor_dispatcher() -> Generator[None, None, None]:
990982
"""Enables DTensor dispatcher to dispatch SDPA to CP."""
991983
old_handlers = DTensor._op_dispatcher._custom_op_handlers
992984
DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops}
@@ -996,94 +988,10 @@ def _enable_cp_dispatcher() -> Generator[None, None, None]:
996988
DTensor._op_dispatcher._custom_op_handlers = old_handlers
997989

998990

999-
def create_cp_block_mask(
1000-
mask_mod: _mask_mod_signature,
1001-
B: int,
1002-
H: int,
1003-
Q_LEN: int,
1004-
KV_LEN: int,
1005-
device_mesh: DeviceMesh,
1006-
) -> BlockMask:
1007-
"""
1008-
This API creates a special BlockMask for Context Parallel FlexAttention:
1009-
1. This BlockMask is masking on the attention of Q shard and KV global views, by
1010-
mapping the local q_idx to the global q_idx before sending to mask_mod.
1011-
2. The kv_seq_length (i.e. seq_lengths[1]) of this blockMask is tailored to match
1012-
the sequence length of KV shard instead of KV global. This is to pass the shape check
1013-
in flex_atttention(). The correct value (i.e. the sequence length of KV global) will be
1014-
used in flex_attention once the shape check passes.
1015-
1016-
Args:
1017-
mask_mod (Callable): Function to modify the mask over the global attention result.
1018-
B (int): Batch size.
1019-
H (int): Number of query heads.
1020-
Q_LEN (int): Sequence length of query (global view).
1021-
KV_LEN (int): Sequence length of key/value (global view).
1022-
device_mesh (:class:`DeviceMesh`): The device mesh for the context parallelism.
1023-
1024-
Return:
1025-
:class:`BlockMask`: the block_mask to be used in flex_attention() within the
1026-
context_parallel() context.
1027-
1028-
.. warning::
1029-
This function cannot generate correct block_mask if the BLOCK_SIZE is not
1030-
``_DEFAULT_SPARSE_BLOCK_SIZE`` which usually happens when the attention
1031-
size is smaller than 128. Please do not use context_parallel() when the
1032-
FlexAttention size is small.
1033-
"""
1034-
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE
1035-
1036-
compiled_create_block_mask = torch.compile(
1037-
create_block_mask, dynamic=False, fullgraph=True
1038-
)
1039-
1040-
def _rewrite_mask_mod(
1041-
mask_mod: _mask_mod_signature,
1042-
rank: int,
1043-
world_size: int,
1044-
block_size: int,
1045-
local_q_size: int,
1046-
) -> _mask_mod_signature:
1047-
def local_q_idx_to_q_idx(local_q_idx: torch.Tensor) -> torch.Tensor:
1048-
# calculate local block_idx and block_offset
1049-
local_blk_idx, local_blk_offset = (
1050-
local_q_idx // block_size,
1051-
local_q_idx % block_size,
1052-
)
1053-
# NOTE: load balancing is not used
1054-
local_num_blocks = local_q_size // block_size
1055-
blk_idx = local_num_blocks * rank + local_blk_idx
1056-
return blk_idx * block_size + local_blk_offset
1057-
1058-
return lambda b, h, q_idx, kv_idx: mask_mod(
1059-
b,
1060-
h,
1061-
local_q_idx_to_q_idx(q_idx),
1062-
kv_idx,
1063-
)
1064-
1065-
cp_rank = device_mesh.get_local_rank()
1066-
cp_group_size = device_mesh.size()
1067-
Q_SHARD_LEN = Q_LEN // cp_group_size
1068-
block_size = _DEFAULT_SPARSE_BLOCK_SIZE
1069-
block_mask = compiled_create_block_mask(
1070-
_rewrite_mask_mod(mask_mod, cp_rank, cp_group_size, block_size, Q_SHARD_LEN),
1071-
B,
1072-
H,
1073-
Q_SHARD_LEN,
1074-
KV_LEN,
1075-
device=device_mesh.device_type,
1076-
BLOCK_SIZE=(block_size, block_size),
1077-
)
1078-
# flex_attention function checks the following shape so we need to rewrite:
1079-
# key.size(-2) == block_mask.seq_lengths[1]
1080-
seq_lengths = block_mask.seq_lengths
1081-
block_mask.seq_lengths = (seq_lengths[0], seq_lengths[1] // cp_group_size)
1082-
return block_mask
1083-
1084-
1085991
@contextlib.contextmanager
1086-
def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, None]:
992+
def _context_parallel_dispatcher(
993+
seq_dim: int, mesh: DeviceMesh
994+
) -> Generator[None, None, None]:
1087995
"""Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher."""
1088996

1089997
def attention_input_fn(
@@ -1185,7 +1093,7 @@ def __torch_function__(
11851093
attention_input_fn,
11861094
attention_output_fn,
11871095
)
1188-
with _enable_cp_dispatcher():
1096+
with _enable_cp_dtensor_dispatcher():
11891097
yield
11901098
_restore_function(F.scaled_dot_product_attention, F)
11911099
elif _dispatch_mode == _DispatchMode.TORCH_FUNCTION:
@@ -1200,7 +1108,7 @@ def __torch_function__(
12001108
_cp_global_vars.torch_function_mode = tf_mode
12011109

12021110
with tf_mode:
1203-
with _enable_cp_dispatcher():
1111+
with _enable_cp_dtensor_dispatcher():
12041112
yield
12051113
else:
12061114
raise NotImplementedError("torch dispatch mode is not supported yet.")
@@ -1270,6 +1178,9 @@ def _context_parallel_buffers(
12701178
return new_buffers
12711179

12721180

1181+
#####################################################
1182+
# Current public APIs, but are also subject to change
1183+
#####################################################
12731184
@contextlib.contextmanager
12741185
@torch.no_grad()
12751186
def context_parallel(
@@ -1343,7 +1254,7 @@ def context_parallel(
13431254
buffer.resize_(shard.shape)
13441255
buffer.copy_(shard)
13451256

1346-
with _context_parallel(seq_dim=2, mesh=mesh):
1257+
with _context_parallel_dispatcher(seq_dim=2, mesh=mesh):
13471258
yield
13481259

13491260
for buffer, original_buffer in zip(buffers, original_buffers):
@@ -1421,3 +1332,89 @@ def set_rotate_method(rotate_method: str) -> None:
14211332
"Context Parallel does not support "
14221333
f"using {rotate_method} for kv shards rotation"
14231334
)
1335+
1336+
1337+
def create_cp_block_mask(
1338+
mask_mod: _mask_mod_signature,
1339+
B: int,
1340+
H: int,
1341+
Q_LEN: int,
1342+
KV_LEN: int,
1343+
device_mesh: DeviceMesh,
1344+
) -> BlockMask:
1345+
"""
1346+
This API creates a special BlockMask for Context Parallel FlexAttention:
1347+
1. This BlockMask is masking on the attention of Q shard and KV global views, by
1348+
mapping the local q_idx to the global q_idx before sending to mask_mod.
1349+
2. The kv_seq_length (i.e. seq_lengths[1]) of this blockMask is tailored to match
1350+
the sequence length of KV shard instead of KV global. This is to pass the shape check
1351+
in flex_atttention(). The correct value (i.e. the sequence length of KV global) will be
1352+
used in flex_attention once the shape check passes.
1353+
1354+
Args:
1355+
mask_mod (Callable): Function to modify the mask over the global attention result.
1356+
B (int): Batch size.
1357+
H (int): Number of query heads.
1358+
Q_LEN (int): Sequence length of query (global view).
1359+
KV_LEN (int): Sequence length of key/value (global view).
1360+
device_mesh (:class:`DeviceMesh`): The device mesh for the context parallelism.
1361+
1362+
Return:
1363+
:class:`BlockMask`: the block_mask to be used in flex_attention() within the
1364+
context_parallel() context.
1365+
1366+
.. warning::
1367+
This function cannot generate correct block_mask if the BLOCK_SIZE is not
1368+
``_DEFAULT_SPARSE_BLOCK_SIZE`` which usually happens when the attention
1369+
size is smaller than 128. Please do not use context_parallel() when the
1370+
FlexAttention size is small.
1371+
"""
1372+
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE
1373+
1374+
compiled_create_block_mask = torch.compile(
1375+
create_block_mask, dynamic=False, fullgraph=True
1376+
)
1377+
1378+
def _rewrite_mask_mod(
1379+
mask_mod: _mask_mod_signature,
1380+
rank: int,
1381+
world_size: int,
1382+
block_size: int,
1383+
local_q_size: int,
1384+
) -> _mask_mod_signature:
1385+
def local_q_idx_to_q_idx(local_q_idx: torch.Tensor) -> torch.Tensor:
1386+
# calculate local block_idx and block_offset
1387+
local_blk_idx, local_blk_offset = (
1388+
local_q_idx // block_size,
1389+
local_q_idx % block_size,
1390+
)
1391+
# NOTE: load balancing is not used
1392+
local_num_blocks = local_q_size // block_size
1393+
blk_idx = local_num_blocks * rank + local_blk_idx
1394+
return blk_idx * block_size + local_blk_offset
1395+
1396+
return lambda b, h, q_idx, kv_idx: mask_mod(
1397+
b,
1398+
h,
1399+
local_q_idx_to_q_idx(q_idx),
1400+
kv_idx,
1401+
)
1402+
1403+
cp_rank = device_mesh.get_local_rank()
1404+
cp_group_size = device_mesh.size()
1405+
Q_SHARD_LEN = Q_LEN // cp_group_size
1406+
block_size = _DEFAULT_SPARSE_BLOCK_SIZE
1407+
block_mask = compiled_create_block_mask(
1408+
_rewrite_mask_mod(mask_mod, cp_rank, cp_group_size, block_size, Q_SHARD_LEN),
1409+
B,
1410+
H,
1411+
Q_SHARD_LEN,
1412+
KV_LEN,
1413+
device=device_mesh.device_type,
1414+
BLOCK_SIZE=(block_size, block_size),
1415+
)
1416+
# flex_attention function checks the following shape so we need to rewrite:
1417+
# key.size(-2) == block_mask.seq_lengths[1]
1418+
seq_lengths = block_mask.seq_lengths
1419+
block_mask.seq_lengths = (seq_lengths[0], seq_lengths[1] // cp_group_size)
1420+
return block_mask

0 commit comments

Comments
 (0)