@@ -180,7 +180,7 @@ void reverse_and_sort_with_cub(std::uint32_t *device_pointer, std::size_t array_
180180 * @brief A CUDA kernel that @b repeatedly computes the product of two small
181181 * matrices of size MxN and NxK using Tensor Cores.
182182 */
183- template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_>
183+ template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_ = 128 >
184184__device__ inline void tops_tc_cuda_kernel () {
185185 using namespace nvcuda ;
186186 wmma::fragment<wmma::matrix_a, m_, n_, k_, input_type_, wmma::row_major> a_frag;
@@ -210,7 +210,7 @@ __device__ inline void tops_tc_cuda_kernel() {
210210 *
211211 * @see Docs: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#sub-byte-operations
212212 */
213- template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_>
213+ template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_ = 128 >
214214__device__ inline void binary_tops_tc_cuda_kernel ( //
215215 nvcuda::wmma::experimental::bmmaBitOp bit_op, nvcuda::wmma::experimental::bmmaAccumulateOp acc_op) {
216216 using namespace nvcuda ;
@@ -225,48 +225,48 @@ __device__ inline void binary_tops_tc_cuda_kernel( //
225225
226226#pragma region Volta
227227
228- __global__ void tops_f16f16_sm70tc_16x16x16_1024unroll_cuda_kernel () {
228+ __global__ void tops_f16f16_sm70tc_16x16x16_loop128_cuda_kernel () {
229229 // ? On Volta: 8x8x4.
230230 // ? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
231231 // ? On Ampere: 16x8x8 / 16x8x16.
232232#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
233- tops_tc_cuda_kernel<half, half, 16 , 16 , 16 , 1024 >();
233+ tops_tc_cuda_kernel<half, half, 16 , 16 , 16 >();
234234#endif
235235}
236- __global__ void tops_f16f32_sm70tc_16x16x16_1024unroll_cuda_kernel () {
236+ __global__ void tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel () {
237237 // ? On Volta: 8x8x4.
238238 // ? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
239239 // ? On Ampere: 16x8x8 / 16x8x16.
240240#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
241- tops_tc_cuda_kernel<half, float , 16 , 16 , 16 , 1024 >();
241+ tops_tc_cuda_kernel<half, float , 16 , 16 , 16 >();
242242#endif
243243}
244244
245245#pragma endregion
246246
247247#pragma region Turing
248248
249- __global__ void tops_u8i32_sm75tc_16x16x16_1024unroll_cuda_kernel () {
249+ __global__ void tops_u8i32_sm75tc_16x16x16_loop128_cuda_kernel () {
250250 // ? On Turing: 8x8x16.
251251 // ? On Ampere: 8x8x16 / 16x8x16 / 16x8x32.
252252#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
253- tops_tc_cuda_kernel<std::uint8_t , int32_t , 16 , 16 , 16 , 1024 >();
253+ tops_tc_cuda_kernel<std::uint8_t , int32_t , 16 , 16 , 16 >();
254254#endif
255255}
256- __global__ void tops_u4i32_sm75tc_8x8x32_1024unroll_cuda_kernel () {
256+ __global__ void tops_u4i32_sm75tc_8x8x32_loop128_cuda_kernel () {
257257 // ! The 16x16x16 won't compile, 8x8x32 will.
258258 // ? On Turing: 8x8x32.
259259 // ? On Ampere: 8x8x32 / 16x8x32 / 16x8x64.
260260#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
261- tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::u4, int32_t , 8 , 8 , 32 , 1024 >();
261+ tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::u4, int32_t , 8 , 8 , 32 >();
262262#endif
263263}
264- __global__ void tops_b1i32xor_sm75tc_8x8x128_1024unroll_cuda_kernel () {
264+ __global__ void tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel () {
265265 // ! The 16x16x16 won't compile, 8x8x128 will.
266266 // ? On Turing: 8x8x128.
267267 // ? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
268268#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
269- binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t , 8 , 8 , 128 , 1024 >(
269+ binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t , 8 , 8 , 128 >(
270270 nvcuda::wmma::experimental::bmmaBitOp::bmmaBitOpXOR,
271271 nvcuda::wmma::experimental::bmmaAccumulateOp::bmmaAccumulateOpPOPC);
272272#endif
@@ -276,32 +276,32 @@ __global__ void tops_b1i32xor_sm75tc_8x8x128_1024unroll_cuda_kernel() {
276276
277277#pragma region Ampere
278278
279- __global__ void tops_bf16f32_sm80tc_16x16x16_1024unroll_cuda_kernel () {
279+ __global__ void tops_bf16f32_sm80tc_16x16x16_loop128_cuda_kernel () {
280280 // ? On Ampere: 16x8x8 / 16x8x16.
281281#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
282- tops_tc_cuda_kernel<__nv_bfloat16, float , 16 , 16 , 16 , 1024 >();
282+ tops_tc_cuda_kernel<__nv_bfloat16, float , 16 , 16 , 16 >();
283283#endif
284284}
285- __global__ void tops_tf32f32_sm80tc_16x16x8_1024unroll_cuda_kernel () {
285+ __global__ void tops_tf32f32_sm80tc_16x16x8_loop128_cuda_kernel () {
286286 // ! The 16x16x16 won't compile, 16x16x8 will.
287287 // ? On Ampere: 16x8x4.
288288#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
289- tops_tc_cuda_kernel<nvcuda::wmma::precision::tf32, float , 16 , 16 , 8 , 1024 >();
289+ tops_tc_cuda_kernel<nvcuda::wmma::precision::tf32, float , 16 , 16 , 8 >();
290290#endif
291291}
292- __global__ void tops_f64f64_sm80tc_8x8x4_1024unroll_cuda_kernel () {
292+ __global__ void tops_f64f64_sm80tc_8x8x4_loop128_cuda_kernel () {
293293 // ! The 16x16x16 won't compile, 8x8x4 will.
294294 // ? On Ampere: 8x8x4.
295295#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
296- tops_tc_cuda_kernel<double , double , 8 , 8 , 4 , 1024 >();
296+ tops_tc_cuda_kernel<double , double , 8 , 8 , 4 >();
297297#endif
298298}
299299
300- __global__ void tops_b1i32and_sm80tc_8x8x128_1024unroll_cuda_kernel () {
300+ __global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel () {
301301 // ! The 16x16x16 won't compile, 8x8x128 will.
302302 // ? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
303303#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
304- binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t , 8 , 8 , 128 , 1024 >(
304+ binary_tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::b1, int32_t , 8 , 8 , 128 >(
305305 nvcuda::wmma::experimental::bmmaBitOp::bmmaBitOpAND,
306306 nvcuda::wmma::experimental::bmmaAccumulateOp::bmmaAccumulateOpPOPC);
307307#endif
0 commit comments