@@ -61,7 +61,7 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, _
6161 int smem_size
6262 = STAGES * STAGE_UNROLL * (TILE_M * TILE_K * sizeof (__nv_bfloat16) + TILE_N * TILE_K * sizeof (__nv_bfloat16));
6363
64- gpuErrChk (cudaFuncSetAttribute (kernel <WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>,
64+ gpuErrChk (cudaFuncSetAttribute (tinygemm_kernel <WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>,
6565 cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
6666
6767 int tiles_m = (output_features + TILE_M - 1 ) / TILE_M;
@@ -82,8 +82,8 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, _
8282 attrs[0 ].val .programmaticStreamSerializationAllowed = 1 ;
8383 config.numAttrs = 1 ;
8484
85- cudaLaunchKernelEx (&config, &kernel <WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>, gC , gA , gB ,
86- bias, output_features, batch_size, input_features, weight_map, activation_map, nullptr );
85+ cudaLaunchKernelEx (&config, &tinygemm_kernel <WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>,
86+ gC , gA , gB , bias, output_features, batch_size, input_features, weight_map, activation_map, nullptr );
8787}
8888
8989torch::Tensor tinygemm2_cuda_forward (torch::Tensor input, torch::Tensor weight, torch::Tensor bias)
0 commit comments