Skip to content

Commit 152e59a

Browse files
committed
Improve: Shrink PTX loops
1 parent d1909f9 commit 152e59a

File tree

5 files changed

+114
-110
lines changed

5 files changed

+114
-110
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ build_release/
66
.DS_Store
77

88
# Temporary binaries
9+
/tmp/
910
less_slow_from_ptx.cubin
1011
less_slow_from_cu.cubin
1112
less_slow_from_cu.ptx

less_slow.cpp

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,53 +2080,53 @@ static void theoretic_tops_cuda( //
20802080
state.counters["TOP"] = benchmark::Counter(tops_per_gpu * state.iterations(), benchmark::Counter::kIsRate);
20812081
}
20822082

2083-
extern __global__ void tops_f16f16_sm70tc_16x16x16_1024unroll_cuda_kernel();
2084-
extern __global__ void tops_f16f32_sm70tc_16x16x16_1024unroll_cuda_kernel();
2083+
extern __global__ void tops_f16f16_sm70tc_16x16x16_loop128_cuda_kernel();
2084+
extern __global__ void tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel();
20852085

2086-
extern __global__ void tops_u8i32_sm75tc_16x16x16_1024unroll_cuda_kernel();
2087-
extern __global__ void tops_u4i32_sm75tc_8x8x32_1024unroll_cuda_kernel();
2088-
extern __global__ void tops_b1i32xor_sm75tc_8x8x128_1024unroll_cuda_kernel();
2086+
extern __global__ void tops_u8i32_sm75tc_16x16x16_loop128_cuda_kernel();
2087+
extern __global__ void tops_u4i32_sm75tc_8x8x32_loop128_cuda_kernel();
2088+
extern __global__ void tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel();
20892089

2090-
extern __global__ void tops_bf16f32_sm80tc_16x16x16_1024unroll_cuda_kernel();
2091-
extern __global__ void tops_tf32f32_sm80tc_16x16x8_1024unroll_cuda_kernel();
2092-
extern __global__ void tops_f64f64_sm80tc_8x8x4_1024unroll_cuda_kernel();
2093-
extern __global__ void tops_b1i32and_sm80tc_8x8x128_1024unroll_cuda_kernel();
2090+
extern __global__ void tops_bf16f32_sm80tc_16x16x16_loop128_cuda_kernel();
2091+
extern __global__ void tops_tf32f32_sm80tc_16x16x8_loop128_cuda_kernel();
2092+
extern __global__ void tops_f64f64_sm80tc_8x8x4_loop128_cuda_kernel();
2093+
extern __global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel();
20942094

