26
26
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
27
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
28
29
- from typing import Optional , Tuple , Type , Union
29
+ import functools
30
+ from typing import Callable , Optional , Tuple , Type , Union , List
30
31
31
32
import cuda .bindings .driver as cuda
32
33
import cutlass
37
38
import cutlass .utils .blackwell_helpers as sm100_utils
38
39
import cutlass .utils .blockscaled_layout as blockscaled_utils
39
40
import cutlass .utils .distributed_helpers as distributed_helpers
40
- import torch
41
- import functools
42
41
from cutlass ._mlir import ir
43
42
from cutlass .cute .nvgpu import cpasync , tcgen05
44
43
from cutlass .cute .runtime import from_dlpack
45
-
46
44
from cutlass .cutlass_dsl import (
47
- Int32 ,
48
- Int64 ,
49
- Uint8 ,
50
- Uint64 ,
51
45
T ,
52
46
Integer ,
53
47
dsl_user_op ,
54
48
extract_mlir_values ,
55
49
new_from_mlir_values ,
56
50
)
57
-
58
51
from cutlass .cute .typing import (
59
52
Int32 ,
53
+ Int64 ,
54
+ Uint8 ,
55
+ Uint64 ,
60
56
Float16 ,
61
57
BFloat16 ,
62
58
Float32 ,
65
61
Tensor ,
66
62
)
67
63
from cutlass ._mlir .dialects import llvm
68
- from flashinfer .utils import get_compute_capability
69
64
from cutlass .utils .static_persistent_tile_scheduler import WorkTileInfo
65
+ import torch
66
+
67
+ from flashinfer .utils import get_compute_capability
70
68
from .utils import get_cutlass_dtype , cutlass_to_torch_dtype , get_num_sm , make_ptr
71
- from typing import Callable , List
72
69
73
70
74
71
sizeof_i32 = 4
@@ -1865,7 +1862,6 @@ def kernel(
1865
1862
# Allreduce
1866
1863
#
1867
1864
if cutlass .const_expr (self .all_reduce == "two_shot" ):
1868
-
1869
1865
tile_id = Int32 (
1870
1866
tile_sched ._current_work_linear_idx
1871
1867
* cute .size (self .cluster_shape_mn )
@@ -2950,13 +2946,15 @@ def __call__(
2950
2946
current_stream : cuda .CUstream ,
2951
2947
):
2952
2948
if cutlass .const_expr (self ._all_reduce != "none" ):
2953
- barrier_flag_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
2954
- self ._m ,
2955
- self ._n ,
2956
- self ._l ,
2957
- self ._mma_tiler_mn ,
2958
- self ._cluster_shape_mn ,
2959
- self ._max_active_clusters ,
2949
+ barrier_flag_size = (
2950
+ Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
2951
+ self ._m ,
2952
+ self ._n ,
2953
+ self ._l ,
2954
+ self ._mma_tiler_mn ,
2955
+ self ._cluster_shape_mn ,
2956
+ self ._max_active_clusters ,
2957
+ )
2960
2958
)
2961
2959
else :
2962
2960
barrier_flag_size = 1 # Dummy size when not used
@@ -2982,21 +2980,33 @@ def __call__(
2982
2980
order = (0 , 1 , 2 ) if self ._c_major == "m" else (1 , 0 , 2 ),
2983
2981
),
2984
2982
)
2985
- c_mc_tensor = cute .make_tensor (
2986
- c_mc_ptr ,
2987
- layout = cute .make_ordered_layout (
2988
- (self ._m , self ._n , self ._l ),
2989
- order = (0 , 1 , 2 ) if self ._c_major == "m" else (1 , 0 , 2 ),
2990
- ),
2991
- ) if c_mc_ptr is not None else None
2992
- barrier_flag_tensor = cute .make_tensor (
2993
- barrier_flag_ptr ,
2994
- layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2995
- ) if barrier_flag_ptr is not None else None
2996
- barrier_flag_mc_tensor = cute .make_tensor (
2997
- barrier_flag_mc_ptr ,
2998
- layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2999
- ) if barrier_flag_mc_ptr is not None else None
2983
+ c_mc_tensor = (
2984
+ cute .make_tensor (
2985
+ c_mc_ptr ,
2986
+ layout = cute .make_ordered_layout (
2987
+ (self ._m , self ._n , self ._l ),
2988
+ order = (0 , 1 , 2 ) if self ._c_major == "m" else (1 , 0 , 2 ),
2989
+ ),
2990
+ )
2991
+ if c_mc_ptr is not None
2992
+ else None
2993
+ )
2994
+ barrier_flag_tensor = (
2995
+ cute .make_tensor (
2996
+ barrier_flag_ptr ,
2997
+ layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2998
+ )
2999
+ if barrier_flag_ptr is not None
3000
+ else None
3001
+ )
3002
+ barrier_flag_mc_tensor = (
3003
+ cute .make_tensor (
3004
+ barrier_flag_mc_ptr ,
3005
+ layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
3006
+ )
3007
+ if barrier_flag_mc_ptr is not None
3008
+ else None
3009
+ )
3000
3010
3001
3011
# calculate sf_tensor shape and order
3002
3012
def ceil_div (a , b ):
@@ -3154,7 +3164,6 @@ def get_cute_pointers(
3154
3164
c_mc_data_ptr ,
3155
3165
barrier_flag_data_ptr ,
3156
3166
barrier_flag_mc_data_ptr ,
3157
-
3158
3167
) = (
3159
3168
a_tensor_gpu .data_ptr (),
3160
3169
b_tensor_gpu .data_ptr (),
@@ -3168,7 +3177,9 @@ def get_cute_pointers(
3168
3177
alpha_tensor_gpu .data_ptr () if alpha_tensor_gpu is not None else None ,
3169
3178
c_mc_gpu .data_ptr () if c_mc_gpu is not None else None ,
3170
3179
barrier_flag_gpu .data_ptr () if barrier_flag_gpu is not None else None ,
3171
- barrier_flag_mc_gpu .data_ptr () if barrier_flag_mc_gpu is not None else None ,
3180
+ barrier_flag_mc_gpu .data_ptr ()
3181
+ if barrier_flag_mc_gpu is not None
3182
+ else None ,
3172
3183
)
3173
3184
3174
3185
a_ptr = make_ptr (
0 commit comments