Skip to content

Commit 66ddb23

Browse files
committed
wip
1 parent edb75a0 commit 66ddb23

File tree

2 files changed

+24
-30
lines changed

2 files changed

+24
-30
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,20 @@
5050
Uint64,
5151
T,
5252
Integer,
53+
dsl_user_op,
54+
extract_mlir_values,
55+
new_from_mlir_values,
56+
)
57+
58+
from cutlass.cute.typing import (
59+
Int32,
5360
Float16,
5461
BFloat16,
5562
Float32,
5663
Float8E4M3FN,
5764
Float8E5M2,
5865
Tensor,
59-
dsl_user_op,
60-
extract_mlir_values,
61-
new_from_mlir_values,
6266
)
63-
64-
# from cutlass.cute.typing import (
65-
# Int32,
66-
# Float16,
67-
# BFloat16,
68-
# Float32,
69-
# Float8E4M3FN,
70-
# Float8E5M2,
71-
# Tensor,
72-
# )
7367
from cutlass._mlir.dialects import llvm
7468
from flashinfer.utils import get_compute_capability
7569
from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
8686
barrier_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
8787
m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count
8888
)
89-
print("LOOK HERE", (barrier_size,))
89+
#print("LOOK HERE", (barrier_size,))
9090
# NOTE: use_2cta_instrs from blockedscaled_gemm logic
9191

9292
# use_2cta_instrs = mma_tiler_mn[0] == 256
@@ -481,23 +481,23 @@ def multi_process_parallel(
481481
@pytest.mark.parametrize(
482482
"ab_dtype,sf_dtype,c_dtype,sf_vec_size",
483483
[
484-
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
485-
("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486-
("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487-
("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488-
("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489-
("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490-
("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491-
("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492-
("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493-
("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
484+
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
485+
# ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486+
# ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487+
# ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488+
# ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489+
# ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490+
# ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491+
# ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492+
# ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493+
# ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
494494
("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
495-
("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496-
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497-
("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498-
("float8_e5m2", "float8_e8m0fnu", "float32", 32),
499-
("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500-
("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
495+
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496+
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497+
# ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498+
# ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
499+
# ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500+
# ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
501501
],
502502
)
503503
@pytest.mark.parametrize("a_major", ["k"])

0 commit comments

Comments
 (0)