Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ci/setup_python.env
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@

# Uncomment to override TVM-FFI version:
# TVM_FFI_REF=

# Uncomment to override nvidia-cutlass-dsl version:
CUTLASS_DSL_VERSION=4.4.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove it after the CI passed.

11 changes: 4 additions & 7 deletions flashinfer/cute_dsl/gemm_allreduce_two_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None:

# HACK https://github.com/NVIDIA/cutlass/issues/2845
from cutlass._mlir.dialects import nvvm
from cutlass.cutlass_dsl import T
from cutlass._mlir.dialects.nvvm import (
MemOrderKind,
MemScopeKind,
Expand All @@ -56,7 +55,6 @@ def spin_lock_atom_cas_acquire_wait(
result = 0
while result != expected_val:
result = nvvm.atomicrmw(
T.i32(),
AtomicOpKind.CAS,
lock_ptr.llvm_ptr,
Int32(reset_val).ir_value(loc=loc, ip=ip),
Expand All @@ -70,7 +68,6 @@ def spin_lock_atom_cas_acquire_wait(
result = 0
while result != expected_val:
result = nvvm.atomicrmw(
T.i32(),
AtomicOpKind.CAS,
lock_ptr.llvm_ptr,
Int32(reset_val).ir_value(loc=loc, ip=ip),
Expand All @@ -92,7 +89,7 @@ def sm_wise_inter_gpu_multimem_barrier(
bdimx, bdimy, _ = cute.arch.grid_dim()
pid = bidx + bidy * bdimx + bidz * bdimx * bdimy
distributed.multimem_red_release_sys_add1(barrier_mc + pid, loc=loc, ip=ip)
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
cute.arch.fence_proxy("alias")

# v4.3.1 does not have mem_order="acquire" variant in `distributed` module
# filed issue https://github.com/NVIDIA/cutlass/issues/2845
Expand Down Expand Up @@ -1251,8 +1248,8 @@ def kernel(
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
epilog_threads = 32 * len(self.epilog_warp_id)
cute.arch.barrier(
Expand Down Expand Up @@ -1312,7 +1309,7 @@ def kernel(
flag = barrier_flag_mc.iterator + tile_id
cute.arch.fence_acq_rel_gpu()
spin_lock_multimem_arrive(flag)
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
cute.arch.fence_proxy("alias")

#
# Advance to next tile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1512,8 +1512,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand Down Expand Up @@ -1548,8 +1548,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand All @@ -1569,8 +1569,8 @@ def kernel(
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0)
sInfo[(4, tile_info_producer_state.index)] = -1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
Expand Down Expand Up @@ -1669,8 +1669,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1844,8 +1844,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1886,8 +1886,8 @@ def kernel(
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
is_valid_tile = valid_tile_info[0] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1927,8 +1927,8 @@ def kernel(
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
is_valid_tile = valid_tile_info[0] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1968,8 +1968,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2051,8 +2051,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2152,8 +2152,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2368,8 +2368,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2480,8 +2480,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2811,8 +2811,8 @@ def kernel(
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
self.epilog_sync_barrier.arrive_and_wait()
#
Expand Down Expand Up @@ -2845,8 +2845,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1380,8 +1380,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand Down Expand Up @@ -1416,8 +1416,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand All @@ -1438,8 +1438,8 @@ def kernel(
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0)
sInfo[(4, tile_info_producer_state.index)] = cutlass.Int32(0)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
Expand Down Expand Up @@ -1467,8 +1467,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1573,8 +1573,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1659,8 +1659,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1818,8 +1818,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1886,8 +1886,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2023,8 +2023,8 @@ def kernel(

if cutlass.const_expr(self.use_blkred):
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
#
# Async arrive accumulator buffer empty
Expand All @@ -2037,8 +2037,8 @@ def kernel(

if cutlass.const_expr(self.use_blkred):
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
if is_valid_row:
coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[1]
Expand Down Expand Up @@ -2073,8 +2073,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down
1 change: 0 additions & 1 deletion flashinfer/fused_moe/cute_dsl/blackwell/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def fmin(
) -> cutlass.Float32:
return cutlass.Float32(
nvvm.fmin(
T.f32(),
cutlass.Float32(a).ir_value(loc=loc, ip=ip),
cutlass.Float32(b).ir_value(loc=loc, ip=ip),
nan=nan,
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,8 +1469,8 @@ def kernel(
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
epilog_threads = 32 * len(self.epilog_warp_id)
cute.arch.barrier(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ einops
ninja
numpy
nvidia-cudnn-frontend>=1.13.0
nvidia-cutlass-dsl>=4.3.4
nvidia-cutlass-dsl>=4.4.2
nvidia-ml-py
packaging>=24.2
requests
Expand Down
19 changes: 19 additions & 0 deletions scripts/setup_test_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,22 @@ if [ -n "${TVM_FFI_REF:-}" ]; then
echo "TVM-FFI override complete."
echo ""
fi

# Override nvidia-cutlass-dsl if specified
if [ -n "${CUTLASS_DSL_VERSION:-}" ]; then
# Detect CUDA major version: only CUDA 13+ needs [cu13] extra
CUDA_MAJOR=$(python -c "import torch; print(torch.version.cuda.split('.')[0])" 2>/dev/null || echo "12")
if [ "$CUDA_MAJOR" = "13" ]; then
CUTLASS_DSL_PKG="nvidia-cutlass-dsl[cu13]==${CUTLASS_DSL_VERSION}"
else
CUTLASS_DSL_PKG="nvidia-cutlass-dsl==${CUTLASS_DSL_VERSION}"
fi
echo "========================================"
echo "Overriding nvidia-cutlass-dsl with: ${CUTLASS_DSL_PKG}"
echo "========================================"
# Clean uninstall old packages first (recommended by NVIDIA docs)
pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu12 nvidia-cutlass-dsl-libs-cu13 -y 2>/dev/null || true
pip install "${CUTLASS_DSL_PKG}"
echo "nvidia-cutlass-dsl override complete."
echo ""
fi
Loading