Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 129 additions & 139 deletions pufferlib/extensions/cuda/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,142 +17,122 @@ inline int seq_size(int N) {
return (N + SEQ_SIZE - 1) / SEQ_SIZE;
}

// If you can get this to work, go ahead. I tried.
// NVCC won't parse templated types in kernel launches
/*
template <template <class> class KernelFn, typename... Args>
void dispatch_and_launch(const at::Tensor& example_tensor, Args... args) {
const int64_t N = example_tensor.numel();
const int64_t block = LAUNCH_BLOCK_SIZE;
const int64_t grid = (N + block - 1) / block;
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(example_tensor.device());

at::ScalarType dtype = example_tensor.scalar_type();
if (dtype == at::ScalarType::Float) {
KernelFn<float><<<grid, block, 0, stream>>>(args..., N);
} else if (dtype == at::ScalarType::Half) {
KernelFn<__half><<<grid, block, 0, stream>>>(args..., N);
} else if (dtype == at::ScalarType::BFloat16) {
KernelFn<__nv_bfloat16><<<grid, block, 0, stream>>>(args..., N);
} else {
AT_ERROR("Unsupported dtype: ", dtype);
// ===== RMSNorm (B, T, H) =====
// Fused forward/backward with one reduction per (B*T) row.
// Accumulate in float for numerical stability; supports float32 and bfloat16.

template<int BLOCK>
__device__ __forceinline__ float block_sum(float v) {
__shared__ float smem[BLOCK];
smem[threadIdx.x] = v;
__syncthreads();
for (int offset = BLOCK / 2; offset > 0; offset >>= 1) {
if (threadIdx.x < offset) {
smem[threadIdx.x] += smem[threadIdx.x + offset];
}
__syncthreads();
}
return smem[0];
}
*/

template<typename T>
template<typename T, int BLOCK>
__global__ void rmsnorm_forward_kernel(
T* __restrict__ out,
float* __restrict__ inv_norm_buf,
float* __restrict__ inv_norm_buf, // shape [B*T]
const T* __restrict__ x,
const T* __restrict__ weight,
double eps,
const T* __restrict__ weight, // shape [H]
float eps,
int T_total,
int H,
int B
int H
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * T_total) return;

int b = idx / T_total;
int t = idx % T_total;
int base = b*T_total*H + t*H;
(void)T_total;
int row = blockIdx.x; // row = b * T + t
int base = row * H;

