66
77#include < c10/util/BFloat16.h>
88#include < c10/util/Half.h>
9+ #ifndef USE_MUSA
910#include < c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
11+ #else
12+ #include " torch_musa/csrc/core/MUSAException.h" // For C10_MUSA_CHECK and C10_MUSA_KERNEL_LAUNCH_CHECK
13+ #endif
1014
15+ #include " vendor.h"
1116#include " fast_hadamard_transform.h"
1217#include " fast_hadamard_transform_common.h"
1318#include " fast_hadamard_transform_special.h"
@@ -28,7 +33,7 @@ struct fast_hadamard_transform_kernel_traits {
2833 using vec_t = typename BytesToType<kNBytes * kNElts >::Type;
2934 static constexpr int kNChunks = N / (kNElts * kNThreads );
3035 // We don't want to use more than 32 KB of shared memory.
31- static constexpr int kSmemExchangeSize = std::min (N * 4 , 32 * 1024 );
36+ static constexpr int kSmemExchangeSize = MIN (N * 4 , 32 * 1024 );
3237 static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize ;
3338 static_assert (kNExchangeRounds * kSmemExchangeSize == N * 4 );
3439 static constexpr int kSmemSize = kSmemExchangeSize ;
@@ -51,7 +56,7 @@ struct fast_hadamard_transform_12N_kernel_traits {
5156 static constexpr int kNChunks = N / (kNElts * kNThreads );
5257 static_assert (kNChunks == 12 );
5358 // We don't want to use more than 24 KB of shared memory.
54- static constexpr int kSmemExchangeSize = std::min (N * 4 , 24 * 1024 );
59+ static constexpr int kSmemExchangeSize = MIN (N * 4 , 24 * 1024 );
5560 static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize ;
5661 static_assert (kNExchangeRounds * kSmemExchangeSize == N * 4 );
5762 static constexpr int kSmemSize = kSmemExchangeSize ;
@@ -74,7 +79,7 @@ struct fast_hadamard_transform_20N_kernel_traits {
7479 static constexpr int kNChunks = N / (kNElts * kNThreads );
7580 static_assert (kNChunks == 20 );
7681 // We don't want to use more than 40 KB of shared memory.
77- static constexpr int kSmemExchangeSize = std::min (N * 4 , 40 * 1024 );
82+ static constexpr int kSmemExchangeSize = MIN (N * 4 , 40 * 1024 );
7883 static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize ;
7984 static_assert (kNExchangeRounds * kSmemExchangeSize == N * 4 );
8085 static constexpr int kSmemSize = kSmemExchangeSize ;
@@ -97,7 +102,7 @@ struct fast_hadamard_transform_28N_kernel_traits {
97102 static constexpr int kNChunks = N / (kNElts * kNThreads );
98103 static_assert (kNChunks == 28 );
99104 // We don't want to use more than 28 KB of shared memory.
100- static constexpr int kSmemExchangeSize = std::min (N * 4 , 28 * 1024 );
105+ static constexpr int kSmemExchangeSize = MIN (N * 4 , 28 * 1024 );
101106 static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize ;
102107 static_assert (kNExchangeRounds * kSmemExchangeSize == N * 4 );
103108 static constexpr int kSmemSize = kSmemExchangeSize ;
@@ -120,7 +125,7 @@ struct fast_hadamard_transform_40N_kernel_traits {
120125 static constexpr int kNChunks = N / (kNElts * kNThreads );
121126 static_assert (kNChunks == 40 );
122127 // We don't want to use more than 40 KB of shared memory.
123- static constexpr int kSmemExchangeSize = std::min (N * 4 , 40 * 1024 );
128+ static constexpr int kSmemExchangeSize = MIN (N * 4 , 40 * 1024 );
124129 static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize ;
125130 static_assert (kNExchangeRounds * kSmemExchangeSize == N * 4 );
126131 static constexpr int kSmemSize = kSmemExchangeSize ;
@@ -163,7 +168,7 @@ void fast_hadamard_transform_kernel(HadamardParamsBase params) {
163168
164169 constexpr int kLogNElts = cilog2 (Ktraits::kNElts );
165170 static_assert (1 << kLogNElts == kNElts , " kNElts must be a power of 2" );
166- constexpr int kWarpSize = std::min (kNThreads , 32 );
171+ constexpr int kWarpSize = MIN (kNThreads , WARP_SIZE );
167172 constexpr int kLogWarpSize = cilog2 (kWarpSize );
168173 static_assert (1 << kLogWarpSize == kWarpSize , " Warp size must be a power of 2" );
169174 constexpr int kNWarps = kNThreads / kWarpSize ;
@@ -234,10 +239,12 @@ void fast_hadamard_transform_launch(HadamardParamsBase ¶ms, cudaStream_t str
234239 constexpr int kSmemSize = Ktraits::kSmemSize ;
235240 dim3 grid (params.batch );
236241 auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
242+ #ifndef USE_ROCM
237243 if (kSmemSize >= 48 * 1024 ) {
238244 C10_CUDA_CHECK (cudaFuncSetAttribute (
239245 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
240246 }
247+ #endif
241248 kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
242249 C10_CUDA_KERNEL_LAUNCH_CHECK ();
243250}
@@ -279,10 +286,12 @@ void fast_hadamard_transform_12N_launch(HadamardParamsBase ¶ms, cudaStream_t
279286 constexpr int kSmemSize = Ktraits::kSmemSize ;
280287 dim3 grid (params.batch );
281288 auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
289+ #ifndef USE_ROCM
282290 if (kSmemSize >= 48 * 1024 ) {
283291 C10_CUDA_CHECK (cudaFuncSetAttribute (
284292 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
285293 }
294+ #endif
286295 kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
287296 C10_CUDA_KERNEL_LAUNCH_CHECK ();
288297}
@@ -307,7 +316,7 @@ void fast_hadamard_transform_12N_cuda(HadamardParamsBase ¶ms, cudaStream_t s
307316 fast_hadamard_transform_12N_launch<128 , 9 , input_t >(params, stream);
308317 } else if (params.log_N == 10 ) {
309318 fast_hadamard_transform_12N_launch<256 , 10 , input_t >(params, stream);
310- }
319+ }
311320}
312321
313322template <int kNThreads , int kLogN , typename input_t >
@@ -316,10 +325,12 @@ void fast_hadamard_transform_20N_launch(HadamardParamsBase ¶ms, cudaStream_t
316325 constexpr int kSmemSize = Ktraits::kSmemSize ;
317326 dim3 grid (params.batch );
318327 auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
328+ #ifndef USE_ROCM
319329 if (kSmemSize >= 48 * 1024 ) {
320330 C10_CUDA_CHECK (cudaFuncSetAttribute (
321331 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
322332 }
333+ #endif
323334 kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
324335 C10_CUDA_KERNEL_LAUNCH_CHECK ();
325336}
@@ -353,10 +364,12 @@ void fast_hadamard_transform_28N_launch(HadamardParamsBase ¶ms, cudaStream_t
353364 constexpr int kSmemSize = Ktraits::kSmemSize ;
354365 dim3 grid (params.batch );
355366 auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
367+ #ifndef USE_ROCM
356368 if (kSmemSize >= 48 * 1024 ) {
357369 C10_CUDA_CHECK (cudaFuncSetAttribute (
358370 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
359371 }
372+ #endif
360373 kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
361374 C10_CUDA_KERNEL_LAUNCH_CHECK ();
362375}
@@ -390,10 +403,12 @@ void fast_hadamard_transform_40N_launch(HadamardParamsBase ¶ms, cudaStream_t
390403 constexpr int kSmemSize = Ktraits::kSmemSize ;
391404 dim3 grid (params.batch );
392405 auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
406+ #ifndef USE_ROCM
393407 if (kSmemSize >= 48 * 1024 ) {
394408 C10_CUDA_CHECK (cudaFuncSetAttribute (
395409 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
396410 }
411+ #endif
397412 kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
398413 C10_CUDA_KERNEL_LAUNCH_CHECK ();
399414}
0 commit comments