|
50 | 50 | Uint64,
|
51 | 51 | T,
|
52 | 52 | Integer,
|
53 |
| - dsl_user_op, |
54 |
| - extract_mlir_values, |
55 |
| - new_from_mlir_values, |
56 |
| -) |
57 |
| -# TODO(asamani): remove Int32 from above? |
58 |
| -from cutlass.cute.typing import ( |
59 |
| - Int32, |
60 | 53 | Float16,
|
61 | 54 | BFloat16,
|
62 | 55 | Float32,
|
63 | 56 | Float8E4M3FN,
|
64 | 57 | Float8E5M2,
|
65 | 58 | Tensor,
|
| 59 | + dsl_user_op, |
| 60 | + extract_mlir_values, |
| 61 | + new_from_mlir_values, |
66 | 62 | )
|
| 63 | + |
| 64 | +# from cutlass.cute.typing import ( |
| 65 | +# Int32, |
| 66 | +# Float16, |
| 67 | +# BFloat16, |
| 68 | +# Float32, |
| 69 | +# Float8E4M3FN, |
| 70 | +# Float8E5M2, |
| 71 | +# Tensor, |
| 72 | +# ) |
67 | 73 | from cutlass._mlir.dialects import llvm
|
68 | 74 | from flashinfer.utils import get_compute_capability
|
69 | 75 | from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo
|
@@ -1871,13 +1877,10 @@ def kernel(
|
1871 | 1877 | * cute.size(self.cluster_shape_mn)
|
1872 | 1878 | + cute.arch.block_idx_in_cluster()
|
1873 | 1879 | )
|
1874 |
| - #cute.printf(tile_id) |
1875 | 1880 | if warp_idx == self.epilog_warp_id[0]:
|
1876 | 1881 | cute.arch.cp_async_bulk_wait_group(0, read=False)
|
1877 | 1882 | # System barrier to make sure that data from each GPU is in memory before allreduce
|
1878 | 1883 | with cute.arch.elect_one():
|
1879 |
| - # cute.printf("EPILOGUE: rank=%d warp=%d tile_id=%d num_executed=%d\n", |
1880 |
| - # self.rank_id, warp_idx, tile_id, tile_sched.num_tiles_executed) |
1881 | 1884 | flag = barrier_flag_mc.iterator + tile_id
|
1882 | 1885 | cute.arch.fence_acq_rel_gpu()
|
1883 | 1886 | distributed_helpers.spin_lock_multimem_arrive(flag)
|
@@ -1996,8 +1999,6 @@ def kernel(
|
1996 | 1999 | # System barrier to make sure that data from each GPU is in memory before allreduce
|
1997 | 2000 | if warp_idx == self.all_reduce_warp_id[0]:
|
1998 | 2001 | with cute.arch.elect_one():
|
1999 |
| - # cute.printf("ALLREDUCE: rank=%d warp=%d tile_id=%d num_executed=%d\n", |
2000 |
| - # self.rank_id, warp_idx, tile_id, tile_sched.num_tiles_executed) |
2001 | 2002 | flag = barrier_flag.iterator + tile_id
|
2002 | 2003 | # TODO: we may use LDG+STG for spin lock instead of ATOMIC_CAS for better performance.
|
2003 | 2004 | distributed_helpers.spin_lock_wait(flag, num_ranks)
|
@@ -2695,6 +2696,7 @@ def can_implement(
|
2695 | 2696 | b_major: str,
|
2696 | 2697 | c_major: str,
|
2697 | 2698 | all_reduce: str = "none",
|
| 2699 | + process_group: Optional[torch.distributed.ProcessGroup] = None, |
2698 | 2700 | ) -> bool:
|
2699 | 2701 | """
|
2700 | 2702 | Check if the gemm can be implemented
|
@@ -2753,10 +2755,10 @@ def can_implement(
|
2753 | 2755 | ):
|
2754 | 2756 | can_implement = False
|
2755 | 2757 |
|
2756 |
| - # check for all reduce constraints |
2757 |
| - # TODO(asamani): expand the logic for mnnvl support |
| 2758 | + # Check for all reduce constraints |
2758 | 2759 | if all_reduce != "none":
|
2759 |
| - if torch.distributed.get_world_size() not in [2, 4, 8]: |
| 2760 | + # TODO(asamani): expand the logic for mnnvl support |
| 2761 | + if torch.distributed.get_world_size(process_group) not in [2, 4, 8]: |
2760 | 2762 | can_implement = False
|
2761 | 2763 | return can_implement
|
2762 | 2764 |
|
@@ -2993,9 +2995,6 @@ def __call__(
|
2993 | 2995 | order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2),
|
2994 | 2996 | ),
|
2995 | 2997 | ) if c_mc_ptr is not None else None
|
2996 |
| - #TODO(asamani): urgent fix this is just for dev |
2997 |
| - # this should be calculated based on how many total tiles we need to work |
2998 |
| - # on |
2999 | 2998 | barrier_flag_tensor = cute.make_tensor(
|
3000 | 2999 | barrier_flag_ptr,
|
3001 | 3000 | layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),
|
|
0 commit comments