Skip to content

Commit 5ce6dd7

Browse files
[fp8] disable all_to_all_fp8 in intranode (#6045)
* enhance all_to_all_fp8 with internode comm control * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable some fp8 ops due to performance issue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 26e5539 commit 5ce6dd7

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

colossalai/quantization/fp8.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Any, Optional, Tuple
23

34
import numpy as np
@@ -23,6 +24,24 @@ def wait(self):
2324
self.remain_ops()
2425

2526

27+
def process_group_is_intranode(pg):
28+
if pg is None:
29+
from torch.distributed.distributed_c10d import _get_default_group
30+
31+
pg = _get_default_group()
32+
33+
local_world_size = None
34+
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
35+
if var in os.environ:
36+
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
37+
if local_world_size is None:
38+
local_world_size = torch.cuda.device_count()
39+
40+
group_ranks = dist.get_process_group_ranks(pg)
41+
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
42+
return min(group_ranks_node_ids) == max(group_ranks_node_ids)
43+
44+
2645
def cast_to_fp8(
2746
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
2847
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -92,7 +111,7 @@ def cast_from_fp8(
92111
return ret.to(ret_type)
93112

94113

95-
def all_reduce_fp8(
114+
def _all_reduce_fp8(
96115
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
97116
) -> Optional[Handle]:
98117
r"""
@@ -159,7 +178,15 @@ def cat_op():
159178
cat_op()
160179

161180

162-
def all_to_all_single_fp8(
181+
def all_reduce_fp8(
182+
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
183+
) -> Optional[Handle]:
184+
# fall back to default op due to performance issue
185+
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)
186+
187+
188+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
189+
def _all_to_all_single_fp8(
163190
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
164191
) -> Optional[Handle]:
165192
r"""
@@ -222,6 +249,33 @@ def cast_op():
222249
cast_op()
223250

224251

252+
def all_to_all_single_fp8(
253+
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
254+
) -> Optional[Handle]:
255+
r"""
256+
This is wrapper for _all_to_all_single_fp8.
257+
"""
258+
if process_group_is_intranode(group):
259+
return dist.all_to_all_single(
260+
output,
261+
input,
262+
output_split_sizes=output_split_sizes,
263+
input_split_sizes=input_split_sizes,
264+
group=group,
265+
async_op=async_op,
266+
)
267+
else:
268+
return _all_to_all_single_fp8(
269+
output,
270+
input,
271+
fp8_format=fp8_format,
272+
output_split_sizes=output_split_sizes,
273+
input_split_sizes=input_split_sizes,
274+
group=group,
275+
async_op=async_op,
276+
)
277+
278+
225279
def cast_to_fp8_pipeline(inp: Any) -> None:
226280
"""
227281
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
@@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
293347
del inp["dtype"]
294348

295349

296-
def reduce_scatter_fp8(
350+
def _reduce_scatter_fp8(
297351
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
298352
) -> Optional[Handle]:
299353
r"""
@@ -338,6 +392,13 @@ def cast_op():
338392
cast_op()
339393

340394

395+
def reduce_scatter_fp8(
396+
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
397+
) -> Optional[Handle]:
398+
# fall back to default op due to performance issue
399+
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)
400+
401+
341402
def fp8_compress_ddp_grad_comm_hook_async(
342403
process_group: dist.ProcessGroup,
343404
bucket: dist.GradBucket,
@@ -617,10 +678,9 @@ def cast_op():
617678
cast_op()
618679

619680

620-
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
621-
681+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
682+
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
622683
world_size = dist.get_world_size(group)
623-
624684
input_type = input_list[0].dtype
625685
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
626686
scale_list = []
@@ -651,6 +711,13 @@ def cast_op():
651711
cast_op()
652712

653713

714+
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
715+
if process_group_is_intranode(group):
716+
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
717+
else:
718+
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)
719+
720+
654721
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
655722

656723
world_size = dist.get_world_size(group)

0 commit comments

Comments
 (0)