Skip to content

Commit 1401a3c

Browse files
authored
[None][feat] Add FP8 rowwise GEMMs for B200 (#8332)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
1 parent 9c4432f commit 1401a3c

File tree

4 files changed

+395
-3
lines changed

4 files changed

+395
-3
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ add_cuda_architectures(fpA_intB_gemm_src 89)
200200
add_instantiations(fpA_intB_gemm_src ${INSTANTIATION_GENERATION_DIR}/gemm)
201201

202202
add_library(fb_gemm_src STATIC ${FBGEMM_SRC_CU} ${FBGEMM_CU_INSTANTIATIONS})
203-
set_cuda_architectures(fb_gemm_src 89 90 120f)
203+
set_cuda_architectures(fb_gemm_src 89 90 100f 120f)
204204
# add_instantiations(fb_gemm_src
205205
# ${INSTANTIATION_GENERATION_DIR}/fp8_rowwise_gemm)
206206

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
20+
#pragma GCC diagnostic push
21+
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
22+
#endif // __GNUC__
23+
24+
#include "cute/tensor.hpp"
25+
#include "cutlass/conv/convolution.h"
26+
// Order matters here, packed_stride.hpp is missing cute and convolution includes
27+
#include "cutlass/util/packed_stride.hpp"
28+
29+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
30+
#include "cutlass/epilogue/thread/linear_combination.h"
31+
#include "cutlass/gemm/collective/collective_builder.hpp"
32+
#include "cutlass/gemm/dispatch_policy.hpp"
33+
34+
#include "cutlass/epilogue/thread/activation.h"
35+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
36+
37+
#include "cutlass/epilogue/collective/collective_builder.hpp"
38+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
39+
40+
#include "tensorrt_llm/kernels/archCondition.h"
41+
42+
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
43+
#pragma GCC diagnostic pop
44+
#endif // __GNUC__
45+
46+
namespace tensorrt_llm::kernels::cutlass_kernels
47+
{
48+
using namespace cute;
49+
50+
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape,
51+
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
52+
typename TileSchedulerType = void>
53+
struct DeviceGemmFp8RowwiseSm100
54+
{
55+
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
56+
57+
// A matrix configuration
58+
using ElementA = ElementType; // Element type for A matrix operand
59+
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
60+
static constexpr int AlignmentA
61+
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A
62+
// matrix in units of elements (up to 16 bytes)
63+
64+
// B matrix configuration
65+
using ElementB = ElementType; // Element type for B matrix operand
66+
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
67+
static constexpr int AlignmentB
68+
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B
69+
// matrix in units of elements (up to 16 bytes)
70+
71+
// C/D matrix configuration
72+
using ElementC = void; // Element type for C matrix operands
73+
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
74+
static constexpr int AlignmentC
75+
= 128 / cutlass::sizeof_bits<OutElementType>::value; // Memory access granularity/alignment of C matrices in
76+
// units of elements (up to 16 bytes)
77+
78+
// Output matrix configuration
79+
using ElementOutput = OutElementType; // Element type for output matrix operands
80+
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
81+
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
82+
83+
// Auxiliary matrix configuration and other fusion types
84+
using ElementBias = float;
85+
86+
// Multiply-accumulate blocking/pipelining details
87+
using ElementAccumulator = AccumElementType; // Element type for internal accumulation
88+
using ElementCompute = float; // Element type for compute
89+
using ElementComputeEpilogue = float;
90+
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
91+
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
92+
using TileShape = CTAShape; // Threadblock-level tile size
93+
using TileScheduler = TileSchedulerType;
94+
95+
using Multiply = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementComputeEpilogue,
96+
ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
97+
using Add = cutlass::epilogue::fusion::Sm90Compute<cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
98+
cutlass::FloatRoundStyle::round_to_nearest>;
99+
using Cast = cutlass::epilogue::fusion::Sm90Compute<cutlass::epilogue::thread::Identity, OutElementType,
100+
ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
101+
102+
// Implement rowwise scaling epilogue.
103+
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue,
104+
ElementComputeEpilogue, cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
105+
106+
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue,
107+
ElementComputeEpilogue, cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
108+
109+
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementBias, ElementBias,
110+
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
111+
112+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
113+
114+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies,
115+
ElementComputeEpilogue, // First stage output type.
116+
ElementComputeEpilogue, // First stage input types.
117+
cutlass::FloatRoundStyle::round_to_nearest>;
118+
119+
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
120+
121+
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput,
122+
ElementComputeEpilogue, // Second stage input types.
123+
cutlass::FloatRoundStyle::round_to_nearest>;
124+
125+
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
126+
127+
using ComputeBias = cutlass::epilogue::fusion::Sm90Compute<cutlass::plus,
128+
ElementOutput, // Final (optional) stage output type.
129+
ElementBias, // Final stage input types.
130+
cutlass::FloatRoundStyle::round_to_nearest>;
131+
132+
using EVTComputeBias = cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
133+
134+
using EpilogueEVT = EVTCompute1;
135+
136+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<ArchTag, OperatorClass,
137+
TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
138+
ElementComputeEpilogue, ElementC, LayoutC, AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput,
139+
EpilogueScheduleType, EpilogueEVT>::CollectiveOp;
140+
141+
using MainLoopSchedule = cutlass::gemm::collective::KernelScheduleAuto;
142+
143+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<ArchTag, OperatorClass, ElementA,
144+
LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
145+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
146+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
147+
MainLoopSchedule>::CollectiveOp;
148+
149+
template <typename Base>
150+
struct Sm100Only : Base
151+
{
152+
using typename Base::Params;
153+
154+
CUTLASS_DEVICE
155+
void operator()(Params const& params, char* smem_buf)
156+
{
157+
if constexpr (tensorrt_llm::kernels::arch::is_match_v<100>)
158+
{
159+
this->Base::operator()(params, smem_buf);
160+
}
161+
else
162+
{
163+
if (cute::thread0())
164+
{
165+
printf("%s : This kernel shall only run on SM100 devices.\n", __PRETTY_FUNCTION__);
166+
__trap();
167+
}
168+
}
169+
}
170+
};
171+
172+
using GemmKernel
173+
= Sm100Only<typename cutlass::gemm::kernel::GemmUniversal<cute::Shape<int, int, int, int>, // Indicates
174+
// ProblemShape
175+
CollectiveMainloop, CollectiveEpilogue, TileScheduler>>;
176+
177+
using Gemm = typename cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
178+
};
179+
180+
} // namespace tensorrt_llm::kernels::cutlass_kernels

0 commit comments

Comments
 (0)