1+ import os
12from typing import Any , Optional , Tuple
23
34import numpy as np
@@ -23,6 +24,24 @@ def wait(self):
2324 self .remain_ops ()
2425
2526
27+ def process_group_is_intranode (pg ):
28+ if pg is None :
29+ from torch .distributed .distributed_c10d import _get_default_group
30+
31+ pg = _get_default_group ()
32+
33+ local_world_size = None
34+ for var in ["LOCAL_WORLD_SIZE" , "OMPI_COMM_WORLD_LOCAL_SIZE" , "SLURM_TASKS_PER_NODE" ]:
35+ if var in os .environ :
36+ local_world_size = int (os .environ ["LOCAL_WORLD_SIZE" ])
37+ if local_world_size is None :
38+ local_world_size = torch .cuda .device_count ()
39+
40+ group_ranks = dist .get_process_group_ranks (pg )
41+ group_ranks_node_ids = [rank // local_world_size for rank in group_ranks ]
42+ return min (group_ranks_node_ids ) == max (group_ranks_node_ids )
43+
44+
2645def cast_to_fp8 (
2746 inp : torch .Tensor , fp8_format = "e4m3" , per_channel_scale = False , out = None
2847) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -92,7 +111,7 @@ def cast_from_fp8(
92111 return ret .to (ret_type )
93112
94113
95- def all_reduce_fp8 (
114+ def _all_reduce_fp8 (
96115 tensor : torch .Tensor , fp8_format = "e4m3" , op = ReduceOp .SUM , group = None , async_op : bool = False
97116) -> Optional [Handle ]:
98117 r"""
@@ -159,7 +178,15 @@ def cat_op():
159178 cat_op ()
160179
161180
162- def all_to_all_single_fp8 (
181+ def all_reduce_fp8 (
182+ tensor : torch .Tensor , fp8_format = "e4m3" , op = ReduceOp .SUM , group = None , async_op : bool = False
183+ ) -> Optional [Handle ]:
184+ # fall back to default op due to performance issue
185+ return dist .all_reduce (tensor , op = op , group = group , async_op = async_op )
186+
187+
188+ @torch .compile (mode = "max-autotune-no-cudagraphs" , dynamic = False )
189+ def _all_to_all_single_fp8 (
163190 output , input , output_split_sizes = None , input_split_sizes = None , fp8_format = "e5m2" , group = None , async_op = False
164191) -> Optional [Handle ]:
165192 r"""
@@ -222,6 +249,33 @@ def cast_op():
222249 cast_op ()
223250
224251
252+ def all_to_all_single_fp8 (
253+ output , input , output_split_sizes = None , input_split_sizes = None , fp8_format = "e5m2" , group = None , async_op = False
254+ ) -> Optional [Handle ]:
255+ r"""
256+ This is wrapper for _all_to_all_single_fp8.
257+ """
258+ if process_group_is_intranode (group ):
259+ return dist .all_to_all_single (
260+ output ,
261+ input ,
262+ output_split_sizes = output_split_sizes ,
263+ input_split_sizes = input_split_sizes ,
264+ group = group ,
265+ async_op = async_op ,
266+ )
267+ else :
268+ return _all_to_all_single_fp8 (
269+ output ,
270+ input ,
271+ fp8_format = fp8_format ,
272+ output_split_sizes = output_split_sizes ,
273+ input_split_sizes = input_split_sizes ,
274+ group = group ,
275+ async_op = async_op ,
276+ )
277+
278+
225279def cast_to_fp8_pipeline (inp : Any ) -> None :
226280 """
227281 Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
@@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
293347 del inp ["dtype" ]
294348
295349
296- def reduce_scatter_fp8 (
350+ def _reduce_scatter_fp8 (
297351 output : torch .Tensor , input_list , group , fp8_format = "e5m2" , async_op : bool = False
298352) -> Optional [Handle ]:
299353 r"""
@@ -338,6 +392,13 @@ def cast_op():
338392 cast_op ()
339393
340394
395+ def reduce_scatter_fp8 (
396+ output : torch .Tensor , input_list , group , fp8_format = "e5m2" , async_op : bool = False
397+ ) -> Optional [Handle ]:
398+ # fall back to default op due to performance issue
399+ return dist .reduce_scatter (output , input_list , group = group , async_op = async_op )
400+
401+
341402def fp8_compress_ddp_grad_comm_hook_async (
342403 process_group : dist .ProcessGroup ,
343404 bucket : dist .GradBucket ,
@@ -617,10 +678,9 @@ def cast_op():
617678 cast_op ()
618679
619680
620- def all_to_all_fp8 ( output_list , input_list , group = None , fp8_format = "e5m2 " , async_op = False ):
621-
681+ @ torch . compile ( mode = "max-autotune-no-cudagraphs " , dynamic = False )
682+ def _all_to_all_fp8 ( output_list , input_list , group = None , fp8_format = "e5m2" , async_op = False ):
622683 world_size = dist .get_world_size (group )
623-
624684 input_type = input_list [0 ].dtype
625685 fp8_type = torch .float8_e4m3fn if fp8_format == "e4m3" else torch .float8_e5m2
626686 scale_list = []
@@ -651,6 +711,13 @@ def cast_op():
651711 cast_op ()
652712
653713
714+ def all_to_all_fp8 (output_list , input_list , group = None , fp8_format = "e5m2" , async_op = False ):
715+ if process_group_is_intranode (group ):
716+ return dist .all_to_all (output_list , input_list , group = group , async_op = async_op )
717+ else :
718+ return _all_to_all_fp8 (output_list , input_list , group = group , fp8_format = fp8_format , async_op = async_op )
719+
720+
654721def gather_fp8 (output_list , input_ , group = None , fp8_format = "e5m2" , async_op : bool = False ) -> Optional [Handle ]:
655722
656723 world_size = dist .get_world_size (group )
0 commit comments