2095-
BENCHMARK_CAPTURE( //
2096-
theoretic_tops_cuda, f16f16_sm70tc, tops_f16f16_sm70tc_16x16x16_1024unroll_cuda_kernel, //
2097-
16, 16, 16, 1024, 70)
2095+
BENCHMARK_CAPTURE( //
2096+
theoretic_tops_cuda, f16f16_sm70tc, tops_f16f16_sm70tc_16x16x16_loop128_cuda_kernel, //
2097+
16, 16, 16, 128, 70)
20982098
->MinTime(10);
2099-
BENCHMARK_CAPTURE( //
2100-
theoretic_tops_cuda, f16f32_sm70tc, tops_f16f32_sm70tc_16x16x16_1024unroll_cuda_kernel, //
2101-
16, 16, 16, 1024, 70)
2099+
BENCHMARK_CAPTURE( //
2100+
theoretic_tops_cuda, f16f32_sm70tc, tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel, //
2101+
16, 16, 16, 128, 70)
21022102
->MinTime(10);
2103-
BENCHMARK_CAPTURE( //
2104-
theoretic_tops_cuda, u8i32_sm75tc, tops_u8i32_sm75tc_16x16x16_1024unroll_cuda_kernel, //
2105-
16, 16, 16, 1024, 75)
2103+
BENCHMARK_CAPTURE( //
2104+
theoretic_tops_cuda, u8i32_sm75tc, tops_u8i32_sm75tc_16x16x16_loop128_cuda_kernel, //
2105+
16, 16, 16, 128, 75)
21062106
->MinTime(10);
2107-
BENCHMARK_CAPTURE( //
2108-
theoretic_tops_cuda, u4i32_sm75tc, tops_u4i32_sm75tc_8x8x32_1024unroll_cuda_kernel, //
2109-
8, 8, 32, 1024, 75)
2107+
BENCHMARK_CAPTURE( //
2108+
theoretic_tops_cuda, u4i32_sm75tc, tops_u4i32_sm75tc_8x8x32_loop128_cuda_kernel, //
2109+
8, 8, 32, 128, 75)
21102110
->MinTime(10);
2111-
BENCHMARK_CAPTURE( //
2112-
theoretic_tops_cuda, b1i32xor_sm75tc, tops_b1i32xor_sm75tc_8x8x128_1024unroll_cuda_kernel, //
2113-
8, 8, 128, 1024, 75)
2111+
BENCHMARK_CAPTURE( //
2112+
theoretic_tops_cuda, b1i32xor_sm75tc, tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel, //
2113+
8, 8, 128, 128, 75)
21142114
->MinTime(10);
2115-
BENCHMARK_CAPTURE( //
2116-
theoretic_tops_cuda, bf16f32_sm80tc, tops_bf16f32_sm80tc_16x16x16_1024unroll_cuda_kernel, //
2117-
16, 16, 16, 1024, 80)
2115+
BENCHMARK_CAPTURE( //
2116+
theoretic_tops_cuda, bf16f32_sm80tc, tops_bf16f32_sm80tc_16x16x16_loop128_cuda_kernel, //
2117+
16, 16, 16, 128, 80)
21182118
->MinTime(10);
2119-
BENCHMARK_CAPTURE( //
2120-
theoretic_tops_cuda, tf32f32_sm80tc, tops_tf32f32_sm80tc_16x16x8_1024unroll_cuda_kernel, //
2121-
16, 16, 8, 1024, 80)
2119+
BENCHMARK_CAPTURE( //
2120+
theoretic_tops_cuda, tf32f32_sm80tc, tops_tf32f32_sm80tc_16x16x8_loop128_cuda_kernel, //
2121+
16, 16, 8, 128, 80)
21222122
->MinTime(10);
2123-
BENCHMARK_CAPTURE( //
2124-
theoretic_tops_cuda, f64f64_sm80tc, tops_f64f64_sm80tc_8x8x4_1024unroll_cuda_kernel, //
2125-
8, 8, 4, 1024, 80)
2123+
BENCHMARK_CAPTURE( //
2124+
theoretic_tops_cuda, f64f64_sm80tc, tops_f64f64_sm80tc_8x8x4_loop128_cuda_kernel, //
2125+
8, 8, 4, 128, 80)
21262126
->MinTime(10);
2127-
BENCHMARK_CAPTURE( //
2128-
theoretic_tops_cuda, b1i32and_sm80tc, tops_b1i32and_sm80tc_8x8x128_1024unroll_cuda_kernel, //
2129-
8, 8, 128, 1024, 80)
2127+
BENCHMARK_CAPTURE( //
2128+
theoretic_tops_cuda, b1i32and_sm80tc, tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel, //
2129+
8, 8, 128, 128, 80)
21302130
->MinTime(10);
21312131

21322132
#include <filesystem>
@@ -2202,7 +2202,10 @@ static void theoretic_tops_ptx( //
22022202
return;
22032203
}
22042204

2205-
// Load the PTX file
2205+
// Load the PTX file and JIT it!
2206+
// If the compilation is taking long, consider using the `CUDA_CACHE_PATH`
2207+
// environment variable to cache already compiled modules:
2208+
// https://developer.nvidia.com/blog/cuda-pro-tip-understand-fat-binaries-jit-caching/
22062209
result = cuModuleLoad(&module_, ptx_file.c_str());
22072210
if (result != CUDA_SUCCESS) {
22082211
state.SkipWithError("Failed to load PTX file: " + last_error_string());
@@ -2261,46 +2264,46 @@ static void theoretic_tops_ptx( //
22612264
cuCtxDestroy(context);
22622265
}
22632266

2267+
BENCHMARK_CAPTURE( //
2268+
theoretic_tops_ptx, f16f16_sm70tc, //
2269+
"less_slow_sm70.ptx", "tops_f16f16_sm70tc_16x16x16_loop128_ptx_kernel", //
2270+
16, 16, 16, 128, 70)
2271+
->MinTime(10);
2272+
22642273
BENCHMARK_CAPTURE( //
2265-
theoretic_tops_ptx, f16f16_sm70tc, //
2266-
"less_slow_sm70.ptx", "tops_f16f16_sm70tc_16x16x16_1024loop_ptx_kernel", //
2267-
16, 16, 16, 1024, 70)
2274+
theoretic_tops_ptx, f16f16_sm90tc, //
2275+
"less_slow_sm90a.ptx", "tops_f16f16_sm90tc_16x16x16_loop128_ptx_kernel", //
2276+
16, 16, 16, 128, 90)
22682277
->MinTime(10);
22692278

