Skip to content

Commit da3c2f4

Browse files
committed
Multi-backend support
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent f134af6 commit da3c2f4

File tree

8 files changed

+234
-76
lines changed

8 files changed

+234
-76
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
__pycache__/
2+
build/
3+
dist/
4+
*.egg-info/
5+
6+
# hipified files
7+
*.hip
8+
*_hip.*

benchmarks/benchmark_fast_hadamard_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
dim = 16384 * 2
1515
dtype = torch.float16
1616
device = "cuda"
17+
if hasattr(torch, "musa"):
18+
device = "musa"
1719

1820
torch.random.manual_seed(0)
1921
x = torch.randn(batch_size, seqlen, dim, dtype=dtype, device=device)

csrc/fast_hadamard_transform.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22
* Copyright (c) 2023, Tri Dao.
33
******************************************************************************/
44

5+
#ifndef USE_MUSA
56
#include <ATen/cuda/CUDAContext.h>
67
#include <c10/cuda/CUDAGuard.h>
8+
#else
9+
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
10+
#include "torch_musa/csrc/core/MUSAGuard.h"
11+
#endif
712
#include <torch/extension.h>
813
#include <vector>
914

15+
#include "vendor.h"
1016
#include "fast_hadamard_transform.h"
1117

1218
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
@@ -74,7 +80,11 @@ fast_hadamard_transform(at::Tensor &x, float scale) {
7480
auto input_type = x.scalar_type();
7581
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
7682

83+
#ifndef USE_MUSA
7784
TORCH_CHECK(x.is_cuda());
85+
#else
86+
TORCH_CHECK(x.is_privateuseone());
87+
#endif
7888

7989
const auto shapes_og = x.sizes();
8090
const int dim_og = x.size(-1);
@@ -117,7 +127,11 @@ fast_hadamard_transform_12N(at::Tensor &x, float scale) {
117127
auto input_type = x.scalar_type();
118128
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
119129

130+
#ifndef USE_MUSA
120131
TORCH_CHECK(x.is_cuda());
132+
#else
133+
TORCH_CHECK(x.is_privateuseone());
134+
#endif
121135

122136
const auto shapes_og = x.sizes();
123137
const int dim_og = x.size(-1);
@@ -160,7 +174,11 @@ fast_hadamard_transform_20N(at::Tensor &x, float scale) {
160174
auto input_type = x.scalar_type();
161175
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
162176

177+
#ifndef USE_MUSA
163178
TORCH_CHECK(x.is_cuda());
179+
#else
180+
TORCH_CHECK(x.is_privateuseone());
181+
#endif
164182

165183
const auto shapes_og = x.sizes();
166184
const int dim_og = x.size(-1);
@@ -203,7 +221,11 @@ fast_hadamard_transform_28N(at::Tensor &x, float scale) {
203221
auto input_type = x.scalar_type();
204222
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
205223

224+
#ifndef USE_MUSA
206225
TORCH_CHECK(x.is_cuda());
226+
#else
227+
TORCH_CHECK(x.is_privateuseone());
228+
#endif
207229

208230
const auto shapes_og = x.sizes();
209231
const int dim_og = x.size(-1);
@@ -246,7 +268,11 @@ fast_hadamard_transform_40N(at::Tensor &x, float scale) {
246268
auto input_type = x.scalar_type();
247269
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
248270

271+
#ifndef USE_MUSA
249272
TORCH_CHECK(x.is_cuda());
273+
#else
274+
TORCH_CHECK(x.is_privateuseone());
275+
#endif
250276

251277
const auto shapes_og = x.sizes();
252278
const int dim_og = x.size(-1);

csrc/fast_hadamard_transform_common.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
#pragma once
66

7-
#include <cuda_bf16.h>
8-
#include <cuda_fp16.h>
7+
#include "vendor.h"
98

109
#define FULL_MASK 0xffffffff
1110

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66

77
#include <c10/util/BFloat16.h>
88
#include <c10/util/Half.h>
9+
#ifndef USE_MUSA
910
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
11+
#else
12+
#include "torch_musa/csrc/core/MUSAException.h" // For C10_MUSA_CHECK and C10_MUSA_KERNEL_LAUNCH_CHECK
13+
#endif
1014

15+
#include "vendor.h"
1116
#include "fast_hadamard_transform.h"
1217
#include "fast_hadamard_transform_common.h"
1318
#include "fast_hadamard_transform_special.h"
@@ -28,7 +33,7 @@ struct fast_hadamard_transform_kernel_traits {
2833
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
2934
static constexpr int kNChunks = N / (kNElts * kNThreads);
3035
// We don't want to use more than 32 KB of shared memory.
31-
static constexpr int kSmemExchangeSize = std::min(N * 4, 32 * 1024);
36+
static constexpr int kSmemExchangeSize = MIN(N * 4, 32 * 1024);
3237
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
3338
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
3439
static constexpr int kSmemSize = kSmemExchangeSize;
@@ -51,7 +56,7 @@ struct fast_hadamard_transform_12N_kernel_traits {
5156
static constexpr int kNChunks = N / (kNElts * kNThreads);
5257
static_assert(kNChunks == 12);
5358
// We don't want to use more than 24 KB of shared memory.
54-
static constexpr int kSmemExchangeSize = std::min(N * 4, 24 * 1024);
59+
static constexpr int kSmemExchangeSize = MIN(N * 4, 24 * 1024);
5560
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
5661
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
5762
static constexpr int kSmemSize = kSmemExchangeSize;
@@ -74,7 +79,7 @@ struct fast_hadamard_transform_20N_kernel_traits {
7479
static constexpr int kNChunks = N / (kNElts * kNThreads);
7580
static_assert(kNChunks == 20);
7681
// We don't want to use more than 40 KB of shared memory.
77-
static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024);
82+
static constexpr int kSmemExchangeSize = MIN(N * 4, 40 * 1024);
7883
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
7984
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
8085
static constexpr int kSmemSize = kSmemExchangeSize;
@@ -97,7 +102,7 @@ struct fast_hadamard_transform_28N_kernel_traits {
97102
static constexpr int kNChunks = N / (kNElts * kNThreads);
98103
static_assert(kNChunks == 28);
99104
// We don't want to use more than 28 KB of shared memory.
100-
static constexpr int kSmemExchangeSize = std::min(N * 4, 28 * 1024);
105+
static constexpr int kSmemExchangeSize = MIN(N * 4, 28 * 1024);
101106
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
102107
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
103108
static constexpr int kSmemSize = kSmemExchangeSize;
@@ -120,7 +125,7 @@ struct fast_hadamard_transform_40N_kernel_traits {
120125
static constexpr int kNChunks = N / (kNElts * kNThreads);
121126
static_assert(kNChunks == 40);
122127
// We don't want to use more than 40 KB of shared memory.
123-
static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024);
128+
static constexpr int kSmemExchangeSize = MIN(N * 4, 40 * 1024);
124129
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
125130
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
126131
static constexpr int kSmemSize = kSmemExchangeSize;
@@ -163,7 +168,7 @@ void fast_hadamard_transform_kernel(HadamardParamsBase params) {
163168

164169
constexpr int kLogNElts = cilog2(Ktraits::kNElts);
165170
static_assert(1 << kLogNElts == kNElts, "kNElts must be a power of 2");
166-
constexpr int kWarpSize = std::min(kNThreads, 32);
171+
constexpr int kWarpSize = MIN(kNThreads, WARP_SIZE);
167172
constexpr int kLogWarpSize = cilog2(kWarpSize);
168173
static_assert(1 << kLogWarpSize == kWarpSize, "Warp size must be a power of 2");
169174
constexpr int kNWarps = kNThreads / kWarpSize;
@@ -234,10 +239,12 @@ void fast_hadamard_transform_launch(HadamardParamsBase &params, cudaStream_t str
234239
constexpr int kSmemSize = Ktraits::kSmemSize;
235240
dim3 grid(params.batch);
236241
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
242+
#ifndef USE_ROCM
237243
if (kSmemSize >= 48 * 1024) {
238244
C10_CUDA_CHECK(cudaFuncSetAttribute(
239245
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
240246
}
247+
#endif
241248
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
242249
C10_CUDA_KERNEL_LAUNCH_CHECK();
243250
}
@@ -279,10 +286,12 @@ void fast_hadamard_transform_12N_launch(HadamardParamsBase &params, cudaStream_t
279286
constexpr int kSmemSize = Ktraits::kSmemSize;
280287
dim3 grid(params.batch);
281288
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
289+
#ifndef USE_ROCM
282290
if (kSmemSize >= 48 * 1024) {
283291
C10_CUDA_CHECK(cudaFuncSetAttribute(
284292
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
285293
}
294+
#endif
286295
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
287296
C10_CUDA_KERNEL_LAUNCH_CHECK();
288297
}
@@ -307,7 +316,7 @@ void fast_hadamard_transform_12N_cuda(HadamardParamsBase &params, cudaStream_t s
307316
fast_hadamard_transform_12N_launch<128, 9, input_t>(params, stream);
308317
} else if (params.log_N == 10) {
309318
fast_hadamard_transform_12N_launch<256, 10, input_t>(params, stream);
310-
}
319+
}
311320
}
312321

313322
template<int kNThreads, int kLogN, typename input_t>
@@ -316,10 +325,12 @@ void fast_hadamard_transform_20N_launch(HadamardParamsBase &params, cudaStream_t
316325
constexpr int kSmemSize = Ktraits::kSmemSize;
317326
dim3 grid(params.batch);
318327
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
328+
#ifndef USE_ROCM
319329
if (kSmemSize >= 48 * 1024) {
320330
C10_CUDA_CHECK(cudaFuncSetAttribute(
321331
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
322332
}
333+
#endif
323334
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
324335
C10_CUDA_KERNEL_LAUNCH_CHECK();
325336
}
@@ -353,10 +364,12 @@ void fast_hadamard_transform_28N_launch(HadamardParamsBase &params, cudaStream_t
353364
constexpr int kSmemSize = Ktraits::kSmemSize;
354365
dim3 grid(params.batch);
355366
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
367+
#ifndef USE_ROCM
356368
if (kSmemSize >= 48 * 1024) {
357369
C10_CUDA_CHECK(cudaFuncSetAttribute(
358370
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
359371
}
372+
#endif
360373
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
361374
C10_CUDA_KERNEL_LAUNCH_CHECK();
362375
}
@@ -390,10 +403,12 @@ void fast_hadamard_transform_40N_launch(HadamardParamsBase &params, cudaStream_t
390403
constexpr int kSmemSize = Ktraits::kSmemSize;
391404
dim3 grid(params.batch);
392405
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
406+
#ifndef USE_ROCM
393407
if (kSmemSize >= 48 * 1024) {
394408
C10_CUDA_CHECK(cudaFuncSetAttribute(
395409
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
396410
}
411+
#endif
397412
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
398413
C10_CUDA_KERNEL_LAUNCH_CHECK();
399414
}

csrc/vendor.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#if !defined(USE_MUSA) && !defined(USE_ROCM)
4+
#include <cuda_bf16.h>
5+
#include <cuda_fp16.h>
6+
7+
#define WARP_SIZE 32
8+
#define MIN(A, B) std::min((A), (B))
9+
#elif defined(USE_MUSA)
10+
#include <musa_bf16.h>
11+
#include <musa_fp16.h>
12+
13+
#define WARP_SIZE 32
14+
#define MIN(A, B) std::min((A), (B))
15+
#define C10_CUDA_CHECK C10_MUSA_CHECK
16+
#define C10_CUDA_KERNEL_LAUNCH_CHECK C10_MUSA_KERNEL_LAUNCH_CHECK
17+
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
18+
#define cudaFuncSetAttribute musaFuncSetAttribute
19+
#define cudaStream_t musaStream_t
20+
21+
#include "torch_musa/csrc/core/MUSAGuard.h"
22+
#include "torch_musa/csrc/core/MUSAStream.h"
23+
namespace at {
24+
namespace cuda {
25+
#ifdef USE_MUSA
26+
using CUDAGuard = at::musa::MUSAGuard;
27+
inline at::musa::MUSAStream getCurrentCUDAStream() {
28+
return at::musa::getCurrentMUSAStream();
29+
}
30+
#endif
31+
} // namespace cuda
32+
} // namespace at
33+
#elif defined(USE_ROCM)
34+
#define WARP_SIZE 64
35+
#define MIN(A, B) ((A) < (B) ? (A) : (B))
36+
#define __shfl_xor_sync(MASK, X, OFFSET) __shfl_xor(X, OFFSET)
37+
#endif // !defined(USE_MUSA) && !defined(USE_ROCM)

0 commit comments

Comments
 (0)