Skip to content

Commit 2e96016

Browse files
committed
wip
1 parent a782f8b commit 2e96016

File tree

2 files changed

+89
-15
lines changed

2 files changed

+89
-15
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2625,6 +2625,60 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
26252625

26262626
return is_valid
26272627

2628+
@staticmethod
2629+
def compute_barrier_flag_size(
2630+
m: int,
2631+
n: int,
2632+
l: int,
2633+
mma_tiler_mn: Tuple[int, int],
2634+
cluster_shape_mn: Tuple[int, int],
2635+
sm_count: int,
2636+
) -> int:
2637+
"""
2638+
Compute the required size for barrier flag tensors used in all-reduce synchronization.
2639+
2640+
The barrier flags are used for:
2641+
1. Per-tile synchronization during the all-reduce phase
2642+
2. Final inter-GPU synchronization barrier
2643+
2644+
:param m: Number of rows in the output matrix
2645+
:type m: int
2646+
:param n: Number of columns in the output matrix
2647+
:type n: int
2648+
:param l: Batch size
2649+
:type l: int
2650+
:param mma_tiler_mn: Shape of the MMA tiler (M, N)
2651+
:type mma_tiler_mn: Tuple[int, int]
2652+
:param cluster_shape_mn: Cluster dimensions (M, N)
2653+
:type cluster_shape_mn: Tuple[int, int]
2654+
:param sm_count: Number of SMs available
2655+
:type sm_count: int
2656+
2657+
:return: Total number of barrier flags needed
2658+
:rtype: int
2659+
"""
2660+
# Calculate CTA tile shape accounting for 2-CTA instructions
2661+
use_2cta_instrs = mma_tiler_mn[0] == 256
2662+
cta_tile_shape_m = mma_tiler_mn[0] // (2 if use_2cta_instrs else 1)
2663+
cta_tile_shape_n = mma_tiler_mn[1]
2664+
2665+
# Calculate number of tiles per batch
2666+
num_tiles_m = (m + cta_tile_shape_m - 1) // cta_tile_shape_m
2667+
num_tiles_n = (n + cta_tile_shape_n - 1) // cta_tile_shape_n
2668+
num_tiles_per_batch = num_tiles_m * num_tiles_n
2669+
2670+
# Calculate number of clusters per batch
2671+
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
2672+
num_ctas_per_tile = cluster_size
2673+
2674+
# Total tiles across all batches and clusters
2675+
num_tiles = num_tiles_per_batch * l * num_ctas_per_tile
2676+
2677+
# Add extra space for final barrier (one per SM)
2678+
total_barrier_size = num_tiles + sm_count
2679+
2680+
return total_barrier_size
2681+
26282682
@staticmethod
26292683
def can_implement(
26302684
ab_dtype: Type[cutlass.Numeric],
@@ -2898,6 +2952,18 @@ def __call__(
28982952
barrier_flag_mc_ptr: Optional[cute.Pointer],
28992953
current_stream: cuda.CUstream,
29002954
):
2955+
if cutlass.const_expr(self._all_reduce != "none"):
2956+
barrier_flag_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
2957+
self._m,
2958+
self._n,
2959+
self._l,
2960+
self._mma_tiler_mn,
2961+
self._cluster_shape_mn,
2962+
self._max_active_clusters,
2963+
)
2964+
else:
2965+
barrier_flag_size = 1 # Dummy size when not used
2966+
29012967
a_tensor = cute.make_tensor(
29022968
a_ptr,
29032969
layout=cute.make_ordered_layout(
@@ -2931,11 +2997,11 @@ def __call__(
29312997
# on
29322998
barrier_flag_tensor = cute.make_tensor(
29332999
barrier_flag_ptr,
2934-
layout=cute.make_ordered_layout((404,), order=(0,)),
3000+
layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),
29353001
) if barrier_flag_ptr is not None else None
29363002
barrier_flag_mc_tensor = cute.make_tensor(
29373003
barrier_flag_mc_ptr,
2938-
layout=cute.make_ordered_layout((404,), order=(0,)),
3004+
layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),
29393005
) if barrier_flag_mc_ptr is not None else None
29403006

29413007
# calculate sf_tensor shape and order

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,23 +81,29 @@ def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True):
8181
)
8282
return cute_tensor, cute_tensor_mc, torch_tensor_gpu, torch_tensor_mc
8383

84-
def create_barrier_flags(m, n, l, mma_tiler_mn):
85-
# NOTE: use_2cta_instrs from blockedscaled_gemm logic
86-
use_2cta_instrs = mma_tiler_mn[0] == 256
87-
cta_tile_shape_mn = (
88-
mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
89-
mma_tiler_mn[1],
84+
def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
85+
barrier_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
86+
m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count
9087
)
91-
problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
92-
num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
93-
num_tiles = num_tiles_per_batch * l
94-
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
95-
88+
print("LOOK HERE",(barrier_size,))
89+
# NOTE: use_2cta_instrs from blockedscaled_gemm logic
90+
91+
# use_2cta_instrs = mma_tiler_mn[0] == 256
92+
# cta_tile_shape_mn = (
93+
# mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
94+
# mma_tiler_mn[1],
95+
# )
96+
# problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
97+
# num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
98+
# num_tiles = num_tiles_per_batch * l
99+
# num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
96100
# +num_sms for final barrier
101+
#num_tiles + num_sms
102+
97103
barrier_flag = symm_mem.empty(
98-
(num_tiles + num_sms,), device="cuda", dtype=torch.int32
104+
(barrier_size,), device="cuda", dtype=torch.int32
99105
)
100-
print("LOOK HERE",(num_tiles + num_sms,))
106+
101107
barrier_flag.fill_(0)
102108
symm = symm_mem.rendezvous(barrier_flag, group=dist.group.WORLD.group_name)
103109
barrier_flag_mc_ptr = symm.multicast_ptr
@@ -219,6 +225,8 @@ def run_blockscaled_gemm_all_reduce_python_interface(
219225
n,
220226
l,
221227
mma_tiler_mn,
228+
cluster_shape_mn,
229+
sm_count,
222230
)
223231
# for deepgemm-like python interface
224232
if ab_dtype == "float4_e2m1fn":

0 commit comments

Comments
 (0)