Skip to content

Commit edb75a0

Browse files
committed
wip
1 parent aeb1815 commit edb75a0

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,26 @@
5050
Uint64,
5151
T,
5252
Integer,
53-
dsl_user_op,
54-
extract_mlir_values,
55-
new_from_mlir_values,
56-
)
57-
# TODO(asamani): remove Int32 from above?
58-
from cutlass.cute.typing import (
59-
Int32,
6053
Float16,
6154
BFloat16,
6255
Float32,
6356
Float8E4M3FN,
6457
Float8E5M2,
6558
Tensor,
59+
dsl_user_op,
60+
extract_mlir_values,
61+
new_from_mlir_values,
6662
)
63+
64+
# from cutlass.cute.typing import (
65+
# Int32,
66+
# Float16,
67+
# BFloat16,
68+
# Float32,
69+
# Float8E4M3FN,
70+
# Float8E5M2,
71+
# Tensor,
72+
# )
6773
from cutlass._mlir.dialects import llvm
6874
from flashinfer.utils import get_compute_capability
6975
from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo
@@ -1871,13 +1877,10 @@ def kernel(
18711877
* cute.size(self.cluster_shape_mn)
18721878
+ cute.arch.block_idx_in_cluster()
18731879
)
1874-
#cute.printf(tile_id)
18751880
if warp_idx == self.epilog_warp_id[0]:
18761881
cute.arch.cp_async_bulk_wait_group(0, read=False)
18771882
# System barrier to make sure that data from each GPU is in memory before allreduce
18781883
with cute.arch.elect_one():
1879-
# cute.printf("EPILOGUE: rank=%d warp=%d tile_id=%d num_executed=%d\n",
1880-
# self.rank_id, warp_idx, tile_id, tile_sched.num_tiles_executed)
18811884
flag = barrier_flag_mc.iterator + tile_id
18821885
cute.arch.fence_acq_rel_gpu()
18831886
distributed_helpers.spin_lock_multimem_arrive(flag)
@@ -1996,8 +1999,6 @@ def kernel(
19961999
# System barrier to make sure that data from each GPU is in memory before allreduce
19972000
if warp_idx == self.all_reduce_warp_id[0]:
19982001
with cute.arch.elect_one():
1999-
# cute.printf("ALLREDUCE: rank=%d warp=%d tile_id=%d num_executed=%d\n",
2000-
# self.rank_id, warp_idx, tile_id, tile_sched.num_tiles_executed)
20012002
flag = barrier_flag.iterator + tile_id
20022003
# TODO: we may use LDG+STG for spin lock instead of ATOMIC_CAS for better performance.
20032004
distributed_helpers.spin_lock_wait(flag, num_ranks)
@@ -2695,6 +2696,7 @@ def can_implement(
26952696
b_major: str,
26962697
c_major: str,
26972698
all_reduce: str = "none",
2699+
process_group: Optional[torch.distributed.ProcessGroup] = None,
26982700
) -> bool:
26992701
"""
27002702
Check if the gemm can be implemented
@@ -2753,10 +2755,10 @@ def can_implement(
27532755
):
27542756
can_implement = False
27552757

2756-
# check for all reduce constraints
2757-
# TODO(asamani): expand the logic for mnnvl support
2758+
# Check for all reduce constraints
27582759
if all_reduce != "none":
2759-
if torch.distributed.get_world_size() not in [2, 4, 8]:
2760+
# TODO(asamani): expand the logic for mnnvl support
2761+
if torch.distributed.get_world_size(process_group) not in [2, 4, 8]:
27602762
can_implement = False
27612763
return can_implement
27622764

@@ -2993,9 +2995,6 @@ def __call__(
29932995
order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2),
29942996
),
29952997
) if c_mc_ptr is not None else None
2996-
#TODO(asamani): urgent fix this is just for dev
2997-
# this should be calculated based on how many total tiles we need to work
2998-
# on
29992998
barrier_flag_tensor = cute.make_tensor(
30002999
barrier_flag_ptr,
30013000
layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),

0 commit comments

Comments
 (0)