29
29
30
30
def create_mc_tensor (torch_tensor_cpu , dtype , is_dynamic_layout = True ):
31
31
m , n , l = torch_tensor_cpu .shape
32
-
32
+
33
33
# Create flat symm_mem buffer
34
34
total_elements = m * n * l
35
35
torch_symm_flat = symm_mem .empty (
36
36
(total_elements ,), device = "cuda" , dtype = torch_tensor_cpu .dtype
37
37
)
38
-
38
+
39
39
# Reshape to match input's stride pattern using as_strided
40
40
torch_symm_tensor = torch_symm_flat .as_strided (
41
- size = torch_tensor_cpu .shape ,
42
- stride = torch_tensor_cpu .stride ()
41
+ size = torch_tensor_cpu .shape , stride = torch_tensor_cpu .stride ()
43
42
)
44
43
torch_symm_tensor .copy_ (torch_tensor_cpu )
45
-
44
+
46
45
symm = symm_mem .rendezvous (torch_symm_flat , group = dist .group .WORLD .group_name )
47
46
mc_ptr = symm .multicast_ptr
48
-
47
+
49
48
# Create MC tensor with same stride
50
- torch_tensor_mc_flat = cutlass_torch .as_tensor (mc_ptr , (total_elements ,), torch_tensor_cpu .dtype )
49
+ torch_tensor_mc_flat = cutlass_torch .as_tensor (
50
+ mc_ptr , (total_elements ,), torch_tensor_cpu .dtype
51
+ )
51
52
torch_tensor_mc = torch_tensor_mc_flat .as_strided (
52
- size = torch_tensor_cpu .shape ,
53
- stride = torch_tensor_cpu .stride ()
53
+ size = torch_tensor_cpu .shape , stride = torch_tensor_cpu .stride ()
54
54
)
55
-
55
+
56
56
cute_tensor_mc = from_dlpack (torch_tensor_mc , assumed_align = 16 )
57
-
57
+
58
58
if is_dynamic_layout :
59
59
for i , stride in enumerate (torch_tensor_mc .stride ()):
60
60
if stride == 1 :
61
61
leading_dim = i
62
62
break
63
63
cute_tensor_mc = cute_tensor_mc .mark_layout_dynamic (leading_dim = leading_dim )
64
-
64
+
65
65
torch_tensor_gpu = torch_symm_tensor
66
66
cute_tensor = from_dlpack (torch_tensor_gpu , assumed_align = 16 )
67
67
cute_tensor .element_type = dtype
68
-
68
+
69
69
if is_dynamic_layout :
70
70
for i , stride in enumerate (torch_tensor_gpu .stride ()):
71
71
if stride == 1 :
72
72
leading_dim = i
73
73
break
74
74
cute_tensor = cute_tensor .mark_layout_dynamic (leading_dim = leading_dim )
75
-
75
+
76
76
cute_tensor = cutlass_torch .convert_cute_tensor (
77
77
torch_tensor_gpu ,
78
78
cute_tensor ,
@@ -81,44 +81,49 @@ def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True):
81
81
)
82
82
return cute_tensor , cute_tensor_mc , torch_tensor_gpu , torch_tensor_mc
83
83
84
+
84
85
def create_barrier_flags (m , n , l , mma_tiler_mn , cluster_shape_mn , sm_count ):
85
- barrier_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
86
- m , n , l , mma_tiler_mn , cluster_shape_mn , sm_count
87
- )
88
- print ("LOOK HERE" ,(barrier_size ,))
89
- # NOTE: use_2cta_instrs from blockedscaled_gemm logic
90
-
91
- # use_2cta_instrs = mma_tiler_mn[0] == 256
92
- # cta_tile_shape_mn = (
93
- # mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
94
- # mma_tiler_mn[1],
95
- # )
96
- # problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
97
- # num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
98
- # num_tiles = num_tiles_per_batch * l
99
- # num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
100
- # +num_sms for final barrier
101
- #num_tiles + num_sms
102
-
103
- barrier_flag = symm_mem .empty (
104
- (barrier_size ,), device = "cuda" , dtype = torch .int32
105
- )
106
-
107
- barrier_flag .fill_ (0 )
108
- symm = symm_mem .rendezvous (barrier_flag , group = dist .group .WORLD .group_name )
109
- barrier_flag_mc_ptr = symm .multicast_ptr
110
-
111
- barrier_flag_memref = from_dlpack (barrier_flag )
112
- barrier_flag_memref = barrier_flag_memref .mark_layout_dynamic ()
113
- barrier_flag_mc_torch = cutlass_torch .as_tensor (
114
- barrier_flag_mc_ptr , barrier_flag .shape , barrier_flag .dtype
115
- )
116
- barrier_flag_mc_memref = from_dlpack (
117
- barrier_flag_mc_torch ,
118
- )
119
- barrier_flag_mc_memref = barrier_flag_mc_memref .mark_layout_dynamic ()
120
- barrier_flag_torch = barrier_flag
121
- return barrier_flag_memref , barrier_flag_mc_memref , barrier_flag_torch , barrier_flag_mc_torch
86
+ barrier_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
87
+ m , n , l , mma_tiler_mn , cluster_shape_mn , sm_count
88
+ )
89
+ print ("LOOK HERE" , (barrier_size ,))
90
+ # NOTE: use_2cta_instrs from blockedscaled_gemm logic
91
+
92
+ # use_2cta_instrs = mma_tiler_mn[0] == 256
93
+ # cta_tile_shape_mn = (
94
+ # mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
95
+ # mma_tiler_mn[1],
96
+ # )
97
+ # problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
98
+ # num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
99
+ # num_tiles = num_tiles_per_batch * l
100
+ # num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
101
+ # +num_sms for final barrier
102
+ # num_tiles + num_sms
103
+
104
+ barrier_flag = symm_mem .empty ((barrier_size ,), device = "cuda" , dtype = torch .int32 )
105
+
106
+ barrier_flag .fill_ (0 )
107
+ symm = symm_mem .rendezvous (barrier_flag , group = dist .group .WORLD .group_name )
108
+ barrier_flag_mc_ptr = symm .multicast_ptr
109
+
110
+ barrier_flag_memref = from_dlpack (barrier_flag )
111
+ barrier_flag_memref = barrier_flag_memref .mark_layout_dynamic ()
112
+ barrier_flag_mc_torch = cutlass_torch .as_tensor (
113
+ barrier_flag_mc_ptr , barrier_flag .shape , barrier_flag .dtype
114
+ )
115
+ barrier_flag_mc_memref = from_dlpack (
116
+ barrier_flag_mc_torch ,
117
+ )
118
+ barrier_flag_mc_memref = barrier_flag_mc_memref .mark_layout_dynamic ()
119
+ barrier_flag_torch = barrier_flag
120
+ return (
121
+ barrier_flag_memref ,
122
+ barrier_flag_mc_memref ,
123
+ barrier_flag_torch ,
124
+ barrier_flag_mc_torch ,
125
+ )
126
+
122
127
123
128
def run_blockscaled_gemm_all_reduce_python_interface (
124
129
lm : Tuple [int , int ],
@@ -139,7 +144,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
139
144
iterations : int ,
140
145
enable_dst_signals : int ,
141
146
all_reduce : str ,
142
- rank :int ,
147
+ rank : int ,
143
148
):
144
149
torch .manual_seed (42 )
145
150
device = torch .device ("cuda" , rank )
@@ -187,7 +192,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
187
192
l , n , k , b_major == "n" , cutlass .Float32 , device = device
188
193
)
189
194
c_ref = cutlass_torch .matrix (
190
- l , m , n , c_major == "m" , cutlass .Float32 , device = device ,
195
+ l ,
196
+ m ,
197
+ n ,
198
+ c_major == "m" ,
199
+ cutlass .Float32 ,
200
+ device = device ,
191
201
init_type = cutlass_torch .TensorInitType .SCALAR ,
192
202
init_config = cutlass_torch .ScalarInitConfig (value = 0.0 ),
193
203
)
@@ -213,14 +223,21 @@ def run_blockscaled_gemm_all_reduce_python_interface(
213
223
c_tensor , c_tensor_mc , c_torch , c_torch_mc = create_mc_tensor (
214
224
c_ref ,
215
225
get_cutlass_dtype (c_dtype ),
216
- #(1 if c_major == "n" else 0),
226
+ # (1 if c_major == "n" else 0),
217
227
is_dynamic_layout = True ,
218
228
)
219
- print (f"Rank { rank } : c_torch INITIAL shape={ c_torch .shape } , stride={ c_torch .stride ()} " )
229
+ print (
230
+ f"Rank { rank } : c_torch INITIAL shape={ c_torch .shape } , stride={ c_torch .stride ()} "
231
+ )
220
232
alpha_tensor = (
221
233
torch .randn (l , dtype = torch .float32 , device = device ) if fuse_alpha else None
222
234
)
223
- barrier_flag_memref , barrier_flag_mc_memref , barrier_flag_torch , barrier_flag_mc_torch = create_barrier_flags (
235
+ (
236
+ barrier_flag_memref ,
237
+ barrier_flag_mc_memref ,
238
+ barrier_flag_torch ,
239
+ barrier_flag_mc_torch ,
240
+ ) = create_barrier_flags (
224
241
m ,
225
242
n ,
226
243
l ,
@@ -254,15 +271,15 @@ def run_blockscaled_gemm_all_reduce_python_interface(
254
271
sfb_ref , sfb_tensor , sfb_torch = create_scale_factor_tensor (
255
272
l , n , k , sf_vec_size , get_cutlass_dtype (sf_dtype ), device
256
273
)
257
- #masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
258
- # if rank == 0:
259
- # masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
260
- # else:
261
- # masked_m_tensor = torch.empty((l,), dtype=torch.int32, device=device)
262
- # torch.distributed.broadcast(masked_m_tensor, src=0)
274
+ # masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
275
+ if rank == 0 :
276
+ masked_m_tensor = torch .randint (0 , m , (l ,), dtype = torch .int32 , device = device )
277
+ else :
278
+ masked_m_tensor = torch .empty ((l ,), dtype = torch .int32 , device = device )
279
+ torch .distributed .broadcast (masked_m_tensor , src = 0 )
263
280
# to hack and test:
264
- masked_m_tensor = torch .full ((l ,), m , dtype = torch .int32 , device = device )
265
- print (f"Rank { rank } : masked_m = { masked_m_tensor } " )
281
+ # masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
282
+ print (f"Rank { rank } : masked_m = { masked_m_tensor } " )
266
283
for _ in range (iterations ):
267
284
dst_signals = (
268
285
torch .zeros ((l ,), dtype = torch .uint32 , device = "cuda" )
@@ -306,7 +323,9 @@ def run_blockscaled_gemm_all_reduce_python_interface(
306
323
ref = torch .einsum ("mkl,nkl->mnl" , res_a , res_b )
307
324
ref = torch .einsum ("mnl,l->mnl" , ref , alpha_tensor )
308
325
ref = ref .contiguous ()
309
- torch .distributed .all_reduce (ref , op = torch .distributed .ReduceOp .SUM , group = dist .group .WORLD )
326
+ torch .distributed .all_reduce (
327
+ ref , op = torch .distributed .ReduceOp .SUM , group = dist .group .WORLD
328
+ )
310
329
# Convert c back to f32 for comparison.
311
330
ref = ref .permute (2 , 0 , 1 ).contiguous ().permute (1 , 2 , 0 )
312
331
print (f"Rank { rank } : c_ref shape={ c_ref .shape } , stride={ c_ref .stride ()} " )
@@ -354,9 +373,10 @@ def run_blockscaled_gemm_all_reduce_python_interface(
354
373
rtol = 1e-02 ,
355
374
)
356
375
376
+
357
377
def _run_correctness_worker (
358
- world_size ,
359
- rank ,
378
+ world_size ,
379
+ rank ,
360
380
distributed_init_port ,
361
381
lm ,
362
382
kn ,
@@ -447,9 +467,48 @@ def multi_process_parallel(
447
467
448
468
for i in range (world_size ):
449
469
procs [i ].join ()
450
- assert procs [i ].exitcode == 0 , (
451
- f"Process { i } failed with exit code { procs [i ].exitcode } "
452
- )
470
+ assert (
471
+ procs [i ].exitcode == 0
472
+ ), f"Process { i } failed with exit code { procs [i ].exitcode } "
473
+
474
+
475
+ # @pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)])
476
+ # @pytest.mark.parametrize("kn", [(7168, 4096), (2048, 7168)])
477
+ # @pytest.mark.parametrize(
478
+ # "ab_dtype,sf_dtype,c_dtype,sf_vec_size",
479
+ # [
480
+ # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
481
+ # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
482
+ # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
483
+ # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
484
+ # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
485
+ # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
486
+ # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
487
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
488
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
489
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
490
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
491
+ # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
492
+ # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
493
+ # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
494
+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
495
+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
496
+ # ],
497
+ # )
498
+ # @pytest.mark.parametrize("a_major", ["k"])
499
+ # @pytest.mark.parametrize("b_major", ["k"])
500
+ # @pytest.mark.parametrize("c_major", ["n"])
501
+ # @pytest.mark.parametrize("fuse_alpha", [False, True])
502
+ # @pytest.mark.parametrize("alpha_dtype", ["float32"])
503
+ # @pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
504
+ # @pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
505
+ # @pytest.mark.parametrize("sm_count", [132, None])
506
+ # @pytest.mark.parametrize("tolerance", [1e-01])
507
+ # @pytest.mark.parametrize("iterations", [3])
508
+ # @pytest.mark.parametrize("enable_dst_signals", [False, True])
509
+
510
+ # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
511
+
453
512
454
513
@pytest .mark .skipif (
455
514
not is_cute_dsl_available (), reason = "Please `pip install nvidia-cutlass-dsl`"
@@ -460,21 +519,36 @@ def multi_process_parallel(
460
519
@pytest .mark .parametrize (
461
520
"ab_dtype,sf_dtype,c_dtype,sf_vec_size" ,
462
521
[
463
- ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
464
- # Add more combinations as needed
522
+ ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 )
523
+ # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
524
+ # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
525
+ # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
526
+ # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
527
+ # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
528
+ # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
529
+ # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
530
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
531
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
532
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
533
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
534
+ # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
535
+ # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
536
+ # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
537
+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
538
+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
465
539
],
466
540
)
467
541
@pytest .mark .parametrize ("a_major" , ["k" ])
468
542
@pytest .mark .parametrize ("b_major" , ["k" ])
469
543
@pytest .mark .parametrize ("c_major" , ["n" ])
470
- @pytest .mark .parametrize ("fuse_alpha" , [False ])
544
+ @pytest .mark .parametrize ("fuse_alpha" , [False , True ])
471
545
@pytest .mark .parametrize ("alpha_dtype" , ["float32" ])
472
546
@pytest .mark .parametrize ("mma_tiler_mn" , [(128 , 128 )])
473
547
@pytest .mark .parametrize ("cluster_shape_mn" , [(1 , 1 )])
474
548
@pytest .mark .parametrize ("sm_count" , [148 ])
475
549
@pytest .mark .parametrize ("tolerance" , [1e-01 ])
476
550
@pytest .mark .parametrize ("iterations" , [1 ])
477
- @pytest .mark .parametrize ("enable_dst_signals" , [True ])
551
+ @pytest .mark .parametrize ("enable_dst_signals" , [False , True ])
478
552
@pytest .mark .parametrize ("all_reduce" , ["two_shot" ])
479
553
def test_cute_dsl_blockscaled_gemm_allreduce_two_shot (
480
554
world_size ,
@@ -527,4 +601,4 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
527
601
all_reduce ,
528
602
),
529
603
)
530
- print (f"cute_dsl_blockscaled_gemm_allreduce_two_shot on { world_size } GPUs: OK" )
604
+ print (f"cute_dsl_blockscaled_gemm_allreduce_two_shot on { world_size } GPUs: OK" )
0 commit comments