Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
__pycache__/
build/
dist/
*.egg-info/

# hipified files
*.hip
*_hip.*
2 changes: 2 additions & 0 deletions benchmarks/benchmark_fast_hadamard_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
dim = 16384 * 2
dtype = torch.float16
device = "cuda"
if hasattr(torch, "musa"):
device = "musa"

torch.random.manual_seed(0)
x = torch.randn(batch_size, seqlen, dim, dtype=dtype, device=device)
Expand Down
26 changes: 26 additions & 0 deletions csrc/fast_hadamard_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#ifndef USE_MUSA
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#else
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#endif
#include <torch/extension.h>
#include <vector>

#include "vendor.h"
#include "fast_hadamard_transform.h"

#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
Expand Down Expand Up @@ -74,7 +80,11 @@ fast_hadamard_transform(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

#ifndef USE_MUSA
TORCH_CHECK(x.is_cuda());
#else
TORCH_CHECK(x.is_privateuseone());
#endif

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
Expand Down Expand Up @@ -117,7 +127,11 @@ fast_hadamard_transform_12N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

#ifndef USE_MUSA
TORCH_CHECK(x.is_cuda());
#else
TORCH_CHECK(x.is_privateuseone());
#endif

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
Expand Down Expand Up @@ -160,7 +174,11 @@ fast_hadamard_transform_20N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

#ifndef USE_MUSA
TORCH_CHECK(x.is_cuda());
#else
TORCH_CHECK(x.is_privateuseone());
#endif

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
Expand Down Expand Up @@ -203,7 +221,11 @@ fast_hadamard_transform_28N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

#ifndef USE_MUSA
TORCH_CHECK(x.is_cuda());
#else
TORCH_CHECK(x.is_privateuseone());
#endif

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
Expand Down Expand Up @@ -246,7 +268,11 @@ fast_hadamard_transform_40N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

#ifndef USE_MUSA
TORCH_CHECK(x.is_cuda());
#else
TORCH_CHECK(x.is_privateuseone());
#endif

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
Expand Down
3 changes: 1 addition & 2 deletions csrc/fast_hadamard_transform_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "vendor.h"

#define FULL_MASK 0xffffffff

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#ifndef USE_MUSA
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#else
#include "torch_musa/csrc/core/MUSAException.h" // For C10_MUSA_CHECK and C10_MUSA_KERNEL_LAUNCH_CHECK
#endif

#include "vendor.h"
#include "fast_hadamard_transform.h"
#include "fast_hadamard_transform_common.h"
#include "fast_hadamard_transform_special.h"
Expand All @@ -28,7 +33,7 @@ struct fast_hadamard_transform_kernel_traits {
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
static constexpr int kNChunks = N / (kNElts * kNThreads);
// We don't want to use more than 32 KB of shared memory.
static constexpr int kSmemExchangeSize = std::min(N * 4, 32 * 1024);
static constexpr int kSmemExchangeSize = MIN(N * 4, 32 * 1024);
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
static constexpr int kSmemSize = kSmemExchangeSize;
Expand All @@ -51,7 +56,7 @@ struct fast_hadamard_transform_12N_kernel_traits {
static constexpr int kNChunks = N / (kNElts * kNThreads);
static_assert(kNChunks == 12);
// We don't want to use more than 24 KB of shared memory.
static constexpr int kSmemExchangeSize = std::min(N * 4, 24 * 1024);
static constexpr int kSmemExchangeSize = MIN(N * 4, 24 * 1024);
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
static constexpr int kSmemSize = kSmemExchangeSize;
Expand All @@ -74,7 +79,7 @@ struct fast_hadamard_transform_20N_kernel_traits {
static constexpr int kNChunks = N / (kNElts * kNThreads);
static_assert(kNChunks == 20);
// We don't want to use more than 40 KB of shared memory.
static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024);
static constexpr int kSmemExchangeSize = MIN(N * 4, 40 * 1024);
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
static constexpr int kSmemSize = kSmemExchangeSize;
Expand All @@ -97,7 +102,7 @@ struct fast_hadamard_transform_28N_kernel_traits {
static constexpr int kNChunks = N / (kNElts * kNThreads);
static_assert(kNChunks == 28);
// We don't want to use more than 28 KB of shared memory.
static constexpr int kSmemExchangeSize = std::min(N * 4, 28 * 1024);
static constexpr int kSmemExchangeSize = MIN(N * 4, 28 * 1024);
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
static constexpr int kSmemSize = kSmemExchangeSize;
Expand All @@ -120,7 +125,7 @@ struct fast_hadamard_transform_40N_kernel_traits {
static constexpr int kNChunks = N / (kNElts * kNThreads);
static_assert(kNChunks == 40);
// We don't want to use more than 40 KB of shared memory.
static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024);
static constexpr int kSmemExchangeSize = MIN(N * 4, 40 * 1024);
static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
static constexpr int kSmemSize = kSmemExchangeSize;
Expand Down Expand Up @@ -163,7 +168,7 @@ void fast_hadamard_transform_kernel(HadamardParamsBase params) {

constexpr int kLogNElts = cilog2(Ktraits::kNElts);
static_assert(1 << kLogNElts == kNElts, "kNElts must be a power of 2");
constexpr int kWarpSize = std::min(kNThreads, 32);
constexpr int kWarpSize = MIN(kNThreads, WARP_SIZE);
constexpr int kLogWarpSize = cilog2(kWarpSize);
static_assert(1 << kLogWarpSize == kWarpSize, "Warp size must be a power of 2");
constexpr int kNWarps = kNThreads / kWarpSize;
Expand Down Expand Up @@ -234,10 +239,12 @@ void fast_hadamard_transform_launch(HadamardParamsBase &params, cudaStream_t str
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch);
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
#ifndef USE_ROCM
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
#endif
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -279,10 +286,12 @@ void fast_hadamard_transform_12N_launch(HadamardParamsBase &params, cudaStream_t
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch);
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
#ifndef USE_ROCM
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
#endif
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand All @@ -307,7 +316,7 @@ void fast_hadamard_transform_12N_cuda(HadamardParamsBase &params, cudaStream_t s
fast_hadamard_transform_12N_launch<128, 9, input_t>(params, stream);
} else if (params.log_N == 10) {
fast_hadamard_transform_12N_launch<256, 10, input_t>(params, stream);
}
}
}

template<int kNThreads, int kLogN, typename input_t>
Expand All @@ -316,10 +325,12 @@ void fast_hadamard_transform_20N_launch(HadamardParamsBase &params, cudaStream_t
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch);
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
#ifndef USE_ROCM
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
#endif
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -353,10 +364,12 @@ void fast_hadamard_transform_28N_launch(HadamardParamsBase &params, cudaStream_t
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch);
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
#ifndef USE_ROCM
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
#endif
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -390,10 +403,12 @@ void fast_hadamard_transform_40N_launch(HadamardParamsBase &params, cudaStream_t
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch);
auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
#ifndef USE_ROCM
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
#endif
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down
37 changes: 37 additions & 0 deletions csrc/vendor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#if !defined(USE_MUSA) && !defined(USE_ROCM)
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#define WARP_SIZE 32
#define MIN(A, B) std::min((A), (B))
#elif defined(USE_MUSA)
#include <musa_bf16.h>
#include <musa_fp16.h>

#define WARP_SIZE 32
#define MIN(A, B) std::min((A), (B))
#define C10_CUDA_CHECK C10_MUSA_CHECK
#define C10_CUDA_KERNEL_LAUNCH_CHECK C10_MUSA_KERNEL_LAUNCH_CHECK
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaStream_t musaStream_t

#include "torch_musa/csrc/core/MUSAGuard.h"
#include "torch_musa/csrc/core/MUSAStream.h"
namespace at {
namespace cuda {
#ifdef USE_MUSA
using CUDAGuard = at::musa::MUSAGuard;
inline at::musa::MUSAStream getCurrentCUDAStream() {
return at::musa::getCurrentMUSAStream();
}
#endif
} // namespace cuda
} // namespace at
#elif defined(USE_ROCM)
#define WARP_SIZE 64
#define MIN(A, B) ((A) < (B) ? (A) : (B))
#define __shfl_xor_sync(MASK, X, OFFSET) __shfl_xor(X, OFFSET)
#endif // !defined(USE_MUSA) && !defined(USE_ROCM)
Loading