Skip to content

Commit 15a2256

Browse files
committed
cleanups
1 parent ef4ef1a commit 15a2256

19 files changed

+42
-266
lines changed

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def get_ext_modules():
119119
[
120120
"torch_harmonics/disco/csrc/disco_helpers.cpp",
121121
],
122+
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
122123
extra_compile_args=get_helpers_compile_args(),
123124
)
124125
)
@@ -130,6 +131,7 @@ def get_ext_modules():
130131
[
131132
"torch_harmonics/attention/csrc/attention_helpers.cpp",
132133
],
134+
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
133135
extra_compile_args=get_helpers_compile_args(),
134136
)
135137
)
@@ -189,6 +191,7 @@ def get_ext_modules():
189191
CppExtension(
190192
"torch_harmonics.disco._C",
191193
disco_sources,
194+
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
192195
extra_compile_args=get_compile_args("disco")
193196
)
194197
)
@@ -212,6 +215,7 @@ def get_ext_modules():
212215
CUDAExtension(
213216
"torch_harmonics.attention._C",
214217
attention_sources,
218+
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
215219
extra_compile_args=get_compile_args("attention")
216220
)
217221
)
@@ -220,6 +224,7 @@ def get_ext_modules():
220224
CppExtension(
221225
"torch_harmonics.attention._C",
222226
attention_sources,
227+
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
223228
extra_compile_args=get_compile_args("attention")
224229
)
225230
)

torch_harmonics/attention/csrc/attention.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,4 @@
3636
#include <torch/library.h>
3737
#include <cassert>
3838

39-
#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU)
40-
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
41-
#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x)
42-
#define CHECK_CPU_INPUT_TENSOR(x) \
43-
CHECK_CPU_TENSOR(x); \
44-
CHECK_CONTIGUOUS_TENSOR(x)
39+
#include "cppmacro.h"

torch_harmonics/attention/csrc/attention_cpu.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#include <array>
3535
#include <vector>
3636

37+
#include "cppmacro.h"
38+
3739
#define CACHE_BLOCK_SIZE (64)
3840

