Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
FetchContent_MakeAvailable(cutlass)

# Fetch CUTLASS v4.x for SM120 support (has Sm120BlockwiseScaleConfig)
# Only fetch source - do NOT run MakeAvailable to avoid duplicate target names
FetchContent_Declare(
cutlass_v4
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v4.0.0
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
FetchContent_GetProperties(cutlass_v4)
if(NOT cutlass_v4_POPULATED)
FetchContent_Populate(cutlass_v4)
endif()
set(CUTLASS_V4_INCLUDE_DIR ${cutlass_v4_SOURCE_DIR}/include)
set(CUTLASS_V4_TOOLS_UTIL_INCLUDE_DIR ${cutlass_v4_SOURCE_DIR}/tools/util/include)

list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
Expand Down Expand Up @@ -421,8 +437,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
# CUDA 12.8 or later (SM100/SM101 only, SM120 handled separately with CUTLASS v4.x)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
Expand All @@ -449,6 +465,41 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# SM120 kernels require CUTLASS v4.x (has Sm120BlockwiseScaleConfig)
cuda_archs_loose_intersection(SCALED_MM_SM120_ARCHS "12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_SM120_ARCHS)
set(SM120_SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)

# SM120 uses CUTLASS v4.x includes
set_source_files_properties(
${SM120_SRCS}
PROPERTIES
INCLUDE_DIRECTORIES "${CUTLASS_V4_INCLUDE_DIR};${CUTLASS_V4_TOOLS_UTIL_INCLUDE_DIR};${CMAKE_SOURCE_DIR}/csrc"
)

set_gencode_flags_for_srcs(
SRCS "${SM120_SRCS}"
CUDA_ARCHS "${SCALED_MM_SM120_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SM120_SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_SM120_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 with CUTLASS v4.x for archs: ${SCALED_MM_SM120_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_SM120_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"SM120 (RTX 50xx).")
else()
message(STATUS "Not building scaled_mm_c3x_sm120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()

#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
Expand Down
4 changes: 3 additions & 1 deletion Dockerfile.quick
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM ghcr.io/gonka-ai/vllm:v0.9.1
FROM ghcr.io/gonka-ai/vllm:v0.9.1-poc-v2-post3-blackwell-sm120-fp8-alpha1

ENV VLLM_USE_V1=0

Expand All @@ -7,6 +7,8 @@ COPY ./vllm /tmp/vllm-src

RUN python3 -m pip uninstall nvidia-nccl-cu12 -y && python3 -m pip install nvidia-nccl-cu12==2.26.2.post1

# Note: We use PyTorch MoE fallback instead of upgrading Triton (which breaks inductor)

# Merge only .py files into existing vllm package (preserves .so files)
RUN cd /tmp/vllm-src \
&& find . -name "*.py" | tar -cf - -T - | tar -xf - -C /usr/local/lib/python3.12/dist-packages/vllm/ \
Expand Down
10 changes: 10 additions & 0 deletions csrc/cutlass_extensions/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel {
#endif
}
};

template <typename Kernel>
struct enable_sm120_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
63 changes: 63 additions & 0 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,67 @@ struct cutlass_3x_gemm_sm100 {
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};

template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm_sm120 {
using ElementAB = ElementAB_;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;

using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;

using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementD_>::value;

using ElementD = ElementD_;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;

using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;

// MMA type
using ElementAccumulator = float;

// Epilogue types
using ElementBias = cutlass::half_t;
using ElementCompute = float;
using ElementAux = ElementD;
using LayoutAux = LayoutD;
using ElementAmax = float;

using EVTCompute = typename Epilogue::EVTCompute;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;

struct GemmKernel : public KernelType {};
};

} // namespace vllm
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"

namespace vllm {

void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);

} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}

} // namespace vllm
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#pragma once

#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"

#include "cutlass_gemm_caller.cuh"

namespace vllm {

using namespace cute;

// clang-format off
template <class OutType, int ScaleGranularityM,
int ScaleGranularityN, int ScaleGranularityK,
class MmaTileShape, class ClusterShape,
class EpilogueScheduler, class MainloopScheduler>
struct cutlass_3x_gemm_fp8_blockwise_sm120 {
using ElementAB = cutlass::float_e4m3_t;

using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

using ElementB = ElementAB;
// ColumnMajor is used for B to match the CUTLASS convention.
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;

using ElementC = void; // TODO: support bias
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
static constexpr int AlignmentC = AlignmentD;

using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;

// Using CUTLASS v4.x which has Sm120BlockwiseScaleConfig
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;

// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());

using ArchTag = cutlass::arch::Sm120;
using OperatorClass = cutlass::arch::OpClassTensorOp;

static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;

using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;

using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;

struct GemmKernel : public KernelType {};
};

template <typename Gemm>
void cutlass_gemm_caller_blockwise_sm120(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;

using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
using ElementBlockScale = typename Gemm::ElementBlockScale;

int32_t m = a.size(0), n = b.size(1), k = a.size(1);

StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));

LayoutSFA layout_SFA =
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));

auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());

typename GemmKernel::MainloopArguments mainloop_args{};
mainloop_args.ptr_A = a_ptr;
mainloop_args.dA = a_stride;
mainloop_args.ptr_B = b_ptr;
mainloop_args.dB = b_stride;
mainloop_args.ptr_SFA = a_scales_ptr;
mainloop_args.layout_SFA = layout_SFA;
mainloop_args.ptr_SFB = b_scales_ptr;
mainloop_args.layout_SFB = layout_SFB;
auto prob_shape = cute::make_shape(m, n, k, 1);

auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}

template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
// TODO: better heuristics
cutlass_gemm_caller_blockwise_sm120<cutlass_3x_gemm_fp8_blockwise_sm120<
OutType, 1, 128, 128, Shape<_128, _128, _128>,
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
cutlass::gemm::collective::KernelScheduleAuto>>(
out, a, b, a_scales, b_scales);
}

} // namespace vllm
12 changes: 12 additions & 0 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,16 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);

void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);

void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm
Loading
Loading