2270-
BENCHMARK_CAPTURE( //
2271-
theoretic_tops_ptx, f16f16_sm90tc, //
2272-
"less_slow_sm90a.ptx", "tops_f16f16_sm90tc_16x16x16_1024loop_ptx_kernel", //
2273-
16, 16, 16, 1024, 90)
2279+
BENCHMARK_CAPTURE( //
2280+
theoretic_tops_ptx, f64f64_sm90tc, //
2281+
"less_slow_sm90a.ptx", "tops_f64f64_sm90tc_8x8x4_loop128_ptx_kernel", //
2282+
8, 8, 4, 128, 90)
22742283
->MinTime(10);
22752284

2276-
BENCHMARK_CAPTURE( //
2277-
theoretic_tops_ptx, f64f64_sm90tc, //
2278-
"less_slow_sm90a.ptx", "tops_f64f64_sm90tc_8x8x4_1024loop_ptx_kernel", //
2279-
8, 8, 4, 1024, 90)
2285+
BENCHMARK_CAPTURE( //
2286+
theoretic_tops_ptx, tf32f32_sm90tc, //
2287+
"less_slow_sm90a.ptx", "tops_tf32f32_sm90tc_16x16x8_loop128_ptx_kernel", //
2288+
16, 16, 8, 128, 90)
22802289
->MinTime(10);
22812290

22822291
BENCHMARK_CAPTURE( //
2283-
theoretic_tops_ptx, tf32f32_sm90tc, //
2284-
"less_slow_sm90a.ptx", "tops_tf32f32_sm90tc_16x16x8_1024loop_ptx_kernel", //
2285-
16, 16, 8, 1024, 90)
2292+
theoretic_tops_ptx, tf32f32_sm90tc_wgmma_smallest, //
2293+
"less_slow_sm90a.ptx", "tops_tf32f32_sm90tc_m64n16k8_loop128_ptx_kernel", //
2294+
64, 16, 8, 128, 90)
22862295
->MinTime(10);
22872296

22882297
BENCHMARK_CAPTURE( //
2289-
theoretic_tops_ptx, tf32f32_sm90tc_wgmma_smallest, //
2290-
"less_slow_sm90a.ptx", "tops_tf32f32_sm90tc_m64n16k8_1024loop_ptx_kernel", //
2291-
64, 16, 8, 1024, 90)
2292-
->MinTime(10);
2293-
2294-
BENCHMARK_CAPTURE( //
2295-
theoretic_tops_ptx, tf32f32_sm90tc_wgmma_largest, //
2296-
"less_slow_sm90a.ptx", "tops_tf32f32_sm90tc_m64n256k8_1024loop_ptx_kernel", //
2297-
64, 256, 8, 1024, 90)
2298+
theoretic_tops_ptx, tf32f32_sm90tc_wgmma_largest, //
2299+
"less_slow_sm90a.ptx", "tops_tf32f32_sm90tc_m64n256k8_loop128_ptx_kernel", //
2300+
64, 256, 8, 128, 90)
22982301
->MinTime(10);
22992302