3941
namespace attention_kernels {

torch_harmonics/attention/csrc/attention_cuda.cuh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,8 @@
3434
#include <cstdint>
3535
#include <torch/torch.h>
3636

37-
#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA)
38-
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous(), #x " must be contiguous")
39-
#define CHECK_CUDA_INPUT_TENSOR(x) \
40-
CHECK_CUDA_TENSOR(x); \
41-
CHECK_CONTIGUOUS_TENSOR(x)
37+
#include "cudamacro.h"
38+
4239

4340
namespace attention_kernels {
4441

torch_harmonics/attention/csrc/attention_cuda_bwd.cu

Lines changed: 2 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
#include <cub/cub.cuh>
4242
#include <limits>
4343

44-
#include "cudamacro.h"
44+
//#include "cudamacro.h"
4545
#include "attention_cuda_utils.cuh"
4646

4747
#include <iostream>
@@ -52,153 +52,8 @@
5252

5353
#define MAX_LOCAL_ARR_LEN (16)
5454

55-
namespace attention_kernels {
56-
57-
#if 0
58-
class ScopeTimer
59-
{
60-
public:
61-
explicit ScopeTimer(const std::string &label = "") :
62-
label_(label), start_(std::chrono::high_resolution_clock::now())
63-
{
64-
}
65-
66-
~ScopeTimer()
67-
{
68-
auto end = std::chrono::high_resolution_clock::now();
69-
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
70-
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
71-
}
72-
73-
private:
74-
std::string label_;
75-
std::chrono::high_resolution_clock::time_point start_;
76-
};
77-
78-
// easier to understand version of manual shfl_xor_sync, performance appears similar
79-
static __device__ float __warp_sum_cub(float val)
80-
{
81-
// use cub to reduce within a warp
82-
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
83-
84-
// 1. Compute sum (initially only in lane 0)
85-
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
86-
// 2. Broadcast sum to all threads
87-
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
88-
return sum;
89-
}
90-
91-
// This kernel computes the backward pass for the S2 attention mechanism, using
92-
// shared memory as a cache and one warp per output point, warp-parallel over
93-
// channels, which should be layed out in the fastest dimension for coalesced
94-
// memory access.
95-
template <int BDIM_X>
96-
__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
97-
int num_channels, int nlon_in, int nlat_out, int nlon_out,
98-
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
99-
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
100-
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
101-
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
102-
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
103-
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
104-
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
105-
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
106-
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
107-
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
108-
{
10955

110-
extern __shared__ float sh[];
111-
float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
112-
float *sh_alpha_vw = sh_alpha_k + num_channels;
113-
float *sh_alpha_kvw = sh_alpha_vw + num_channels;
114-
float *sh_dy = sh_alpha_kvw + num_channels;
115-
float *sh_qy = sh_dy + num_channels;
116-
// (optionally, could use more shared memory for other intermediates)
117-
118-
const uint64_t batchId = blockIdx.y;
119-
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
120-
if (wid >= uint64_t(nlat_out) * nlon_in) return;
121-
const int tidx = threadIdx.x;
122-
const int ho = wid / nlon_out;
123-
const int wo = wid - (ho * nlon_out);
124-
125-
// Zero shared memory
126-
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
127-
sh_alpha_k[chan] = 0.0f;
128-
sh_alpha_vw[chan] = 0.0f;
129-
sh_alpha_kvw[chan] = 0.0f;
130-
sh_dy[chan] = dy[batchId][chan][ho][wo];
131-
sh_qy[chan] = qy[batchId][chan][ho][wo];
132-
}
133-
float alpha_sum = 0.0f;
134-
float qdotk_max = -FLT_MAX;
135-
float integral = 0.0f;
136-
__syncthreads();
137-
138-
const int64_t rbeg = psi_row_offset[ho];
139-
const int64_t rend = psi_row_offset[ho + 1];
140-
const int rlen = rend - rbeg;
141-
142-
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
143-
for (int off = 0; off < rlen; off++) {
144-
const int64_t col = psi_col_idx[rbeg + off];
145-
const int hi = col / nlon_in;
146-
const int wi = col - (hi * nlon_in);
147-
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
148-
float qdotk = 0.0f, gdotv = 0.0f;
149-
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
150-
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
151-
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
152-
}
153-
qdotk = __warp_sum_cub(qdotk);
154-
gdotv = __warp_sum_cub(gdotv);
155-
float qdotk_max_tmp = max(qdotk_max, qdotk);
156-
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
157-
float max_correction = expf(qdotk_max - qdotk_max_tmp);
158-
alpha_sum = alpha_sum * max_correction + alpha_inz;
159-
integral = integral * max_correction + alpha_inz * gdotv;
160-
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
161-
float kxval = kx[batchId][chan][hi][wip];
162-
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
163-
sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
164-
sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
165-
}
166-
qdotk_max = qdotk_max_tmp;
167-
}
168-
169-
integral /= alpha_sum;
170-
171-
// Write dydq
172-
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
173-
dydq[batchId][chan][ho][wo]
174-
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
175-
}
176-
177-
// Third pass: accumulate gradients for k and v
178-
for (int off = 0; off < rlen; off++) {
179-
const int64_t col = psi_col_idx[rbeg + off];
180-
const int hi = col / nlon_in;
181-
const int wi = col - (hi * nlon_in);
182-
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
183-
float qdotk = 0.0f, gdotv = 0.0f;
184-
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
185-
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
186-
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
187-
}
188-
qdotk = __warp_sum_cub(qdotk);
189-
gdotv = __warp_sum_cub(gdotv);
190-
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
191-
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
192-
float qyval = qy[batchId][chan][ho][wo];
193-
float dyval = sh_dy[chan];
194-
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
195-
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
196-
}
197-
}
198-
}
199-
#endif
200-
201-
// BEGIN backward kernels and functions
56+
namespace attention_kernels {
20257

20358
// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
20459
template<int BDIM_X,

torch_harmonics/attention/csrc/attention_cuda_fwd.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,13 @@
3939
#include <cub/cub.cuh>
4040
#include <limits>
4141

42-
#include "cudamacro.h"
42+
//#include "cudamacro.h"
4343
#include "attention_cuda_utils.cuh"
4444

4545
#define THREADS (64)
4646

4747
#define MAX_LOCAL_ARR_LEN (16)
4848

49-
// BEGIN - forward kernels and functions
5049

5150
namespace attention_kernels {
5251

torch_harmonics/attention/csrc/attention_cuda_utils.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
#include <cub/cub.cuh>
4040
#include <limits>
4141

42-
#include "cudamacro.h"
42+
//#include "cudamacro.h"
4343
#include "attention_cuda.cuh"
4444

4545
#define THREADS (64)

torch_harmonics/attention/csrc/attention_cuda_utils.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
#include <ATen/cuda/CUDAContext.h>
3535
#include <ATen/cuda/CUDAUtils.h>
3636

37+
#include "cudamacro.h"
38+
3739
#define WARP_SIZE (32)
38-
#define FULL_MASK (0xFFFFFFFF)
39-
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
4040

4141
namespace attention_kernels {
4242

torch_harmonics/disco/csrc/cudamacro.h

Lines changed: 0 additions & 47 deletions
This file was deleted.

torch_harmonics/disco/csrc/disco.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,3 @@
3535
#include <torch/all.h>
3636
#include <torch/library.h>
3737
#include <cassert>
38-
39-
#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU)
40-
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
41-
#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x)
42-
#define CHECK_CPU_INPUT_TENSOR(x) \
43-
CHECK_CPU_TENSOR(x); \
44-
CHECK_CONTIGUOUS_TENSOR(x)

0 commit comments

Comments
 (0)