Skip to content

Commit 1ab4f41

Browse files
committed
Add: DPX instructions on Hopper
1 parent 80e1d83 commit 1ab4f41

File tree

2 files changed

+107
-27
lines changed

2 files changed

+107
-27
lines changed

less_slow.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2233,6 +2233,23 @@ BENCHMARK_CAPTURE(
22332233
64, 256, 8, 90, 128, tensor_core_scale_t::warpgroup_k)
22342234
->MinTime(10);
22352235

2236+
extern __global__ void tops_u16u32_sm90dpx_16x16x32_loop128_floyd_warshall_cuda_kernel();
2237+
extern __global__ void tops_i16i32_sm90dpx_16x16x32_loop128_needleman_wunsch_cuda_kernel();
2238+
extern __global__ void tops_i32i32_sm90dpx_16x16x16_loop128_smith_waterman_cuda_kernel();
2239+
2240+
BENCHMARK_CAPTURE( //
2241+
theoretic_tops_cuda, u16u32_sm90dpx, tops_u16u32_sm90dpx_16x16x32_loop128_floyd_warshall_cuda_kernel, //
2242+
16, 16, 32, 90, 128, tensor_core_scale_t::single_k)
2243+
->MinTime(10);
2244+
BENCHMARK_CAPTURE( //
2245+
theoretic_tops_cuda, i16i32_sm90dpx, tops_i16i32_sm90dpx_16x16x32_loop128_needleman_wunsch_cuda_kernel, //
2246+
16, 16, 32, 90, 128, tensor_core_scale_t::single_k)
2247+
->MinTime(10);
2248+
BENCHMARK_CAPTURE( //
2249+
theoretic_tops_cuda, i32i32_sm90dpx, tops_i32i32_sm90dpx_16x16x16_loop128_smith_waterman_cuda_kernel, //
2250+
16, 16, 16, 90, 128, tensor_core_scale_t::single_k)
2251+
->MinTime(10);
2252+
22362253
#include <filesystem> // `std::filesystem::absolute` to locate PTX IR file
22372254

22382255
/**

less_slow.cu

Lines changed: 90 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,26 +75,8 @@
7575
#include <cuda_fp8.h> // `__nv_fp8*` types
7676
#endif
7777

78-
template <typename scalar_type_, std::size_t side_>
79-
struct small_square_matrix {
80-
scalar_type_ scalars[side_][side_];
81-
};
82-
83-
/**
84-
* @brief A CUDA kernel that computes the product of two small square matrices.
85-
* Doesn't use any block/warp-level communication and optimizations.
86-
*/
87-
template <typename scalar_type_, std::size_t side_>
88-
small_square_matrix<scalar_type_, side_> small_matmul_kernel_cuda( //
89-
small_square_matrix<scalar_type_, side_> const &a, //
90-
small_square_matrix<scalar_type_, side_> const &b) {
91-
92-
small_square_matrix<scalar_type_, side_> c;
93-
for (std::size_t i = 0; i != side_; ++i)
94-
for (std::size_t j = 0; j != side_; ++j)
95-
for (std::size_t k = 0; k != side_; ++k) c.scalars[i][j] += a.scalars[i][k] * b.scalars[k][j];
96-
return c;
97-
}
78+
#pragma region - Basics
79+
#pragma region Parallelism and Computational Complex
9880

9981
#include <thrust/sort.h> // `thrust::sort`
10082

@@ -133,7 +115,11 @@ void reverse_and_sort_with_cub(std::uint32_t *device_pointer, std::size_t array_
133115
);
134116
}
135117

136-
#pragma region Numerics
118+
#pragma endregion // Parallelism and Computational Complex
119+
#pragma endregion // Basics
120+
121+
#pragma region - Numerics
122+
#pragma region Scalar Operations
137123