2300-
BENCHMARK_CAPTURE( //
2301-
theoretic_tops_ptx, b1i32and_sm90tc_wgmma, //
2302-
"less_slow_sm90a.ptx", "tops_b1i32and_sm90tc_m64n256k256_1024loop_ptx_kernel", //
2303-
64, 256, 256, 1024, 90)
2303+
BENCHMARK_CAPTURE( //
2304+
theoretic_tops_ptx, b1i32and_sm90tc_wgmma, //
2305+
"less_slow_sm90a.ptx", "tops_b1i32and_sm90tc_m64n256k256_loop128_ptx_kernel", //
2306+
64, 256, 256, 128, 90)
23042307
->MinTime(10);
23052308

23062309
/**

less_slow.cu

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ void reverse_and_sort_with_cub(std::uint32_t *device_pointer, std::size_t array_
180180
* @brief A CUDA kernel that @b repeatedly computes the product of two small
181181
* matrices of size MxN and NxK using Tensor Cores.
182182
*/
183-
template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_>
183+
template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_ = 128>
184184
__device__ inline void tops_tc_cuda_kernel() {
185185
using namespace nvcuda;
186186
wmma::fragment<wmma::matrix_a, m_, n_, k_, input_type_, wmma::row_major> a_frag;
@@ -210,7 +210,7 @@ __device__ inline void tops_tc_cuda_kernel() {
210210
*
211211
* @see Docs: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#sub-byte-operations
212212
*/
213-
template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_>
213+
template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_ = 128>
214214
__device__ inline void binary_tops_tc_cuda_kernel( //
215215
nvcuda::wmma::experimental::bmmaBitOp bit_op, nvcuda::wmma::experimental::bmmaAccumulateOp acc_op) {
216216
using namespace nvcuda;
@@ -225,48 +225,48 @@ __device__ inline void binary_tops_tc_cuda_kernel( //
225225

226226
#pragma region Volta
227227

228-
__global__ void tops_f16f16_sm70tc_16x16x16_1024unroll_cuda_kernel() {
228+
__global__ void tops_f16f16_sm70tc_16x16x16_loop128_cuda_kernel() {
229229
//? On Volta: 8x8x4.
230230
//? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
231231
//? On Ampere: 16x8x8 / 16x8x16.
232232
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
233-
tops_tc_cuda_kernel<half, half, 16, 16, 16, 1024>();
233+
tops_tc_cuda_kernel<half, half, 16, 16, 16>();
234234
#endif
235235
}
236-
__global__ void tops_f16f32_sm70tc_16x16x16_1024unroll_cuda_kernel() {
236+
__global__ void tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel() {
237237
//? On Volta: 8x8x4.
238238
//? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
239239
//? On Ampere: 16x8x8 / 16x8x16.
240240
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
241-
tops_tc_cuda_kernel<half, float, 16, 16, 16, 1024>();
241+
tops_tc_cuda_kernel<half, float, 16, 16, 16>();
242242
#endif
243243
}
244244

245245
#pragma endregion
246246

247247
#pragma region Turing
248248

249-
__global__ void tops_u8i32_sm75tc_16x16x16_1024unroll_cuda_kernel() {
249+
__global__ void tops_u8i32_sm75tc_16x16x16_loop128_cuda_kernel() {
250250
//? On Turing: 8x8x16.
251251
//? On Ampere: 8x8x16 / 16x8x16 / 16x8x32.
252252
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
253-
tops_tc_cuda_kernel<std::uint8_t, int32_t, 16, 16, 16, 1024>();
253+
tops_tc_cuda_kernel<std::uint8_t, int32_t, 16, 16, 16>();
254254
#endif
255255
}
256-
__global__ void tops_u4i32_sm75tc_8x8x32_1024unroll_cuda_kernel() {
256+
__global__ void tops_u4i32_sm75tc_8x8x32_loop128_cuda_kernel() {
257257
//! The 16x16x16 won't compile, 8x8x32 will.
258258
//? On Turing: 8x8x32.
259259
//? On Ampere: 8x8x32 / 16x8x32 / 16x8x64.
260260
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
261-
tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::u4, int32_t, 8, 8, 32, 1024>();
261+
tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::u4, int32_t, 8, 8, 32>();
262262
#endif
263263
}
264-
__global__ void tops_b1i32xor_sm75tc_8x8x128_1024unroll_cuda_kernel() {
264+
__global__ void tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel() {
265265
//! The 16x16x16 won't compile, 8x8x128 will.
266266
//? On Turing: 8x8x128.
267267
//? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
268268
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
269-
binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t, 8, 8, 128, 1024>(
269+
binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t, 8, 8, 128>(
270270
nvcuda::wmma::experimental::bmmaBitOp::bmmaBitOpXOR,
271271
nvcuda::wmma::experimental::bmmaAccumulateOp::bmmaAccumulateOpPOPC);
272272
#endif
@@ -276,32 +276,32 @@ __global__ void tops_b1i32xor_sm75tc_8x8x128_1024unroll_cuda_kernel() {
276276

277277
#pragma region Ampere
278278

279-
__global__ void tops_bf16f32_sm80tc_16x16x16_1024unroll_cuda_kernel() {
279+
__global__ void tops_bf16f32_sm80tc_16x16x16_loop128_cuda_kernel() {
280280
//? On Ampere: 16x8x8 / 16x8x16.
281281
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
282-
tops_tc_cuda_kernel<__nv_bfloat16, float, 16, 16, 16, 1024>();
282+
tops_tc_cuda_kernel<__nv_bfloat16, float, 16, 16, 16>();
283283
#endif
284284
}
285-
__global__ void tops_tf32f32_sm80tc_16x16x8_1024unroll_cuda_kernel() {
285+
__global__ void tops_tf32f32_sm80tc_16x16x8_loop128_cuda_kernel() {
286286
//! The 16x16x16 won't compile, 16x16x8 will.
287287
//? On Ampere: 16x8x4.
288288
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
289-
tops_tc_cuda_kernel<nvcuda::wmma::precision::tf32, float, 16, 16, 8, 1024>();
289+
tops_tc_cuda_kernel<nvcuda::wmma::precision::tf32, float, 16, 16, 8>();
290290
#endif
291291
}
292-
__global__ void tops_f64f64_sm80tc_8x8x4_1024unroll_cuda_kernel() {
292+
__global__ void tops_f64f64_sm80tc_8x8x4_loop128_cuda_kernel() {
293293
//! The 16x16x16 won't compile, 8x8x4 will.
294294
//? On Ampere: 8x8x4.
295295
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
296-
tops_tc_cuda_kernel<double, double, 8, 8, 4, 1024>();
296+
tops_tc_cuda_kernel<double, double, 8, 8, 4>();
297297
#endif
298298
}
299299

300-
__global__ void tops_b1i32and_sm80tc_8x8x128_1024unroll_cuda_kernel() {
300+
__global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel() {
301301
//! The 16x16x16 won't compile, 8x8x128 will.
302302
//? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
303303
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
304-
binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t, 8, 8, 128, 1024>(
304+
binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t, 8, 8, 128>(
305305
nvcuda::wmma::experimental::bmmaBitOp::bmmaBitOpAND,
306306
nvcuda::wmma::experimental::bmmaAccumulateOp::bmmaAccumulateOpPOPC);
307307
#endif

less_slow_sm70.ptx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
.target sm_70 // Target architecture (SM 7.0 - Volta GPUs)
3838
.address_size 64 // 64-bit addressing
3939

40-
.visible .entry tops_f16f16_sm70tc_16x16x16_1024loop_ptx_kernel()
40+
.visible .entry tops_f16f16_sm70tc_16x16x16_loop128_ptx_kernel()
4141
{
4242
// Accumulator registers used for both input and output of the MMA operation
4343
.reg .b32 accum_0, accum_1, accum_2, accum_3;
@@ -58,7 +58,7 @@
5858

5959
// Set up loop counter and loop limit
6060
mov.u32 loop_counter, 0;
61-
mov.u32 loop_limit, 1024;
61+
mov.u32 loop_limit, 128;
6262

6363
// Zero-initialize the accumulator registers
6464
mov.f32 accum_0, 0.0;
@@ -89,7 +89,7 @@
8989
mov.b32 matrix_b_6, packed_const;
9090
mov.b32 matrix_b_7, packed_const;
9191

92-
// The main loop will repeat for 1024 iterations
92+
// The main loop will repeat for 128 iterations
9393
loop_start:
9494
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
9595
@exit_predicate bra loop_end;

0 commit comments

Comments
 (0)