diff --git a/ci/setup_python.env b/ci/setup_python.env index ffd9c49ac2..1cba321988 100644 --- a/ci/setup_python.env +++ b/ci/setup_python.env @@ -15,3 +15,6 @@ # Uncomment to override TVM-FFI version: # TVM_FFI_REF= + +# Uncomment to override nvidia-cutlass-dsl version: +CUTLASS_DSL_VERSION=4.4.2 diff --git a/flashinfer/cute_dsl/gemm_allreduce_two_shot.py b/flashinfer/cute_dsl/gemm_allreduce_two_shot.py index baf55468a4..25d8ddff2a 100644 --- a/flashinfer/cute_dsl/gemm_allreduce_two_shot.py +++ b/flashinfer/cute_dsl/gemm_allreduce_two_shot.py @@ -31,7 +31,6 @@ def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: # HACK https://github.com/NVIDIA/cutlass/issues/2845 from cutlass._mlir.dialects import nvvm -from cutlass.cutlass_dsl import T from cutlass._mlir.dialects.nvvm import ( MemOrderKind, MemScopeKind, @@ -56,7 +55,6 @@ def spin_lock_atom_cas_acquire_wait( result = 0 while result != expected_val: result = nvvm.atomicrmw( - T.i32(), AtomicOpKind.CAS, lock_ptr.llvm_ptr, Int32(reset_val).ir_value(loc=loc, ip=ip), @@ -70,7 +68,6 @@ def spin_lock_atom_cas_acquire_wait( result = 0 while result != expected_val: result = nvvm.atomicrmw( - T.i32(), AtomicOpKind.CAS, lock_ptr.llvm_ptr, Int32(reset_val).ir_value(loc=loc, ip=ip), @@ -92,7 +89,7 @@ def sm_wise_inter_gpu_multimem_barrier( bdimx, bdimy, _ = cute.arch.grid_dim() pid = bidx + bidy * bdimx + bidz * bdimx * bdimy distributed.multimem_red_release_sys_add1(barrier_mc + pid, loc=loc, ip=ip) - cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + cute.arch.fence_proxy("alias") # v4.3.1 does not have mem_order="acquire" variant in `distributed` module # filed issue https://github.com/NVIDIA/cutlass/issues/2845 @@ -1251,8 +1248,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) epilog_threads = 32 * len(self.epilog_warp_id) cute.arch.barrier( @@ -1312,7 +1309,7 @@ def kernel( flag = barrier_flag_mc.iterator + tile_id cute.arch.fence_acq_rel_gpu() spin_lock_multimem_arrive(flag) - cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + cute.arch.fence_proxy("alias") # # Advance to next tile diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 10a1f7f822..f8c50c624f 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -1512,8 +1512,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1548,8 +1548,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1569,8 +1569,8 @@ def kernel( sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) sInfo[(4, tile_info_producer_state.index)] = -1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1669,8 +1669,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1844,8 +1844,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1886,8 +1886,8 @@ def kernel( valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)] is_valid_tile = valid_tile_info[0] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1927,8 +1927,8 @@ def kernel( valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)] is_valid_tile = valid_tile_info[0] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1968,8 +1968,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2051,8 +2051,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2152,8 +2152,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2368,8 +2368,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2480,8 +2480,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2811,8 +2811,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() # @@ -2845,8 +2845,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index ce4fb6269b..e07fab4eb6 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -1380,8 +1380,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1416,8 +1416,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1438,8 +1438,8 @@ def kernel( sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) sInfo[(4, tile_info_producer_state.index)] = cutlass.Int32(0) cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1467,8 +1467,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1573,8 +1573,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1659,8 +1659,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1818,8 +1818,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1886,8 +1886,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2023,8 +2023,8 @@ def kernel( if cutlass.const_expr(self.use_blkred): cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) # # Async arrive accumulator buffer empty @@ -2037,8 +2037,8 @@ def kernel( if cutlass.const_expr(self.use_blkred): cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) if is_valid_row: coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[1] @@ -2073,8 +2073,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/utils.py b/flashinfer/fused_moe/cute_dsl/blackwell/utils.py index b1c5349de1..4bc3b960c4 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/utils.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/utils.py @@ -208,7 +208,6 @@ def fmin( ) -> cutlass.Float32: return cutlass.Float32( nvvm.fmin( - T.f32(), cutlass.Float32(a).ir_value(loc=loc, ip=ip), cutlass.Float32(b).ir_value(loc=loc, ip=ip), nan=nan, diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py index 1475850e46..5710f97fac 100644 --- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py +++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py @@ -1469,8 +1469,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) epilog_threads = 32 * len(self.epilog_warp_id) cute.arch.barrier( diff --git a/requirements.txt b/requirements.txt index 7dd93c67f7..7eb97a4ab9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ einops ninja numpy nvidia-cudnn-frontend>=1.13.0 -nvidia-cutlass-dsl>=4.3.4 +nvidia-cutlass-dsl>=4.4.2 nvidia-ml-py packaging>=24.2 requests diff --git a/scripts/setup_test_env.sh b/scripts/setup_test_env.sh index 83480cbd6a..5cd61330f1 100755 --- a/scripts/setup_test_env.sh +++ b/scripts/setup_test_env.sh @@ -23,3 +23,22 @@ if [ -n "${TVM_FFI_REF:-}" ]; then echo "TVM-FFI override complete." echo "" fi + +# Override nvidia-cutlass-dsl if specified +if [ -n "${CUTLASS_DSL_VERSION:-}" ]; then + # Detect CUDA major version: only CUDA 13+ needs [cu13] extra + CUDA_MAJOR=$(python -c "import torch; print(torch.version.cuda.split('.')[0])" 2>/dev/null || echo "12") + if [ "$CUDA_MAJOR" = "13" ]; then + CUTLASS_DSL_PKG="nvidia-cutlass-dsl[cu13]==${CUTLASS_DSL_VERSION}" + else + CUTLASS_DSL_PKG="nvidia-cutlass-dsl==${CUTLASS_DSL_VERSION}" + fi + echo "========================================" + echo "Overriding nvidia-cutlass-dsl with: ${CUTLASS_DSL_PKG}" + echo "========================================" + # Clean uninstall old packages first (recommended by NVIDIA docs) + pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu12 nvidia-cutlass-dsl-libs-cu13 -y 2>/dev/null || true + pip install "${CUTLASS_DSL_PKG}" + echo "nvidia-cutlass-dsl override complete." + echo "" +fi