@@ -55,7 +55,7 @@ class _ContextParallelOptions:
55
55
# errors. It is likely this is always True but we currently keep this variable
56
56
# for the experimental purpose.
57
57
convert_to_f32 : bool = True
58
- enable_load_balance = True
58
+ enable_load_balance : bool = True
59
59
rotate_method : _RotateMethod = _RotateMethod .ALL_GATHER
60
60
61
61
@@ -924,18 +924,10 @@ def _distribute_function(
924
924
output_fn : Optional [Callable ] = None ,
925
925
) -> None :
926
926
"""
927
- ``distribute_function`` is an experimental API that allows users to "distribute"
928
- the inputs and outputs of a function. Similar to ``distribute_module``, this API
929
- installs hooks to the ``fn`` to convert the inputs and outputs. There are two
930
- major differences between ``distribute_function`` and ``distribute_module``.
931
- First, a function does not have parameters and buffers, as a result,
932
- ``distribute_function`` itself won't convert any parameters/buffers but simply
933
- install the input and output hooks. The tensor conversion will happen in the hooks.
934
- Another difference is an nn.Module subclass can have several instances and each
935
- instance be fed into ``distribute_module`` independently with affecting other
936
- instance. On the other hand, function is a singleton object. So if a function
937
- is distributed by ``distribute_function`` all subsequent calls to the function
938
- will invoke the installed hooks.
927
+ A helper function to replace a function with a distributed version by
928
+ using the monkey patching approach.
929
+
930
+ This function is for the CP internal usage only.
939
931
940
932
Args:
941
933
fn (Callable): the function to be distributed.
@@ -986,7 +978,7 @@ def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
986
978
987
979
988
980
@contextlib .contextmanager
989
- def _enable_cp_dispatcher () -> Generator [None , None , None ]:
981
+ def _enable_cp_dtensor_dispatcher () -> Generator [None , None , None ]:
990
982
"""Enables DTensor dispatcher to dispatch SDPA to CP."""
991
983
old_handlers = DTensor ._op_dispatcher ._custom_op_handlers
992
984
DTensor ._op_dispatcher ._custom_op_handlers = {** old_handlers , ** customized_ops }
@@ -996,94 +988,10 @@ def _enable_cp_dispatcher() -> Generator[None, None, None]:
996
988
DTensor ._op_dispatcher ._custom_op_handlers = old_handlers
997
989
998
990
999
- def create_cp_block_mask (
1000
- mask_mod : _mask_mod_signature ,
1001
- B : int ,
1002
- H : int ,
1003
- Q_LEN : int ,
1004
- KV_LEN : int ,
1005
- device_mesh : DeviceMesh ,
1006
- ) -> BlockMask :
1007
- """
1008
- This API creates a special BlockMask for Context Parallel FlexAttention:
1009
- 1. This BlockMask is masking on the attention of Q shard and KV global views, by
1010
- mapping the local q_idx to the global q_idx before sending to mask_mod.
1011
- 2. The kv_seq_length (i.e. seq_lengths[1]) of this blockMask is tailored to match
1012
- the sequence length of KV shard instead of KV global. This is to pass the shape check
1013
- in flex_atttention(). The correct value (i.e. the sequence length of KV global) will be
1014
- used in flex_attention once the shape check passes.
1015
-
1016
- Args:
1017
- mask_mod (Callable): Function to modify the mask over the global attention result.
1018
- B (int): Batch size.
1019
- H (int): Number of query heads.
1020
- Q_LEN (int): Sequence length of query (global view).
1021
- KV_LEN (int): Sequence length of key/value (global view).
1022
- device_mesh (:class:`DeviceMesh`): The device mesh for the context parallelism.
1023
-
1024
- Return:
1025
- :class:`BlockMask`: the block_mask to be used in flex_attention() within the
1026
- context_parallel() context.
1027
-
1028
- .. warning::
1029
- This function cannot generate correct block_mask if the BLOCK_SIZE is not
1030
- ``_DEFAULT_SPARSE_BLOCK_SIZE`` which usually happens when the attention
1031
- size is smaller than 128. Please do not use context_parallel() when the
1032
- FlexAttention size is small.
1033
- """
1034
- from torch .nn .attention .flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE
1035
-
1036
- compiled_create_block_mask = torch .compile (
1037
- create_block_mask , dynamic = False , fullgraph = True
1038
- )
1039
-
1040
- def _rewrite_mask_mod (
1041
- mask_mod : _mask_mod_signature ,
1042
- rank : int ,
1043
- world_size : int ,
1044
- block_size : int ,
1045
- local_q_size : int ,
1046
- ) -> _mask_mod_signature :
1047
- def local_q_idx_to_q_idx (local_q_idx : torch .Tensor ) -> torch .Tensor :
1048
- # calculate local block_idx and block_offset
1049
- local_blk_idx , local_blk_offset = (
1050
- local_q_idx // block_size ,
1051
- local_q_idx % block_size ,
1052
- )
1053
- # NOTE: load balancing is not used
1054
- local_num_blocks = local_q_size // block_size
1055
- blk_idx = local_num_blocks * rank + local_blk_idx
1056
- return blk_idx * block_size + local_blk_offset
1057
-
1058
- return lambda b , h , q_idx , kv_idx : mask_mod (
1059
- b ,
1060
- h ,
1061
- local_q_idx_to_q_idx (q_idx ),
1062
- kv_idx ,
1063
- )
1064
-
1065
- cp_rank = device_mesh .get_local_rank ()
1066
- cp_group_size = device_mesh .size ()
1067
- Q_SHARD_LEN = Q_LEN // cp_group_size
1068
- block_size = _DEFAULT_SPARSE_BLOCK_SIZE
1069
- block_mask = compiled_create_block_mask (
1070
- _rewrite_mask_mod (mask_mod , cp_rank , cp_group_size , block_size , Q_SHARD_LEN ),
1071
- B ,
1072
- H ,
1073
- Q_SHARD_LEN ,
1074
- KV_LEN ,
1075
- device = device_mesh .device_type ,
1076
- BLOCK_SIZE = (block_size , block_size ),
1077
- )
1078
- # flex_attention function checks the following shape so we need to rewrite:
1079
- # key.size(-2) == block_mask.seq_lengths[1]
1080
- seq_lengths = block_mask .seq_lengths
1081
- block_mask .seq_lengths = (seq_lengths [0 ], seq_lengths [1 ] // cp_group_size )
1082
- return block_mask
1083
-
1084
-
1085
991
@contextlib .contextmanager
1086
- def _context_parallel (seq_dim : int , mesh : DeviceMesh ) -> Generator [None , None , None ]:
992
+ def _context_parallel_dispatcher (
993
+ seq_dim : int , mesh : DeviceMesh
994
+ ) -> Generator [None , None , None ]:
1087
995
"""Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher."""
1088
996
1089
997
def attention_input_fn (
@@ -1185,7 +1093,7 @@ def __torch_function__(
1185
1093
attention_input_fn ,
1186
1094
attention_output_fn ,
1187
1095
)
1188
- with _enable_cp_dispatcher ():
1096
+ with _enable_cp_dtensor_dispatcher ():
1189
1097
yield
1190
1098
_restore_function (F .scaled_dot_product_attention , F )
1191
1099
elif _dispatch_mode == _DispatchMode .TORCH_FUNCTION :
@@ -1200,7 +1108,7 @@ def __torch_function__(
1200
1108
_cp_global_vars .torch_function_mode = tf_mode
1201
1109
1202
1110
with tf_mode :
1203
- with _enable_cp_dispatcher ():
1111
+ with _enable_cp_dtensor_dispatcher ():
1204
1112
yield
1205
1113
else :
1206
1114
raise NotImplementedError ("torch dispatch mode is not supported yet." )
@@ -1270,6 +1178,9 @@ def _context_parallel_buffers(
1270
1178
return new_buffers
1271
1179
1272
1180
1181
+ #####################################################
1182
+ # Current public APIs, but are also subject to change
1183
+ #####################################################
1273
1184
@contextlib .contextmanager
1274
1185
@torch .no_grad ()
1275
1186
def context_parallel (
@@ -1343,7 +1254,7 @@ def context_parallel(
1343
1254
buffer .resize_ (shard .shape )
1344
1255
buffer .copy_ (shard )
1345
1256
1346
- with _context_parallel (seq_dim = 2 , mesh = mesh ):
1257
+ with _context_parallel_dispatcher (seq_dim = 2 , mesh = mesh ):
1347
1258
yield
1348
1259
1349
1260
for buffer , original_buffer in zip (buffers , original_buffers ):
@@ -1421,3 +1332,89 @@ def set_rotate_method(rotate_method: str) -> None:
1421
1332
"Context Parallel does not support "
1422
1333
f"using { rotate_method } for kv shards rotation"
1423
1334
)
1335
+
1336
+
1337
+ def create_cp_block_mask (
1338
+ mask_mod : _mask_mod_signature ,
1339
+ B : int ,
1340
+ H : int ,
1341
+ Q_LEN : int ,
1342
+ KV_LEN : int ,
1343
+ device_mesh : DeviceMesh ,
1344
+ ) -> BlockMask :
1345
+ """
1346
+ This API creates a special BlockMask for Context Parallel FlexAttention:
1347
+ 1. This BlockMask is masking on the attention of Q shard and KV global views, by
1348
+ mapping the local q_idx to the global q_idx before sending to mask_mod.
1349
+ 2. The kv_seq_length (i.e. seq_lengths[1]) of this blockMask is tailored to match
1350
+ the sequence length of KV shard instead of KV global. This is to pass the shape check
1351
+ in flex_atttention(). The correct value (i.e. the sequence length of KV global) will be
1352
+ used in flex_attention once the shape check passes.
1353
+
1354
+ Args:
1355
+ mask_mod (Callable): Function to modify the mask over the global attention result.
1356
+ B (int): Batch size.
1357
+ H (int): Number of query heads.
1358
+ Q_LEN (int): Sequence length of query (global view).
1359
+ KV_LEN (int): Sequence length of key/value (global view).
1360
+ device_mesh (:class:`DeviceMesh`): The device mesh for the context parallelism.
1361
+
1362
+ Return:
1363
+ :class:`BlockMask`: the block_mask to be used in flex_attention() within the
1364
+ context_parallel() context.
1365
+
1366
+ .. warning::
1367
+ This function cannot generate correct block_mask if the BLOCK_SIZE is not
1368
+ ``_DEFAULT_SPARSE_BLOCK_SIZE`` which usually happens when the attention
1369
+ size is smaller than 128. Please do not use context_parallel() when the
1370
+ FlexAttention size is small.
1371
+ """
1372
+ from torch .nn .attention .flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE
1373
+
1374
+ compiled_create_block_mask = torch .compile (
1375
+ create_block_mask , dynamic = False , fullgraph = True
1376
+ )
1377
+
1378
+ def _rewrite_mask_mod (
1379
+ mask_mod : _mask_mod_signature ,
1380
+ rank : int ,
1381
+ world_size : int ,
1382
+ block_size : int ,
1383
+ local_q_size : int ,
1384
+ ) -> _mask_mod_signature :
1385
+ def local_q_idx_to_q_idx (local_q_idx : torch .Tensor ) -> torch .Tensor :
1386
+ # calculate local block_idx and block_offset
1387
+ local_blk_idx , local_blk_offset = (
1388
+ local_q_idx // block_size ,
1389
+ local_q_idx % block_size ,
1390
+ )
1391
+ # NOTE: load balancing is not used
1392
+ local_num_blocks = local_q_size // block_size
1393
+ blk_idx = local_num_blocks * rank + local_blk_idx
1394
+ return blk_idx * block_size + local_blk_offset
1395
+
1396
+ return lambda b , h , q_idx , kv_idx : mask_mod (
1397
+ b ,
1398
+ h ,
1399
+ local_q_idx_to_q_idx (q_idx ),
1400
+ kv_idx ,
1401
+ )
1402
+
1403
+ cp_rank = device_mesh .get_local_rank ()
1404
+ cp_group_size = device_mesh .size ()
1405
+ Q_SHARD_LEN = Q_LEN // cp_group_size
1406
+ block_size = _DEFAULT_SPARSE_BLOCK_SIZE
1407
+ block_mask = compiled_create_block_mask (
1408
+ _rewrite_mask_mod (mask_mod , cp_rank , cp_group_size , block_size , Q_SHARD_LEN ),
1409
+ B ,
1410
+ H ,
1411
+ Q_SHARD_LEN ,
1412
+ KV_LEN ,
1413
+ device = device_mesh .device_type ,
1414
+ BLOCK_SIZE = (block_size , block_size ),
1415
+ )
1416
+ # flex_attention function checks the following shape so we need to rewrite:
1417
+ # key.size(-2) == block_mask.seq_lengths[1]
1418
+ seq_lengths = block_mask .seq_lengths
1419
+ block_mask .seq_lengths = (seq_lengths [0 ], seq_lengths [1 ] // cp_group_size )
1420
+ return block_mask
0 commit comments