float sum_sq = 0.0f;
for (int h = 0; h < H; h++) {
int curr = base + h;
float x_val = float(x[curr]);
sum_sq += x_val * x_val;
// Accumulate sum of squares for this row
float sum = 0.0f;
for (int h = threadIdx.x; h < H; h += BLOCK) {
float v = float(x[base + h]);
sum += v * v;
}
sum = block_sum<BLOCK>(sum);

float rms = sqrtf(sum_sq/H + eps);
float inv_rms = 1.0f / rms;
inv_norm_buf[idx] = inv_rms;
float inv_rms = rsqrtf(sum / H + eps);
if (threadIdx.x == 0) {
inv_norm_buf[row] = inv_rms;
}
__syncthreads();

for (int h = 0; h < H; h++) {
int curr = base + h;
out[curr] = T(weight[h] * x[curr] * inv_rms);
// Write normalized output
for (int h = threadIdx.x; h < H; h += BLOCK) {
float v = float(x[base + h]);
float w = float(weight[h]);
out[base + h] = T(v * w * inv_rms);
}
}

template<typename T>
__global__ void rmsnorm_backward_kernel(
template<typename T, int BLOCK>
__global__ void rmsnorm_backward_input_kernel(
T* __restrict__ grad_x,
T* __restrict__ grad_weight,
const T* __restrict__ grad_out,
const float* __restrict__ inv_norm_buf,
const T* __restrict__ x_buf,
const float* __restrict__ inv_norm_buf, // shape [B*T]
const T* __restrict__ x,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B
int H
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= T_total*H*B) return;
int base = idx % H;
int norm_idx = idx / H;

float inv_rms = inv_norm_buf[norm_idx];
float inv_rms_3 = inv_rms * inv_rms * inv_rms;

grad_x[idx] = weight[base] * grad_out[idx] * inv_rms;
grad_weight[idx] = grad_out[idx] * inv_rms;

float wg_x = 0.0f;
for (int h=0; h<H; h++) {
float x = x_buf[base + h];
float w = weight[h];
float g = grad_out[base + h];
wg_x += w*g*x;
(void)T_total;
int row = blockIdx.x;
int base = row * H;
float inv_rms = inv_norm_buf[row];

// dot = sum_h grad_out * weight * x
float dot = 0.0f;
for (int h = threadIdx.x; h < H; h += BLOCK) {
float go = float(grad_out[base + h]);
float w = float(weight[h]);
float xv = float(x[base + h]);
dot += go * w * xv;
}
dot = block_sum<BLOCK>(dot);
float coeff = dot * inv_rms * inv_rms * inv_rms / H; // inv_rms^3 / H

// grad_x
for (int h = threadIdx.x; h < H; h += BLOCK) {
float go = float(grad_out[base + h]);
float w = float(weight[h]);
float xv = float(x[base + h]);
float gx = w * go * inv_rms - xv * coeff;
grad_x[base + h] = T(gx);
}
float x = x_buf[idx];
grad_x[idx] -= x*wg_x*inv_rms_3/float(H);
}

/*
template<typename T>
__global__ void rmsnorm_backward_kernel(
T* grad_x,
T* grad_weight,
const T* grad_out,
const float* inv_norm_buf,
const T* x,
const T* weight,
double eps,
int T_total,
int H,
int B
__global__ void rmsnorm_backward_weight_kernel(
T* __restrict__ grad_weight, // shape [H]
const T* __restrict__ grad_out,
const float* __restrict__ inv_norm, // shape [B*T]
const T* __restrict__ x,
int rows, // B * T
int H
) {
int total_elements = B * T_total * H;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_elements) return;

int h = idx % H;
int vec_idx = idx / H; // index of the vector (b,t)
int offset = vec_idx * H;

float inv_rms = inv_norm_buf[vec_idx];
float inv_rms3 = inv_rms * inv_rms * inv_rms;
int h = blockIdx.x * blockDim.x + threadIdx.x;
if (h >= H) return;

// ∂L/∂γ_h += grad_out * (x / rms)
float gw = grad_out[idx] * (float)x[idx] * inv_rms;
atomicAdd((float*)&grad_weight[h], gw);

// Compute reduction: sum_h weight[h] * grad_out[h] * x[h]
float sum = 0.0f;
for (int i = 0; i < H; ++i) {
sum += (float)weight[i] * (float)grad_out[offset + i] * (float)x[offset + i];
float acc = 0.0f;
for (int row = 0; row < rows; row++) {
int idx = row * H + h;
acc += float(grad_out[idx]) * float(x[idx]) * inv_norm[row];
}
float reduction = sum * inv_rms; // = σ γ g hat_x

float dx = (float)weight[h] * (float)grad_out[idx] * inv_rms
- (float)x[idx] * reduction * inv_rms3 / H;
grad_weight[h] = T(acc);
}

grad_x[idx] = T(dx);
// Heuristic: smaller blocks for small H to avoid wasted threads
inline int rms_block_size(int H) {
if (H <= 64) return 64;
if (H <= 128) return 128;
return 256;
}
*/

template<typename T>
void launch_rmsnorm_forward(
Expand All @@ -165,23 +145,26 @@ void launch_rmsnorm_forward(
int H,
int B
) {
int total = B * T_total;
int grid = grid_size(total);

rmsnorm_forward_kernel<T><<<grid, BLOCK_SIZE>>>(
out,
inv_norm_buf,
x,
weight,
eps,
T_total,
H,
B
);
int rows = B * T_total;
int block = rms_block_size(H);

switch (block) {
case 64:
rmsnorm_forward_kernel<T, 64><<<rows, 64>>>(
out, inv_norm_buf, x, weight, static_cast<float>(eps), T_total, H);
break;
case 128:
rmsnorm_forward_kernel<T, 128><<<rows, 128>>>(
out, inv_norm_buf, x, weight, static_cast<float>(eps), T_total, H);
break;
default:
rmsnorm_forward_kernel<T, 256><<<rows, 256>>>(
out, inv_norm_buf, x, weight, static_cast<float>(eps), T_total, H);
}

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in forward: %s\n", cudaGetErrorString(err));
fprintf(stderr, "CUDA kernel launch error in RMSNorm forward: %s\n", cudaGetErrorString(err));
}
}

Expand All @@ -193,32 +176,39 @@ void launch_rmsnorm_backward(
const float* __restrict__ inv_norm_buf,
const T* __restrict__ x_buf,
const T* __restrict__ weight,
double eps,
double eps, // unused but kept for API parity
int T_total,
int H,
int B
) {
// The backward is fully parallel
// since the inv norm is cached
int total = B * T_total * H;
int grid = grid_size(total);
(void)eps;
int rows = B * T_total;
int block = rms_block_size(H);

// Grad w.r.t. x
switch (block) {
case 64:
rmsnorm_backward_input_kernel<T, 64><<<rows, 64>>>(
grad_x, grad_out, inv_norm_buf, x_buf, weight, T_total, H);
break;
case 128:
rmsnorm_backward_input_kernel<T, 128><<<rows, 128>>>(
grad_x, grad_out, inv_norm_buf, x_buf, weight, T_total, H);
break;
default:
rmsnorm_backward_input_kernel<T, 256><<<rows, 256>>>(
grad_x, grad_out, inv_norm_buf, x_buf, weight, T_total, H);
}

rmsnorm_backward_kernel<T><<<grid, BLOCK_SIZE>>>(
grad_x,
grad_weight,
grad_out,
inv_norm_buf,
x_buf,
weight,
eps,
T_total,
H,
B
);
// Grad w.r.t. weight (one reduction over rows for each h)
int threads = 256;
int blocks = (H + threads - 1) / threads;
rmsnorm_backward_weight_kernel<T><<<blocks, threads>>>(
grad_weight, grad_out, inv_norm_buf, x_buf, rows, H);

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in backward: %s\n", cudaGetErrorString(err));
fprintf(stderr, "CUDA kernel launch error in RMSNorm backward: %s\n", cudaGetErrorString(err));
}
}

Expand Down
22 changes: 9 additions & 13 deletions pufferlib/extensions/cuda/modules.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ torch::autograd::tensor_list log_coeffs_and_values(
return LogCoeffsAndValuesFunction::apply(gate, hidden);
}

/*
class RMSNormFunction: public torch::autograd::Function<RMSNormFunction> {
public:
static torch::autograd::tensor_list forward(
Expand All @@ -158,6 +157,10 @@ public:
TORCH_CHECK(weight.dim() == 1, "weight must be (H,)");
TORCH_CHECK(x.size(2) == weight.size(0), "H must match");

// Ensure contiguous for flat indexing
x = x.contiguous();
weight = weight.contiguous();

auto dtype = x.dtype();
auto device = x.device();
auto B = x.size(0);
Expand All @@ -167,7 +170,7 @@ public:
auto out = torch::empty({B, T, H}, x.options());

auto options_float = torch::TensorOptions().dtype(torch::kFloat32).device(device);
auto inv_norm = torch::empty({B, T}, options_float);
auto inv_norm = torch::empty({B * T}, options_float);

if (dtype == torch::kFloat32) {
launch_rmsnorm_forward<float>(
Expand Down Expand Up @@ -195,13 +198,8 @@ public:
TORCH_CHECK(false, "Unsupported dtype. Only float32 and bfloat16 supported.");
}

// TODO: don't save eps as a tensor
//ctx->saved_data["eps"] = eps; // store in saved_data instead

// Save for backward
auto eps_tensor = torch::tensor(eps);
ctx->save_for_backward({x, weight, out, inv_norm, eps_tensor});

ctx->saved_data["eps"] = eps;
ctx->save_for_backward({x, weight, inv_norm});
return {out};
}
static torch::autograd::tensor_list backward(
Expand All @@ -211,9 +209,8 @@ public:
auto saved = ctx->get_saved_variables();
auto x = saved[0].contiguous();
auto weight = saved[1].contiguous();
auto out = saved[2].contiguous();
auto inv_norm = saved[3].contiguous();
double eps = saved[4].item<double>();
auto inv_norm = saved[2].contiguous();
double eps = ctx->saved_data["eps"].to<double>();

auto grad_out = grad_outputs[0].contiguous();
auto dtype = x.dtype();
Expand Down Expand Up @@ -266,7 +263,6 @@ torch::autograd::tensor_list rmsnorm(
) {
return RMSNormFunction::apply(x, weight, eps);
}
*/

/*
class RMSNormImpl : public torch::nn::Module {
Expand Down
Loading