Skip to content

Commit 6ad4d90

Browse files
committed
precommit
1 parent 415da39 commit 6ad4d90

File tree

2 files changed

+53
-44
lines changed

2 files changed

+53
-44
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29-
from typing import Optional, Tuple, Type, Union
29+
import functools
30+
from typing import Callable, Optional, Tuple, Type, Union, List
3031

3132
import cuda.bindings.driver as cuda
3233
import cutlass
@@ -37,26 +38,21 @@
3738
import cutlass.utils.blackwell_helpers as sm100_utils
3839
import cutlass.utils.blockscaled_layout as blockscaled_utils
3940
import cutlass.utils.distributed_helpers as distributed_helpers
40-
import torch
41-
import functools
4241
from cutlass._mlir import ir
4342
from cutlass.cute.nvgpu import cpasync, tcgen05
4443
from cutlass.cute.runtime import from_dlpack
45-
4644
from cutlass.cutlass_dsl import (
47-
Int32,
48-
Int64,
49-
Uint8,
50-
Uint64,
5145
T,
5246
Integer,
5347
dsl_user_op,
5448
extract_mlir_values,
5549
new_from_mlir_values,
5650
)
57-
5851
from cutlass.cute.typing import (
5952
Int32,
53+
Int64,
54+
Uint8,
55+
Uint64,
6056
Float16,
6157
BFloat16,
6258
Float32,
@@ -65,10 +61,11 @@
6561
Tensor,
6662
)
6763
from cutlass._mlir.dialects import llvm
68-
from flashinfer.utils import get_compute_capability
6964
from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo
65+
import torch
66+
67+
from flashinfer.utils import get_compute_capability
7068
from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr
71-
from typing import Callable, List
7269

7370

7471
sizeof_i32 = 4
@@ -1865,7 +1862,6 @@ def kernel(
18651862
# Allreduce
18661863
#
18671864
if cutlass.const_expr(self.all_reduce == "two_shot"):
1868-
18691865
tile_id = Int32(
18701866
tile_sched._current_work_linear_idx
18711867
* cute.size(self.cluster_shape_mn)
@@ -2950,13 +2946,15 @@ def __call__(
29502946
current_stream: cuda.CUstream,
29512947
):
29522948
if cutlass.const_expr(self._all_reduce != "none"):
2953-
barrier_flag_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
2954-
self._m,
2955-
self._n,
2956-
self._l,
2957-
self._mma_tiler_mn,
2958-
self._cluster_shape_mn,
2959-
self._max_active_clusters,
2949+
barrier_flag_size = (
2950+
Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
2951+
self._m,
2952+
self._n,
2953+
self._l,
2954+
self._mma_tiler_mn,
2955+
self._cluster_shape_mn,
2956+
self._max_active_clusters,
2957+
)
29602958
)
29612959
else:
29622960
barrier_flag_size = 1 # Dummy size when not used
@@ -2982,21 +2980,33 @@ def __call__(
29822980
order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2),
29832981
),
29842982
)
2985-
c_mc_tensor = cute.make_tensor(
2986-
c_mc_ptr,
2987-
layout=cute.make_ordered_layout(
2988-
(self._m, self._n, self._l),
2989-
order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2),
2990-
),
2991-
) if c_mc_ptr is not None else None
2992-
barrier_flag_tensor = cute.make_tensor(
2993-
barrier_flag_ptr,
2994-
layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),
2995-
) if barrier_flag_ptr is not None else None
2996-
barrier_flag_mc_tensor = cute.make_tensor(
2997-
barrier_flag_mc_ptr,
2998-
layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),
2999-
) if barrier_flag_mc_ptr is not None else None
2983+
c_mc_tensor = (
2984+
cute.make_tensor(
2985+
c_mc_ptr,
2986+
layout=cute.make_ordered_layout(
2987+
(self._m, self._n, self._l),
2988+
order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2),
2989+
),
2990+
)
2991+
if c_mc_ptr is not None
2992+
else None
2993+
)
2994+
barrier_flag_tensor = (
2995+
cute.make_tensor(
2996+
barrier_flag_ptr,
2997+
layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),
2998+
)
2999+
if barrier_flag_ptr is not None
3000+
else None
3001+
)
3002+
barrier_flag_mc_tensor = (
3003+
cute.make_tensor(
3004+
barrier_flag_mc_ptr,
3005+
layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)),
3006+
)
3007+
if barrier_flag_mc_ptr is not None
3008+
else None
3009+
)
30003010

30013011
# calculate sf_tensor shape and order
30023012
def ceil_div(a, b):
@@ -3154,7 +3164,6 @@ def get_cute_pointers(
31543164
c_mc_data_ptr,
31553165
barrier_flag_data_ptr,
31563166
barrier_flag_mc_data_ptr,
3157-
31583167
) = (
31593168
a_tensor_gpu.data_ptr(),
31603169
b_tensor_gpu.data_ptr(),
@@ -3168,7 +3177,9 @@ def get_cute_pointers(
31683177
alpha_tensor_gpu.data_ptr() if alpha_tensor_gpu is not None else None,
31693178
c_mc_gpu.data_ptr() if c_mc_gpu is not None else None,
31703179
barrier_flag_gpu.data_ptr() if barrier_flag_gpu is not None else None,
3171-
barrier_flag_mc_gpu.data_ptr() if barrier_flag_mc_gpu is not None else None,
3180+
barrier_flag_mc_gpu.data_ptr()
3181+
if barrier_flag_mc_gpu is not None
3182+
else None,
31723183
)
31733184

31743185
a_ptr = make_ptr(

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
import logging
21
import multiprocessing as mp
32
import pytest
43
import socket
5-
from typing import Any, Tuple, Type
4+
from typing import Any, Tuple
65

7-
from cuda import cuda
86
import cutlass
97
import cutlass.cute as cute
10-
import cutlass.cute.testing as testing
118
from cutlass.cute.runtime import from_dlpack
129
import cutlass.torch as cutlass_torch
13-
import cutlass.utils as utils
1410

1511
import torch
1612
import torch.distributed as dist
@@ -130,6 +126,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
130126
enable_dst_signals: int,
131127
all_reduce: str,
132128
rank: int,
129+
world_size: int,
133130
):
134131
torch.manual_seed(42)
135132
device = torch.device("cuda", rank)
@@ -395,6 +392,7 @@ def _run_correctness_worker(
395392
enable_dst_signals=enable_dst_signals,
396393
all_reduce=all_reduce,
397394
rank=rank,
395+
world_size=world_size,
398396
)
399397
except Exception as e:
400398
print(f"Rank {rank_id}: Exception during test: {e}")
@@ -430,9 +428,9 @@ def multi_process_parallel(
430428

431429
for i in range(world_size):
432430
procs[i].join()
433-
assert (
434-
procs[i].exitcode == 0
435-
), f"Process {i} failed with exit code {procs[i].exitcode}"
431+
assert procs[i].exitcode == 0, (
432+
f"Process {i} failed with exit code {procs[i].exitcode}"
433+
)
436434

437435

438436
@pytest.mark.skipif(

0 commit comments

Comments
 (0)