|
41 | 41 | #include <cub/cub.cuh> |
42 | 42 | #include <limits> |
43 | 43 |
|
44 | | -#include "cudamacro.h" |
| 44 | +//#include "cudamacro.h" |
45 | 45 | #include "attention_cuda_utils.cuh" |
46 | 46 |
|
47 | 47 | #include <iostream> |
|
52 | 52 |
|
53 | 53 | #define MAX_LOCAL_ARR_LEN (16) |
54 | 54 |
|
55 | | -namespace attention_kernels { |
56 | | - |
57 | | -#if 0 |
58 | | -class ScopeTimer |
59 | | -{ |
60 | | - public: |
61 | | - explicit ScopeTimer(const std::string &label = "") : |
62 | | - label_(label), start_(std::chrono::high_resolution_clock::now()) |
63 | | - { |
64 | | - } |
65 | | - |
66 | | - ~ScopeTimer() |
67 | | - { |
68 | | - auto end = std::chrono::high_resolution_clock::now(); |
69 | | - auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_); |
70 | | - std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; |
71 | | - } |
72 | | - |
73 | | - private: |
74 | | - std::string label_; |
75 | | - std::chrono::high_resolution_clock::time_point start_; |
76 | | -}; |
77 | | - |
78 | | -// easier to understand version of manual shfl_xor_sync, performance appears similar |
79 | | -static __device__ float __warp_sum_cub(float val) |
80 | | -{ |
81 | | - // use cub to reduce within a warp |
82 | | - __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; |
83 | | - |
84 | | - // 1. Compute sum (initially only in lane 0) |
85 | | - float sum = cub::WarpReduce<float>(temp_storage).Sum(val); |
86 | | - // 2. Broadcast sum to all threads |
87 | | - sum = __shfl_sync(0xFFFFFFFF, sum, 0); |
88 | | - return sum; |
89 | | -} |
90 | | - |
91 | | -// This kernel computes the backward pass for the S2 attention mechanism, using |
92 | | -// shared memory as a cache and one warp per output point, warp-parallel over |
93 | | -// channels, which should be layed out in the fastest dimension for coalesced |
94 | | -// memory access. |
95 | | -template <int BDIM_X> |
96 | | -__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( |
97 | | - int num_channels, int nlon_in, int nlat_out, int nlon_out, |
98 | | - const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, |
99 | | - const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, |
100 | | - const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, |
101 | | - const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, |
102 | | - torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk, |
103 | | - torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv, |
104 | | - torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq, |
105 | | - const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, |
106 | | - const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, |
107 | | - const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) |
108 | | -{ |
109 | 55 |
|
110 | | - extern __shared__ float sh[]; |
111 | | - float *sh_alpha_k = sh + threadIdx.y * num_channels * 5; |
112 | | - float *sh_alpha_vw = sh_alpha_k + num_channels; |
113 | | - float *sh_alpha_kvw = sh_alpha_vw + num_channels; |
114 | | - float *sh_dy = sh_alpha_kvw + num_channels; |
115 | | - float *sh_qy = sh_dy + num_channels; |
116 | | - // (optionally, could use more shared memory for other intermediates) |
117 | | - |
118 | | - const uint64_t batchId = blockIdx.y; |
119 | | - const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; |
120 | | - if (wid >= uint64_t(nlat_out) * nlon_in) return; |
121 | | - const int tidx = threadIdx.x; |
122 | | - const int ho = wid / nlon_out; |
123 | | - const int wo = wid - (ho * nlon_out); |
124 | | - |
125 | | - // Zero shared memory |
126 | | - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { |
127 | | - sh_alpha_k[chan] = 0.0f; |
128 | | - sh_alpha_vw[chan] = 0.0f; |
129 | | - sh_alpha_kvw[chan] = 0.0f; |
130 | | - sh_dy[chan] = dy[batchId][chan][ho][wo]; |
131 | | - sh_qy[chan] = qy[batchId][chan][ho][wo]; |
132 | | - } |
133 | | - float alpha_sum = 0.0f; |
134 | | - float qdotk_max = -FLT_MAX; |
135 | | - float integral = 0.0f; |
136 | | - __syncthreads(); |
137 | | - |
138 | | - const int64_t rbeg = psi_row_offset[ho]; |
139 | | - const int64_t rend = psi_row_offset[ho + 1]; |
140 | | - const int rlen = rend - rbeg; |
141 | | - |
142 | | - // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max. |
143 | | - for (int off = 0; off < rlen; off++) { |
144 | | - const int64_t col = psi_col_idx[rbeg + off]; |
145 | | - const int hi = col / nlon_in; |
146 | | - const int wi = col - (hi * nlon_in); |
147 | | - const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; |
148 | | - float qdotk = 0.0f, gdotv = 0.0f; |
149 | | - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { |
150 | | - qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; |
151 | | - gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; |
152 | | - } |
153 | | - qdotk = __warp_sum_cub(qdotk); |
154 | | - gdotv = __warp_sum_cub(gdotv); |
155 | | - float qdotk_max_tmp = max(qdotk_max, qdotk); |
156 | | - float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; |
157 | | - float max_correction = expf(qdotk_max - qdotk_max_tmp); |
158 | | - alpha_sum = alpha_sum * max_correction + alpha_inz; |
159 | | - integral = integral * max_correction + alpha_inz * gdotv; |
160 | | - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { |
161 | | - float kxval = kx[batchId][chan][hi][wip]; |
162 | | - sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval; |
163 | | - sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv; |
164 | | - sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv; |
165 | | - } |
166 | | - qdotk_max = qdotk_max_tmp; |
167 | | - } |
168 | | - |
169 | | - integral /= alpha_sum; |
170 | | - |
171 | | - // Write dydq |
172 | | - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { |
173 | | - dydq[batchId][chan][ho][wo] |
174 | | - = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum); |
175 | | - } |
176 | | - |
177 | | - // Third pass: accumulate gradients for k and v |
178 | | - for (int off = 0; off < rlen; off++) { |
179 | | - const int64_t col = psi_col_idx[rbeg + off]; |
180 | | - const int hi = col / nlon_in; |
181 | | - const int wi = col - (hi * nlon_in); |
182 | | - const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; |
183 | | - float qdotk = 0.0f, gdotv = 0.0f; |
184 | | - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { |
185 | | - qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; |
186 | | - gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; |
187 | | - } |
188 | | - qdotk = __warp_sum_cub(qdotk); |
189 | | - gdotv = __warp_sum_cub(gdotv); |
190 | | - float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; |
191 | | - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { |
192 | | - float qyval = qy[batchId][chan][ho][wo]; |
193 | | - float dyval = sh_dy[chan]; |
194 | | - atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral)); |
195 | | - atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval); |
196 | | - } |
197 | | - } |
198 | | -} |
199 | | -#endif |
200 | | - |
201 | | -// BEGIN backward kernels and functions |
| 56 | +namespace attention_kernels { |
202 | 57 |
|
203 | 58 | // called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y) |
204 | 59 | template<int BDIM_X, |
|
0 commit comments