138124
/**
139125
* @brief On-device @b Fused-Multiply-Add operator, that for most numeric
@@ -178,7 +164,6 @@ __device__ void tops_fma_cuda_kernel() {
178164
}
179165

180166
__global__ void tops_f32f32_sm60fma_16x16x16_loop128_cuda_kernel() { tops_fma_cuda_kernel<float, float, 16, 128>(); }
181-
182167
__global__ void tops_f64f64_sm60fma_16x16x16_loop128_cuda_kernel() { tops_fma_cuda_kernel<double, double, 16, 128>(); }
183168

184169
__global__ void tops_f16f16_sm70fma_16x16x16_loop128_cuda_kernel() {
@@ -217,6 +202,83 @@ __global__ void tops_i64i64_sm60fma_16x16x16_loop128_cuda_kernel() {
217202
tops_fma_cuda_kernel<std::int64_t, std::int64_t, 16, 128>();
218203
}
219204

205+
/**
206+
* Given the growing demand for such workloads, new Dynamic Programming
207+
* eXtensions @b (DPX) have been added on Hopper for various combinations
208+
* of { addition, min, max, ReLU } on 8-bit and 16-bit integer inputs.
209+
*
210+
* Thus, @b Floyd-Warshall All-Pairs Shortest Path @b (APSP) algorithm can be
211+
* reformulated as @b Tropical-semiring matrix multiplications in Algebraic
212+
* Graph Theory.
213+
*
214+
* It works for both positive and negative edge weights, but not in the
215+
* presence of negative cycles, so most people will realistically use the
216+
* 16-bit unsigned edge weights with 32-bit unsigned accumulators.
217+
*
218+
* @see "Floyd–Warshall algorithm" on Wikipedia: https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm
219+
* @see "Boosting Dynamic Programming Performance Using NVIDIA Hopper GPU DPX
220+
* Instructions" by Nvidia:
221+
* https://developer.nvidia.com/blog/boosting-dynamic-programming-performance-using-nvidia-hopper-gpu-dpx-instructions/
222+
*/
223+
__global__ void tops_u16u32_sm90dpx_16x16x32_loop128_floyd_warshall_cuda_kernel() {
224+
// Each pair of unsigned 16-bit inputs will be represented by a single `uint`.
225+
#if (__CUDA_ARCH__ >= 900)
226+
struct floyd_warshall_semiring_t {
227+
inline __device__ uint operator()(uint a, uint b, uint c) const noexcept { return __viaddmin_u16x2(a, b, c); }
228+
};
229+
tops_fma_cuda_kernel<uint, uint, 16, 128, floyd_warshall_semiring_t>();
230+
#endif
231+
}
232+
233+
/**
234+
* Similarly, the @b Needleman-Wunsch algorithm in Bioinformatics is often
235+
* used for @b global alignment of fairly short protein or DNA & RNA strings.
236+
*
237+
* @see "Needleman–Wunsch algorithm" on Wikipedia: https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm
238+
*/
239+
__global__ void tops_i16i32_sm90dpx_16x16x32_loop128_needleman_wunsch_cuda_kernel() {
240+
// Each pair of signed 16-bit inputs will be represented by a single `uint`.
241+
#if (__CUDA_ARCH__ >= 900)
242+
struct needleman_wunsch_semiring_t {
243+
inline __device__ uint operator()(uint a, uint b, uint c) const noexcept { return __viaddmax_s16x2(a, b, c); }
244+
};
245+
tops_fma_cuda_kernel<uint, uint, 16, 128, needleman_wunsch_semiring_t>();
246+
#endif
247+
}
248+
249+
/**
250+
* Similarly, the @b Needleman-Wunsch algorithm in Bioinformatics is often
251+
* used for @b local alignment of longer DNA & RNA strings. It also replaces
252+
* multiplication with addition, and addition with maximum, but also applies
253+
* the Rectified Linear Unit, to cut-off negative values.
254+
*
255+
* Assuming the strings can easily be over 64 KB long, we should use the
256+
* larger 32-bit inputs for cost matrices.
257+
*
258+
* @see "Smith–Waterman algorithm" on Wikipedia: https://en.wikipedia.org/wiki/Smith%E2%80%93Waterman_algorithm
259+
*/
260+
__global__ void tops_i32i32_sm90dpx_16x16x16_loop128_smith_waterman_cuda_kernel() {
261+
#if (__CUDA_ARCH__ >= 900)
262+
struct smith_waterman_operator_t {
263+
inline __device__ int operator()(int a, int b, int c) const noexcept { return __viaddmax_s32_relu(a, b, c); }
264+
};
265+
tops_fma_cuda_kernel<int, int, 16, 128, smith_waterman_operator_t>();
266+
#endif
267+
}
268+
269+
/**
270+
* On H200, the following integer performance can be expected:
271+
*
272+
* - Naive FMA for `i32` and `i64` inputs: 2.3 P
273+
* - Hopper DPX for Floyd-Warshall algorithm with `u16` and `u32`: 11 T
274+
* - Hopper DPX for Needleman-Wunsch algorithm with `i16` and `i32`: 11 T
275+
* - Hopper DPX for Smith-Waterman algorithm with `i32`: 27 T
276+
*/
277+
278+
#pragma endregion // Scalar Operations
279+
280+
#pragma region Tiled Matrix Multiplications
281+
220282
/**
221283
* Starting with Nvidia Volta GPUs, specialized "Tensor Cores" @b (TC) are
222284
* added for faster matrix multiplications. These Tensor Cores are much faster
@@ -397,7 +459,7 @@ __global__ void tops_b1i32and_sm80wmma_8x8x128_loop128_cuda_kernel() {
397459
#endif
398460
}
399461

400-
#pragma endregion
462+
#pragma endregion // Tiled Matrix Multiplications
401463

402464
/**
403465
* MMA is not the only family of tensor core instructions:
@@ -479,9 +541,8 @@ __global__ void tops_b1i32and_sm80wmma_8x8x128_loop128_cuda_kernel() {
479541
* .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8,
480542
* .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8,
481543
* .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8
482-
483544
*/
484-
#pragma region Hopper
545+
#pragma region Tiled Matrix Multiplications Across Warps
485546

486547
/**
487548
* Ideally, both matrices A and B should be in shared memory. Both are
@@ -738,11 +799,13 @@ __global__ void tops_tf32f32_sm90wgmma_64x256x8_loop128_cuda_kernel() {
738799
if (threadIdx.x == 2147483647) *(std::uint32_t *)nullptr = c_registers[0];
739800
}
740801

741-
#pragma endregion
802+
#pragma endregion // Tiled Matrix Multiplications Across Warps
742803

743804
/**
744805
*
745806
* @see "Blackwell Cluster Launch Control" in CUTLASS docs:
746807
* https://github.com/NVIDIA/cutlass/blob/main/media/docs/blackwell_cluster_launch_control.md
747808
*
748-
*/
809+
*/
810+
811+
#pragma endregion // Numerics

0 commit comments

Comments
 (0)