diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..b39c458 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "accelerated_scan/tk"] + path = accelerated_scan/tk + url = https://github.com/HazyResearch/ThunderKittens diff --git a/README.md b/README.md index 2d42f1d..4d6706a 100644 --- a/README.md +++ b/README.md @@ -65,3 +65,12 @@ forward speed of (8,1536,seqlen), inference mode: When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation): ![max-abs-error.png](max-abs-error.png) + + +## Attention + +``` +(cd accelerated_scan; python3 kitten_setup.py build) +ncu -k causal_attend_kernel python3 ./tests/single.py +python3 ./tests/bench.py --direction forward +``` diff --git a/accelerated_scan/algebra.cu b/accelerated_scan/algebra.cu new file mode 100644 index 0000000..bc6c066 --- /dev/null +++ b/accelerated_scan/algebra.cu @@ -0,0 +1,386 @@ + +#include "tk/src/kittens.cuh" +#include "tk/src/common/pyutils/torch_helpers.cuh" + +using namespace kittens; + +template +__device__ void tileprint(rt reg, char *name) { + auto laneid = kittens::laneid(); + static_assert(reg.height == 1 && reg.width == 1, "height and width must be 1"); + for(int i = 0; i < reg.height; i++) { + for(int j = 0; j < reg.width; j++) { + static_assert(reg.packed_per_thread == 4, "packed_per_thread must be 4"); + + int row_top = laneid / 4; + int row_bottom = row_top + 8; + int col_left = laneid % 4 * 2; // stride 4 + int col_right = col_left + 8; + + auto item_top_left = __bfloat1622float2(reg.tiles[i][j].data[0]); + auto item_bottom_left = __bfloat1622float2(reg.tiles[i][j].data[1]); + auto item_top_right = __bfloat1622float2(reg.tiles[i][j].data[2]); + auto item_bottom_right = __bfloat1622float2(reg.tiles[i][j].data[3]); + printf("lane=%02d " + "%s[%02d,%02d] 0x=% .3f " + "%s[,%02d] 0y=% .3f " + "%s[%02d,%02d] 1x=% .3f " + "%s[,%02d] 1y=% .3f " + "%s[%02d,%02d] 2x=% .3f " + "%s[,%02d] 2y=% .3f " + "%s[%02d,%02d] 3x=% .3f " + "%s[,%02d] 3y=% .3f\n", + laneid, + name, row_top, col_left, item_top_left.x, + name, col_left+1, item_top_left.y, + name, row_bottom, col_left, item_bottom_left.x, + name, col_left+1, item_bottom_left.y, + name, row_top, col_right, item_top_right.x, + name, col_right+1, item_top_right.y, + name, row_bottom, col_right, item_bottom_right.x, + name, col_right+1, item_bottom_right.y); + } + } +} + + +template +__device__ static inline void op_singlerow(T &dst, const T &lhs, const T &rhs, const int row_index) { + const int row_top = laneid() / 4; + const int row_bottom = row_top + 8; + + static_assert(dst.packed_per_tile == 4, "packed_per_tile must be 4"); + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if (row_top == row_index) { + dst.tiles[i][j].data[0] = op::template op(lhs.tiles[i][j].data[0], rhs.tiles[i][j].data[0]); + dst.tiles[i][j].data[2] = op::template op(lhs.tiles[i][j].data[2], rhs.tiles[i][j].data[2]); + } else if (row_bottom == row_index) { + dst.tiles[i][j].data[1] = op::template op(lhs.tiles[i][j].data[1], rhs.tiles[i][j].data[1]); + dst.tiles[i][j].data[3] = op::template op(lhs.tiles[i][j].data[3], rhs.tiles[i][j].data[3]); + } + } + } +} + +template +__device__ static inline void zeroexcept( + rt &dst, + const rt &src, + const int row_index +) { + const int row_top = laneid() / 4; + const int row_bottom = row_top + 8; + + static_assert(dst.packed_per_tile == 4, "packed_per_tile must be 4"); + //using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if (row_top != row_index) { + dst.tiles[i][j].data[0] = {0.0, 0.0}; + dst.tiles[i][j].data[2] = {0.0, 0.0}; + } else { + dst.tiles[i][j].data[0] = src.tiles[i][j].data[0]; + dst.tiles[i][j].data[2] = src.tiles[i][j].data[2]; + } + if (row_bottom != row_index) { + dst.tiles[i][j].data[1] = {0.0, 0.0}; + dst.tiles[i][j].data[3] = {0.0, 0.0}; + } else { + dst.tiles[i][j].data[1] = src.tiles[i][j].data[1]; + dst.tiles[i][j].data[3] = src.tiles[i][j].data[3]; + } + } + } +} + + +template +__device__ static inline void reset_trailing_rows(RT &dst, const int row_index, const typename base_types::packing::unpacked_type &val=0) { + const int row_top = laneid() / 4; + const int row_bottom = row_top + 8; + + static_assert(dst.packed_per_tile == 4, "packed_per_tile must be 4"); + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if (row_top >= row_index) { + dst.tiles[i][j].data[0].x = val; + dst.tiles[i][j].data[0].y = val; + dst.tiles[i][j].data[2].x = val; + dst.tiles[i][j].data[2].y = val; + } + if (row_bottom >= row_index) { + dst.tiles[i][j].data[1].x = val; + dst.tiles[i][j].data[1].y = val; + dst.tiles[i][j].data[3].x = val; + dst.tiles[i][j].data[3].y = val; + } + } + } +} + + + +/** + * @brief Set a constant to elements of the diagonal in a square register tile. + * + * @tparam T The data type of the register tile elements. + * @tparam _size The size (height and width) of the square register tile. + * @tparam layout The current layout of the register tile. + * @param tile[in,out] Reference to the register tile. + */ +template +__device__ static inline void set_diagonal(RT &dst, const RT &src, const typename base_types::packing::unpacked_type &val=1) { + const typename RT::dtype packed_val = base_types::packing::pack(val); + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if(j > i || j < i) { // below or above the diagonal, ignore + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = src.tiles[i][j].data[k]; + } + } else { // on the diagonal, interesting! + dst.tiles[i][j].data[1] = src.tiles[i][j].data[1]; // below diagonal, copy + dst.tiles[i][j].data[2] = src.tiles[i][j].data[2]; // above diagonal, copy + + if (laneid() == 0 || laneid() == 9 || laneid() == 18 || laneid() == 27) { + // diagonal: every odd row + dst.tiles[i][j].data[0].x = val; + dst.tiles[i][j].data[3].x = val; + } else { + dst.tiles[i][j].data[0].x = src.tiles[i][j].data[0].x; + dst.tiles[i][j].data[3].x = src.tiles[i][j].data[3].x; + } + + if (laneid() == 4 || laneid() == 13 || laneid() == 22 || laneid() == 31) { + // diagonal: every even row + dst.tiles[i][j].data[0].y = val; + dst.tiles[i][j].data[3].y = val; + } else { + dst.tiles[i][j].data[0].y = src.tiles[i][j].data[0].y; + dst.tiles[i][j].data[3].y = src.tiles[i][j].data[3].y; + } + } + } + } +} + + +__device__ void vecprint(rv reg, char *name) { + auto warpid = kittens::warpid(); + auto item0 = __bfloat1622float2(reg.data[0][0]); + printf("warpid=%d tid=%d %s[0] = {%f,%f}\n", warpid, threadIdx.x, name, item0.x, item0.y); + auto item1 = __bfloat1622float2(reg.data[0][1]); + printf("warpid=%d tid=%d %s[1] = {%f,%f}\n", warpid, threadIdx.x, name, item1.x, item1.y); +} + +__device__ void vecprint(rv reg, char *name) { + auto warpid = kittens::warpid(); + auto item0 = __bfloat1622float2(reg.data[0][0]); + printf("warpid=%d tid=%d %s[0] = {%f,%f}\n", warpid, threadIdx.x, name, item0.x, item0.y); +} + +__device__ void vecprint(rv reg, char *name) { + auto warpid = kittens::warpid(); + + #pragma unroll + for(int i = 0; i < reg.outer_dim; i++) { + #pragma unroll + for(int j = 0; j < reg.inner_dim; j++) { + auto item = __bfloat1622float2(reg.data[i][j]); + printf("warpid=%d tid=%d %s[%d][%d] = {%f,%f}\n", warpid, threadIdx.x, name, i, j, item.x, item.y); + } + } +} + +/** + * @brief Bind a list of vectors (k,u) and write them to a dictionary state_delta. + */ +template +__device__ static inline void associate( + rt &state_delta, + /*const*/ rt &v, + /*const*/ rt &k +) { + rt mma_state; + associate(state_delta, v, k, mma_state, false); +} + +/** + * @brief Bind a list of vectors (k,u) and write them to a dictionary state_delta. + */ +template +__device__ static inline void associate( + rt &state_delta, + /*const*/ rt &v, + /*const*/ rt &k, + rt &mma_state, + const bool accum = false +) { + if (accum == false) { + zero(mma_state); + } + auto &k_col = swap_layout_inplace(k); + auto &v_col = swap_layout_inplace(v); + mma_AtB(mma_state, v_col, k_col, mma_state); + swap_layout_inplace(k_col); + swap_layout_inplace(v_col); + copy(state_delta, mma_state); +} + +/** + * @brief Bind a list of vectors (k,u) and write them to a dictionary state_delta. + */ +template +__device__ static inline void associate( + rt &state_delta, + /*const*/ rt &v_col, + /*const*/ rt &k, + rt &mma_state, + const bool accum = false +) { + if (accum == false) { + zero(mma_state); + } + auto &k_col = swap_layout_inplace(k); + mma_AtB(mma_state, v_col, k_col, mma_state); + swap_layout_inplace(k_col); + copy(state_delta, mma_state); +} + + +template +__device__ static inline void query( + rt &output_values, + const rt &state, + const rt &query +) { + rt mma; + + zero(mma); + mma_ABt(mma, query, state, mma); // einsum('tk,vk->tv', query, state) + copy(output_values, mma); +} + +template +__device__ static inline void reverse_query( + rt &output_keys, + const rt &value_query, + /*const*/ rt &state +) { + auto &state_col = swap_layout_inplace(state); + reverse_query(output_keys, value_query, state_col); + swap_layout_inplace(state_col); +} + +template +__device__ static inline void reverse_query( + rt &output_keys, + const rt &value_query, + /*const*/ rt &state_col +) { + rt mma; + + zero(mma); + mma_AB(mma, value_query, state_col, mma); // einsum('tv,vk->tk', value_query, state) + copy(output_keys, mma); +} + + + +template +__device__ static inline void kernel( + rt &attention, + const rt &query, + const rt &key +) { + rt qk; + + zero(qk); + mma_ABt(qk, query, key, qk); // mma_TT = einsum('nsk,ntk->nst', q, k) + copy(attention, qk); +} + +template +__device__ static inline void attend( + rt &mixtures, + const rt &attention, + /*const*/ rt &values +) { + auto &values_col = swap_layout_inplace(values); + attend(mixtures, attention, values_col); + swap_layout_inplace(values_col); +} + +template +__device__ static inline void attend( + rt &mixtures, + const rt &attention, + const rt &values_col +) { + rt mma; + + zero(mma); + mma_AB(mma, attention, values_col, mma); + copy(mixtures, mma); +} + +template +__device__ static inline void reverse_attend( + rt &mixtures, + /*const*/ rt &attention, + /*const*/ rt &source_values +) { + auto &source_values_col = swap_layout_inplace(source_values); + reverse_attend(mixtures, attention, source_values_col); + swap_layout_inplace(source_values_col); +} + +template +__device__ static inline void reverse_attend( + rt &mixtures, + /*const*/ rt &attention, + /*const*/ rt &source_values_col +) { + rt mma; + reverse_attend(attention, source_values_col, mma); + copy(mixtures, mma); +} + +template +__device__ static inline void reverse_attend( + /*const*/ rt &attention, + /*const*/ rt &source_values_col, + rt &mma, + const bool accum = false +) { + if (!accum) { + zero(mma); + } + auto &attention_col = swap_layout_inplace(attention); + mma_AtB(mma, attention_col, source_values_col, mma); + swap_layout_inplace(attention_col); +} + + +struct op_negate { + template static __device__ inline T op(const T &x) { return -x; } +}; + +template +__device__ static inline void negate(T &dst) { + unary_map(dst, dst); +} + diff --git a/accelerated_scan/delta.cu b/accelerated_scan/delta.cu new file mode 100644 index 0000000..1c265cd --- /dev/null +++ b/accelerated_scan/delta.cu @@ -0,0 +1,923 @@ +#include +#include +#include +#include + +#include "algebra.cu" + +using namespace kittens; // this kernel only handles headdim=q_reg.cols for simplicity. Also n should be a multiple of 256 here. +using barrier = cuda::barrier; + +template +static __device__ inline void decay_values_forward( + rt &tt_reg, + rt &bKl_reg, + rt &k_reg, + typename rt::col_vec &beta_reg, + rt &w_reg, + rt &u_reg, + rt &v_reg, + rt &u_bases_reg, // can be aliased to v_reg + rt &bk_reg +) { + rt tk_reg; + rt tv_reg; + + rt tt; + rv decay; + rv decay_1; + rv decay_w; + rv decay_u; + + mul_row(bk_reg, k_reg, beta_reg); // k = einsum('ntk,nt->ntk', k, beta) + + kernel(tt_reg, k_reg, k_reg); + make_causal(tt_reg, tt_reg, 0); + set_diagonal(tt_reg, tt_reg, 0); // tt = tt.tril(diagnal=-1) + + mul_row(bKl_reg, tt_reg, beta_reg); // tt = einsum('nts,nt->nts', tt, beta) + + copy(u_bases_reg, v_reg); + mul_row(v_reg, v_reg, beta_reg); // v = einsum('ntw,nt->ntw', v, beta) + + copy(w_reg, bk_reg); + copy(u_reg, v_reg); + + #pragma unroll + for (auto t = 1; t < w_reg.rows; t++) { + // select row t from bKl + zeroexcept(tt, bKl_reg, t); + col_sum(decay, tt); + copy(decay_1, decay); + + // ALT: reverse_query(tk_reg, bKl_reg, w_reg); + broadcast_row(tk_reg, decay_1); + mul(tk_reg, w_reg, tk_reg); + col_sum(decay_w, tk_reg); // decay_w is 1xK vector + broadcast_col(tk_reg, decay_w); + + op_singlerow(w_reg, bk_reg, tk_reg, t); // w[t] = bk[t] - tk[t] + + // ALT: reverse_query(tv_reg, bKl_reg, u_reg); + broadcast_row(tv_reg, decay_1); + mul(tv_reg, u_reg, tv_reg); + col_sum(decay_u, tv_reg); // decay_u is 1xV vector + broadcast_col(tv_reg, decay_u); + + op_singlerow(u_reg, v_reg, tv_reg, t); // u[t] = bv[t] - tv[t] + } +} + + +template +__global__ void delta_forward_kernel( + const int num_chunks, + const H* __restrict__ __q__, + const H* __restrict__ __k__, + const H* __restrict__ __v__, + const H* __restrict__ __beta__, + H* __restrict__ __y__ +) { + auto warpid = kittens::warpid(); + auto vg = blockIdx.x % _value_groups; + auto block_start = blockIdx.z*(num_chunks*kChunkSize*(_key*TILE_DIM)); + auto beta_block_start = blockIdx.z*(num_chunks*kChunkSize*1); // width is 1 for beta + const T *_q = reinterpret_cast(__q__) + block_start, + *_k = reinterpret_cast(__k__) + block_start, + *_v = reinterpret_cast(__v__) + block_start, + *_beta = reinterpret_cast(__beta__) + beta_block_start; + T *_y = reinterpret_cast(__y__) + block_start; + + extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory + shared_allocator al((int*)&__shm[0]); + + st (&shared_states)[1] = al.allocate, 1>(); + st &shared_state = shared_states[0]; + + using col = rt; + /* + * register allocations + */ + rt mma_state; + rt k, w, bk; + rt v, u; + typename rt::col_vec beta_reg; + rt qk; + constexpr int v_num_elements = v.num_elements * _value_groups; + constexpr int v_row_stride = v.cols * _value_groups; + + __shared__ barrier batons[kNumWarps]; + auto block = cooperative_groups::this_thread_block(); + + init(&batons[warpid], 2); + if (block.thread_rank() == 0) { + auto token = batons[0].arrive(); + } + + for (int time_block = 0; time_block < (num_chunks + kNumWarps - 1) / kNumWarps; time_block++) { + int chunk = time_block * kNumWarps + warpid; + if (chunk >= num_chunks) { + break; + } + + /* + * load k, v, beta + */ + + load(k, _k + chunk*k.num_elements, k.cols); + load(beta_reg, _beta + chunk*beta_reg.outer_dim*TILE_DIM); + load(v, _v + chunk*v_num_elements + v.cols * vg, v_row_stride); + + /* + * decay_values_forward: compute w and u + */ + + decay_values_forward(qk, qk, k, beta_reg, w, u, v, v, bk); + + /* + * attend to decayed values + */ + + auto &q = bk; + load(q, _q + chunk*q.num_elements, q.cols); + + kernel(qk, q, k); + make_causal(qk, qk, 0); + + auto &y = v; + attend(y, qk, u); + + chunk_forward(shared_state, mma_state, q, k, w, u, qk, y, chunk, batons[warpid], batons[(warpid + 1) % kNumWarps]); + + store(_y + chunk*v_num_elements + y.cols * vg, y, v_row_stride); + } +} + +template +__device__ static inline void chunk_forward_nooutput( + st &shared_state, + rt &mma_state, + rt &k, + rt &w, + rt &u, // become decayed + const int chunk, + barrier &self, + barrier &next +) { + return chunk_forward_impl( + shared_state, mma_state, k, w, u, nullptr, nullptr, nullptr, chunk, self, next + ); +} + +template +__device__ static inline void chunk_forward( + st &shared_state, + rt &mma_state, + rt &q, + rt &k, + rt &w, + rt &u, // become decayed + rt &qk, + rt &y, + const int chunk, + barrier &self, + barrier &next +) { + return chunk_forward_impl( + shared_state, mma_state, k, w, u, q, qk, y, chunk, self, next + ); +} + +template +__device__ static inline void chunk_forward_impl( + st &shared_state, + rt &mma_state, + rt &k, + rt &w, // not necessary for chunk 0 + rt &u, // become decayed + std::conditional_t &, std::nullptr_t> q, + std::conditional_t &, std::nullptr_t> qk, + std::conditional_t &, std::nullptr_t> y, + const int chunk, + barrier &self, + barrier &next +) { + auto laneid = kittens::laneid(); + + rt state; + rt u_old; + rt y_buf; + + // all warps execute sequentially, passing state through the shared memory + if (laneid == 0) { + // wait until our state is ready + auto token = self.arrive(); + self.wait(std::move(token)); + } + __syncwarp(); + + if (chunk > 0) { + load(state, shared_state); + copy(mma_state, state); + + query(u_old, state, w); + sub(u, u, u_old); + + if constexpr (NeedsY) { + query(y_buf, state, q); + } + } else { + zero(mma_state); + } + + if constexpr (NeedsY) { + if (chunk > 0) { + add(y, y, y_buf); + + attend(y_buf, qk, u_old); + sub(y, y, y_buf); + } + } + + associate(state, u, k, mma_state, true); + store(shared_state, state); + + if (laneid == 0) { + // data ƒor the next chunk has arrived + auto token = next.arrive(); + } + __syncwarp(); +} + +/** + * @brief Stitch the chunks backwards + */ +template +__device__ static inline void chunk_backward( + rt &q, // not needed for chunk 0 + rt &k, + rt &w, // not needed for chunk 0 + rt &u, + rt &d_y, + rt &d_q, + rt &d_k, + rt &d_w, + rt &d_u, + const int chunk, + const int num_chunks, + st &shared_state, + st &shared_d_state, + barrier &self, + barrier &prev +) { + auto laneid = kittens::laneid(); + + rt qk; + rt d_state_decays, d_state_decays_buf; + rt tk; + rt tv; + rt state, d_state, state_delta; + + // all warps execute sequentially, passing state through the shared memory + if (laneid == 0) { + // wait until our state is ready + auto token = self.arrive(); + self.wait(std::move(token)); + } + __syncwarp(); + + + load(state, shared_state); + load(d_state, shared_d_state); + + if (chunk == 0) { + zero(d_q); + reverse_query(d_k, u, d_state); + zero(d_w); + query(d_u, d_state, k); + } else { + associate(state_delta, u, k); + /* + * uncompute the state backwards + */ + sub(state, state, state_delta); + query(tv, state, w); + + negate(d_y); + + // d_q, d_k + kernel(qk, d_y, tv); + make_causal(qk, qk, 0); + + // d_q + attend(d_q, qk, k); + reverse_query(tk, d_y, state); + sub(d_q, d_q, tk); + + // d_k + reverse_attend(d_k, qk, q); + if (chunk < num_chunks - 1) { + reverse_query(tk, u, d_state); // otherwise we know d_state is zero + add(d_k, d_k, tk); + } + + // d_u + if (chunk < num_chunks - 1) { + query(d_u, d_state, k); + } else { + // otherwise we know d_state is zero + zero(d_u); + } + + // d_state_decays + kernel(qk, q, k); + make_causal(qk, qk, 0); + + reverse_attend(d_state_decays, qk, d_y); + if (chunk < num_chunks - 1) { + query(d_state_decays_buf, d_state, k); + sub(d_state_decays, d_state_decays, d_state_decays_buf); + } + + // d_w + reverse_query(d_w, d_state_decays, state); + + // backpropagate through time + auto &state_buf = state_delta; // alias + associate(state_buf, d_y, q); + sub(d_state, d_state, state_buf); + associate(state_buf, d_state_decays, w); + add(d_state, d_state, state_buf); + + negate(d_y); // undo + } + + store(shared_state, state); + store(shared_d_state, d_state); + + if (laneid == 0) { + // gradients ƒor the previous chunk have arrived + auto token = prev.arrive(); + } + __syncwarp(); +} + + +template +struct DeltaBackwardArgs { + unsigned long long num_chunks; + const H* __restrict__ __d_out_y__; + const H* __restrict__ __q__; + const H* __restrict__ __k__; + const H* __restrict__ __v__; + const H* __restrict__ __beta__; + H* __restrict__ __d_q__; + H* __restrict__ __d_k__; + H* __restrict__ __d_v__; + H* __restrict__ __d_beta__; + H* __restrict__ __u__; + int* __locks__; +}; + +template +__global__ void delta_backward_kernel(DeltaBackwardArgs args) { + const int num_chunks = args.num_chunks; + auto warpid = kittens::warpid(); + auto laneid = kittens::laneid(); + auto vg = blockIdx.x % _value_groups; + auto block_start = blockIdx.z*(num_chunks*kChunkSize*(_key*TILE_DIM)); + auto beta_block_start = blockIdx.z*(num_chunks*kChunkSize*1); // width is 1 for beta + const T *_d_out_y = reinterpret_cast(args.__d_out_y__) + block_start, + *_q = reinterpret_cast(args.__q__) + block_start, + *_k = reinterpret_cast(args.__k__) + block_start, + *_v = reinterpret_cast(args.__v__) + block_start, + *_beta = reinterpret_cast(args.__beta__) + beta_block_start; + T *_d_q = reinterpret_cast(args.__d_q__) + block_start, + *_d_k = reinterpret_cast(args.__d_k__) + block_start, + *_d_v = reinterpret_cast(args.__d_v__) + block_start, + *_d_beta = reinterpret_cast(args.__d_beta__) + beta_block_start, + *_u = reinterpret_cast(args.__u__) + block_start; + extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory + shared_allocator al((int*)&__shm[0]); + + st (&shared_states)[2] = al.allocate, 2>(); + st &shared_state = shared_states[0]; + st &shared_d_state = shared_states[1]; + + /* + * register allocations + */ + rt q, d_q; + rt k_reg, d_w, d_k; + rt v_reg; + rt d_y, d_u; + typename rt::col_vec beta_reg; + rt w_reg, bk_reg; + rt u_reg, u_bases_reg, u_predecay_reg; + rt tt_reg, bKl_reg; + + rt mma_state; + rt mma_TD; + rt mma_TT; + rt mma_TV; + constexpr int v_num_elements = v_reg.num_elements * _value_groups; + constexpr int v_row_stride = v_reg.cols * _value_groups; + + __shared__ barrier batons[kNumWarps]; + __shared__ barrier backward_batons[kNumWarps]; + auto block = cooperative_groups::this_thread_block(); + + init(&batons[warpid], 2); + init(&backward_batons[warpid], 2); + if (block.thread_rank() == 0) { + auto token = batons[0].arrive(); + } + + zero(shared_d_state); + const int time_blocks = (num_chunks + kNumWarps - 1) / kNumWarps; + + for (int time_block = 0; time_block < time_blocks; time_block++) { + const int chunk = time_block * kNumWarps + warpid; + if (chunk >= num_chunks) { + break; + } + + /* + * load k, v, beta + */ + + load(k_reg, _k + chunk*k_reg.num_elements, k_reg.cols); + load(beta_reg, _beta + chunk*beta_reg.outer_dim*TILE_DIM); + load(v_reg, _v + chunk*v_num_elements + v_reg.cols * vg, v_row_stride); + + /* + * decay_values_forward: compute w and u + */ + + decay_values_forward(tt_reg, bKl_reg, k_reg, beta_reg, w_reg, u_reg, v_reg, u_bases_reg, bk_reg); + + barrier &forward_self = batons[warpid]; + barrier &forward_next = chunk == num_chunks - 1 ? backward_batons[warpid] : batons[(warpid + 1) % kNumWarps]; + + copy(u_predecay_reg, u_reg); // chunk_forward_nooutput will mutate u_reg + chunk_forward_nooutput(shared_state, mma_state, k_reg, w_reg, u_reg, chunk, forward_self, forward_next); + store(_u + chunk*v_num_elements + u_reg.cols * vg, u_reg, v_row_stride); // store decayed u + } + + for (int time_block = time_blocks - 1; time_block >= 0; time_block--) { + const int chunk = time_block * kNumWarps + warpid; + if (chunk >= num_chunks) { + continue; + } + + if (time_block < time_blocks - 1) { + // reload + load(k_reg, _k + chunk*k_reg.num_elements, k_reg.cols); + load(beta_reg, _beta + chunk*beta_reg.outer_dim*TILE_DIM); + load(v_reg, _v + chunk*v_num_elements + v_reg.cols * vg, v_row_stride); + + // recompute w_reg and u_reg + decay_values_forward(tt_reg, bKl_reg, k_reg, beta_reg, w_reg, u_reg, v_reg, u_bases_reg, bk_reg); + copy(u_predecay_reg, u_reg); // load will mutate u_reg + + load(u_reg, _u + chunk*v_num_elements + u_reg.cols * vg, v_row_stride); + } + + load(q, _q + chunk*q.num_elements, q.cols); + load(d_y, _d_out_y + chunk*v_num_elements + d_y.cols * vg, v_row_stride); + + barrier &backward_self = backward_batons[warpid]; + const int previous = ((warpid - 1) % kNumWarps + kNumWarps) % kNumWarps; + barrier &backward_prev = backward_batons[previous]; + + chunk_backward( + q, k_reg, w_reg, u_reg, + d_y, + d_q, d_k, d_w, d_u, + chunk, num_chunks, + shared_state, shared_d_state, + backward_self, backward_prev + ); + + copy(u_reg, u_predecay_reg); // restore non-decayed u + + rt &q_reg = swap_layout_inplace(q); + decay_values_backward( + q_reg, k_reg, w_reg, tt_reg, bk_reg, u_reg, u_bases_reg, bKl_reg, beta_reg, + d_y, + d_q, d_k, d_w, d_u, + mma_TD, mma_TT, mma_TV, + _d_out_y, + _d_q, _d_k, _d_v, _d_beta, + chunk, + args.__locks__ + ); + + swap_layout_inplace(q_reg); + } +} + +template +__device__ static inline void decay_values_backward( + rt &q_reg, + rt &k_reg, + rt &w_reg, + rt &tt_reg, + rt &bk_reg, + rt &u_reg, + rt &u_bases_reg, + rt &bKl_reg, + rv &beta_reg, + rt &d_out_y_reg, + rt &d_q, + rt &d_k, + rt &d_w_reg, + rt &d_u_reg, + rt &mma_TD, + rt &mma_TT, + rt &mma_TV, + const T *_d_out_y, + T *_d_q, + T *_d_k, + T *_d_v, + T *_d_beta, + const int chunk, + int *locks +) { + auto vg = blockIdx.x % _value_groups; + constexpr int v_num_elements = u_reg.num_elements * _value_groups; + constexpr int v_row_stride = u_reg.cols * _value_groups; + + rt w_bases_reg, tk_reg, d_k_reg, d_out_w_reg; + rt v_reg, tv_reg; + rv d_beta_reg; + rv d_beta_buf_reg; + + attend(w_bases_reg, tt_reg, w_reg); + sub(w_bases_reg, k_reg, w_bases_reg); + + attend(v_reg, tt_reg, u_reg); + sub(u_bases_reg, u_bases_reg, v_reg); + + /* + * causal_attend_backward for d_q, d_k_2, d_u + */ + + + // d_q + kernel(tt_reg, d_out_y_reg, u_reg); + make_causal(tt_reg, tt_reg, 0); + attend(tk_reg, tt_reg, k_reg); + add(d_q, d_q, tk_reg); + + + if constexpr (_value_groups > 1) { + int *lock = &locks[0*kNumWarps*gridDim.z + blockIdx.z*kNumWarps + kittens::warpid()]; + + // lock + if (kittens::laneid() == 0) { + while (atomicCAS(lock, 0, 1) == 1) { + // spin + }; + } + __syncwarp(); + + load(tk_reg, _d_q + chunk*tk_reg.num_elements, tk_reg.cols); + add(d_q, d_q, tk_reg); + store(_d_q + chunk*d_q.num_elements, d_q, d_q.cols); + + // unlock + if (kittens::laneid() == 0) { + atomicExch(lock, 0); + } + __syncwarp(); + } else { + store(_d_q + chunk*d_q.num_elements, d_q, d_q.cols); + } + + // d_k + + reverse_attend(d_k_reg, tt_reg, q_reg); + add(d_k, d_k, d_k_reg); + + auto &q_reg_row = swap_layout_inplace(q_reg); + kernel(tt_reg, q_reg_row, k_reg); + //q_reg = swap_layout_inplace(q_reg_row); // won't need it later + make_causal(tt_reg, tt_reg, 0); // tt.tril_() + + reverse_attend(tv_reg, tt_reg, d_out_y_reg); // don't need last swap_layout_inplace of d_out_y_reg + add(d_u_reg, d_u_reg, tv_reg); + + /* + * backward for d_k, d_v, d_beta + */ + + zero(d_k_reg); + + for (auto t = _time * TILE_DIM - 1; t >= 0; t--) { + + auto &k_reg_col = swap_layout_inplace(k_reg); + + // d_k + zero(mma_TD); + { + kernel(tt_reg, w_reg, d_w_reg); + reset_trailing_rows(tt_reg, t); + + reverse_attend(tt_reg, k_reg_col, mma_TD, true); + + kernel(tt_reg, u_reg, d_u_reg); + reset_trailing_rows(tt_reg, t); + + reverse_attend(tt_reg, k_reg_col, mma_TD, true); + copy(tk_reg, mma_TD); + + op_singlerow(d_k_reg, d_k_reg, tk_reg, t); + } + + k_reg = swap_layout_inplace(k_reg_col); + + // backpropagate through time, updating only remaining timestamps + { + zero(tt_reg); + op_singlerow(tt_reg, tt_reg, bKl_reg, t); + auto &tt_reg_col = swap_layout_inplace(tt_reg); + + { + zero(mma_TD); + { + auto &d_w_reg_col = swap_layout_inplace(d_w_reg); + mma_AtB(mma_TD, tt_reg_col, d_w_reg_col, mma_TD); + d_w_reg = swap_layout_inplace(d_w_reg_col); + copy(tk_reg, mma_TD); + } + + sub(d_w_reg, d_w_reg, tk_reg); + + zero(mma_TV); + { + auto &d_u_reg_col = swap_layout_inplace(d_u_reg); + mma_AtB(mma_TV, tt_reg_col, d_u_reg_col, mma_TV); + d_u_reg = swap_layout_inplace(d_u_reg_col); + copy(tv_reg, mma_TV); + } + + sub(d_u_reg, d_u_reg, tv_reg); + } + + tt_reg = swap_layout_inplace(tt_reg_col); + } + + } + + sub(d_k_reg, d_w_reg, d_k_reg); // d_k = d_w - d_k + mul_row(d_k_reg, d_k_reg, beta_reg); // d_k = einsum('ntk,nt->ntk', d_k, beta) + + // decay w and u + zero(mma_TT); + mma_ABt(mma_TT, d_w_reg, w_reg, mma_TT); + mma_ABt(mma_TT, d_u_reg, u_reg, mma_TT); + copy(tt_reg, mma_TT); + make_causal(tt_reg, tt_reg, 0); + set_diagonal(tt_reg, tt_reg, 0); + + zero(mma_TD); + { + auto &tt_reg_col = swap_layout_inplace(tt_reg); + auto &bk_reg_col = swap_layout_inplace(bk_reg); // don't need the swap later + mma_AtB(mma_TD, tt_reg_col, bk_reg_col, mma_TD); + copy(tk_reg, mma_TD); + } + sub(d_k_reg, d_k_reg, tk_reg); + add(d_k, d_k, d_k_reg); + + if constexpr (_value_groups > 1) { + int *lock = &locks[1*kNumWarps*gridDim.z + blockIdx.z*kNumWarps + kittens::warpid()]; + + // lock + if (kittens::laneid() == 0) { + while (atomicCAS(lock, 0, 1) == 1) { + // spin + }; + } + __syncwarp(); + + load(tk_reg, _d_k + chunk*d_k.num_elements, d_k.cols); + add(d_k, d_k, tk_reg); + store(_d_k + chunk*d_k.num_elements, d_k, d_k.cols); + + // unlock + if (kittens::laneid() == 0) { + atomicExch(lock, 0); + } + __syncwarp(); + } else { + store(_d_k + chunk*d_k.num_elements, d_k, d_k.cols); + } + + // d_beta + mul(w_bases_reg, w_bases_reg, d_w_reg); // w_bases = einsum('ntk,ntk->ntk', w_bases, d_w) + mul(u_bases_reg, u_bases_reg, d_u_reg); // u_bases = einsum('ntw,ntw->ntw', u_bases, d_u) + + // d_v using available d_u_reg register + mul_row(d_u_reg, d_u_reg, beta_reg); + store(_d_v + chunk*v_num_elements + d_u_reg.cols * vg, d_u_reg, v_row_stride); + + // continue d_beta + auto &w_bases_col = swap_layout_inplace(w_bases_reg); + auto &u_bases_col = swap_layout_inplace(u_bases_reg); + zero(d_beta_reg); + row_sum(d_beta_reg, w_bases_col); // d_beta = einsum('tk->t', w_bases); + row_sum(d_beta_reg, u_bases_col, d_beta_reg); // d_beta += einsum('tw->t', u_bases); + + if constexpr (_value_groups > 1) { + int *lock = &locks[2*kNumWarps*gridDim.z + blockIdx.z*kNumWarps + kittens::warpid()]; + + // lock + if (kittens::laneid() == 0) { + while (atomicCAS(lock, 0, 1) == 1) { + // spin + }; + } + __syncwarp(); + + load(d_beta_buf_reg, _d_beta + chunk*beta_reg.outer_dim*TILE_DIM); + add(d_beta_reg, d_beta_reg, d_beta_buf_reg); + store(_d_beta + chunk*beta_reg.outer_dim*TILE_DIM, d_beta_reg); + + // unlock + if (kittens::laneid() == 0) { + atomicExch(lock, 0); + } + __syncwarp(); + } else { + store(_d_beta + chunk*beta_reg.outer_dim*TILE_DIM, d_beta_reg); + } +} + +// see also: DISPATCH +#define TYPE_DISPATCH(scalar_type, FUNC)\ + switch (scalar_type) {\ + case c10::ScalarType::BFloat16: {\ + using H = c10::BFloat16;\ + using T = bf16;\ + using D = bf16_2;\ + using ACCUM = float2;\ + FUNC;\ + }\ + break;\ + default:\ + TORCH_CHECK(false, "Unsupported type! Try bfloat16");\ + } + +#define DISPATCH_ME(d, seqlen) \ + if (d == 16) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 1, 1, 8)); \ + } else if (d == 32) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 1, 8)); \ + } else if (d == 64) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 2, 4)); \ + } else if (d == 128) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 4, 2)); \ + } else if (d == 256) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 8, 2)); \ + } else if (d == 512) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 16, 2)); \ + } else if (d == 1024) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 32, 2)); \ + } else { \ + TORCH_CHECK(false, "[qkv].size(2) should be 16, 32, 64, 128, 256, 512 or 1024"); \ + } + +#define DISPATCH_ME_FLAT(d, seqlen) \ + if (d == 16) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 1, 1, 8)); \ + } else if (d == 32) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 1, 8)); \ + } else if (d == 64) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 2, 2, 4)); \ + } else if (d == 128) { \ + TYPE_DISPATCH(scalar_type, DELTA_DISPATCH(1, 1, 8, 4)); \ + } else { \ + TORCH_CHECK(false, "[qkv].size(2) should be 16, 32, 64, 128"); \ + } + +void +forward( + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor beta, + torch::Tensor y +) { + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(beta); + CHECK_INPUT(y); + + auto scalar_type = k.scalar_type(); + TORCH_CHECK(q.scalar_type() == scalar_type, "q type mismatch"); + TORCH_CHECK(v.scalar_type() == scalar_type, "v type mismatch"); + TORCH_CHECK(beta.scalar_type() == scalar_type, "beta type mismatch"); + TORCH_CHECK(y.scalar_type() == scalar_type, "y type mismatch"); + + auto batch = k.size(0); + auto seqlen = k.size(1); + auto d = k.size(2); + bool same = true; + for(auto i = 0; i < 3; i++) { + same &= k.size(i) == v.size(i); + same &= q.size(i) == v.size(i); + } + TORCH_CHECK(same, "Q, K and V should be same size"); + constexpr int kChunkSize = 16; + auto num_chunks = seqlen / kChunkSize; + + // kHeight: tiles per sequence block, 2 means 2*16 = 32 sequence elements per warp + // kWidth: tiles per vector, 2 means head dimension is 2*16 = 32 +#define DELTA_DISPATCH(_kHeight, _kWidth, _kWidthGroups, _kNumWarps) \ + constexpr int kHeight = _kHeight; \ + constexpr int kWidth = _kWidth; \ + constexpr int kWidthGroups = _kWidthGroups; \ + constexpr int kNumWarps = _kNumWarps; \ + constexpr int kKey = kWidth * kWidthGroups; \ + auto threads = kNumWarps * kittens::WARP_THREADS; \ + dim3 blocks(kWidthGroups, 1, batch); \ + unsigned long mem_size = sizeof(st) + kNumWarps*sizeof(barrier); \ + delta_forward_kernel<<>>( \ + (int)num_chunks, \ + q.data_ptr(), k.data_ptr(), v.data_ptr(), beta.data_ptr(), \ + y.data_ptr()) + + DISPATCH_ME(d, seqlen); +#undef DELTA_DISPATCH +} + + +void +backward( + torch::Tensor d_out_y, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor beta, + torch::Tensor d_q, torch::Tensor d_k, torch::Tensor d_v, torch::Tensor d_beta, + torch::Tensor u +) { + CHECK_INPUT(d_out_y); + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(beta); + CHECK_INPUT(d_k); + CHECK_INPUT(d_v); + CHECK_INPUT(d_beta); + CHECK_INPUT(u); + + auto scalar_type = d_out_y.scalar_type(); + TORCH_CHECK(q.scalar_type() == scalar_type, "q type mismatch"); + TORCH_CHECK(k.scalar_type() == scalar_type, "k type mismatch"); + TORCH_CHECK(v.scalar_type() == scalar_type, "v type mismatch"); + TORCH_CHECK(beta.scalar_type() == scalar_type, "beta type mismatch"); + TORCH_CHECK(d_k.scalar_type() == scalar_type, "d_k type mismatch"); + TORCH_CHECK(d_v.scalar_type() == scalar_type, "d_v type mismatch"); + TORCH_CHECK(d_beta.scalar_type() == scalar_type, "d_beta type mismatch"); + TORCH_CHECK(u.scalar_type() == scalar_type, "u type mismatch"); + + auto batch = k.size(0); + auto seqlen = k.size(1); + auto d = k.size(2); + bool same = true; + for(auto i = 0; i < 3; i++) { + same &= k.size(i) == v.size(i); + same &= q.size(i) == v.size(i); + } + TORCH_CHECK(same, "Q, K and V should be same size"); + constexpr int kChunkSize = 16; + unsigned long long num_chunks = seqlen / kChunkSize; + // TORCH_CHECK(num_chunks <= 16, "num_chunks should be <= 32 (chunk size is 16)"); + + // kHeight: tiles per sequence block, 2 means 2*16 = 32 sequence elements per warp + // kWidth: tiles per vector, 2 means head dimension is 2*16 = 32 +#define DELTA_DISPATCH(_kHeight, _kWidth, _kWidthGroups, _kNumWarps) \ + constexpr int kHeight = _kHeight; \ + constexpr int kWidth = _kWidth; \ + constexpr int kWidthGroups = _kWidthGroups; \ + constexpr int kNumWarps = _kNumWarps; \ + constexpr int kKey = kWidth * kWidthGroups; \ + int *locks; \ + cudaMalloc(&locks, 3 * batch * kNumWarps * sizeof(int)); \ + cudaMemset(locks, 0, 3 * batch * kNumWarps * sizeof(int)); \ + auto threads = kNumWarps * kittens::WARP_THREADS; \ + dim3 gridDim(kWidthGroups, 1, batch); \ + dim3 blockDim(threads, 1, 1); \ + size_t mem_size = 2*sizeof(st) + 2*kNumWarps*sizeof(barrier); \ + auto kernel = delta_backward_kernel; \ + CHECK_CUDA_ERROR(cudaFuncSetAttribute(delta_backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size)); \ + DeltaBackwardArgs args = {num_chunks, \ + d_out_y.data_ptr(), \ + q.data_ptr(), k.data_ptr(), v.data_ptr(), beta.data_ptr(), \ + d_q.data_ptr(), d_k.data_ptr(), d_v.data_ptr(), d_beta.data_ptr(), \ + u.data_ptr(), locks}; \ + void *args_ptr[] = {&args}; \ + cudaLaunchKernel((const void *)&delta_backward_kernel, gridDim, blockDim, args_ptr, mem_size, at::cuda::getCurrentCUDAStream().stream()); \ + cudaFree(locks) + + DISPATCH_ME_FLAT(d, seqlen); +#undef DELTA_DISPATCH +} diff --git a/accelerated_scan/kitten.cu b/accelerated_scan/kitten.cu new file mode 100644 index 0000000..56ca770 --- /dev/null +++ b/accelerated_scan/kitten.cu @@ -0,0 +1,130 @@ +//#include "../../../src/kittens.cuh" +#include "tk/src/kittens.cuh" +#include "tk/src/common/pyutils/torch_helpers.cuh" + +using namespace kittens; // this kernel only handles headdim=q_reg.cols for simplicity. Also n should be a multiple of 256 here. + + +template +__global__ void causal_attend_kernel( + int seqlen, + const H* __restrict__ __q__, + const H* __restrict__ __k__, + const H* __restrict__ __v__, + const float* __restrict__ __f__, + H* __o__ +) { + auto warpid = kittens::warpid(); + auto block_start = blockIdx.x*(seqlen*_width*TILE_DIM); + const bf16 *_q = reinterpret_cast(__q__) + block_start, + *_k = reinterpret_cast(__k__) + block_start, + *_v = reinterpret_cast(__v__) + block_start; + bf16 *_o = reinterpret_cast(__o__) + block_start; + + extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory + shared_allocator al((int*)&__shm[0]); + + // K and V live in shared memory -- this is about all that will fit. + st_bf<_height, _width, ducks::st_layout::swizzle> (&k_smem)[kNumWorkers] = al.allocate, kNumWorkers>(); + st_bf<_height, _width, ducks::st_layout::swizzle> (&v_smem)[kNumWorkers] = al.allocate, kNumWorkers>(); + + // Initialize all of the register tiles. + rt_bf<_height, _width> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l + rt_fl<_height, _height> att_block; + rt_bf<_height, _height> att_block_mma; + rt_fl<_height, _width> o_reg; + + const int qo_blocks = seqlen / (q_reg.rows*kNumWorkers); + + for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) { + // each warp loads its own Q tile of 16x16 + auto q_index = q_blk*kNumWorkers + warpid; + load(q_reg, _q + q_index*q_reg.num_elements, q_reg.cols); + + zero(o_reg); // zero flash attention O register. + + // iterate over k, v for these q's that have been loaded + for(auto kv_blk = q_blk; kv_blk >= 0; kv_blk--) { + int kv_warp_index = kv_blk*kNumWorkers + warpid; + if (kv_warp_index <= q_index) { // ensure causality + // each warp loads its own chunk of k, v into shared memory + load(v_smem[warpid], _v + kv_warp_index*q_reg.num_elements, q_reg.cols); + load(k_smem[warpid], _k + kv_warp_index*q_reg.num_elements, q_reg.cols); + } + __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase + + // now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg. + for(int subtile = kNumWorkers-1; subtile >= 0; subtile--) { + int kv_subtile_index = kv_blk*kNumWorkers + subtile; + if (!(kv_subtile_index <= q_index)) { // ensure causality + continue; + } + load(k_reg, k_smem[subtile]); // load k from shared into registers + + zero(att_block); // zero 16x16 attention tile + mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T + + copy(att_block_mma, att_block); // convert to bf16 for mma_AB + + if (kv_subtile_index == q_index) { + make_causal(att_block_mma, att_block_mma, 0); + } + + load(v_reg, v_smem[subtile]); // load v from shared into registers. + rt_bf<_height, _width, ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg + + mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul. + } + __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk + } + + store(_o + (q_blk*kNumWorkers + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/ + } +} + +void +attend(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor f, torch::Tensor o_small) { + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(f); + CHECK_INPUT(o_small); + + auto batch = q.size(0); + auto head = q.size(1); + auto seqlen = q.size(2); + auto d = q.size(3); + auto dv = v.size(3); + bool k_same = true; + for(auto i = 0; i < 4; i++) { + k_same &= q.size(i) == k.size(i); + } + // This is just a restriction of what we're doing now... + TORCH_CHECK(k_same, "Q and K should be same size"); + TORCH_CHECK(q.scalar_type() == c10::ScalarType::BFloat16, "Q is a Bfloat"); + TORCH_CHECK(k.scalar_type() == c10::ScalarType::BFloat16, "K is a Bfloat"); + TORCH_CHECK(v.scalar_type() == c10::ScalarType::BFloat16, "V is a Bfloat"); + + using H = c10::BFloat16; + constexpr int kHeight = 2; // tiles per sequence block, 4 means 4*16 = 64 sequence elements per warp + constexpr int kWidth = 2; // tiles per vector, 4 means head dimension is 4*16 = 64 + TORCH_CHECK(d == 32, "q.size(3) and k.size(3) should be 32"); + TORCH_CHECK(dv == 32, "v.size(3) should be 32"); + + constexpr int kNumWorkers = 16; + + unsigned long mem_size = 2*kNumWorkers*sizeof(st_bf); + + TORCH_CHECK(seqlen % (kNumWorkers*kittens::TILE_DIM) == 0, "The number of elements should be divisible the number of workers times stored fragments"); + + auto threads = kNumWorkers * kittens::WARP_THREADS; + //printf("[causal_attend] Requesting %lu bytes of memory for %d worker warps (%d threads)\n", mem_size, kNumWorkers, threads); + CHECK_CUDA_ERROR(cudaFuncSetAttribute(causal_attend_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size)); + + causal_attend_kernel<<>>((int)seqlen, q.data_ptr(), k.data_ptr(), v.data_ptr(), f.data_ptr(), o_small.data_ptr()); + + CHECK_CUDA_ERROR(cudaDeviceSynchronize()); +} + +//#include "harness.impl" + diff --git a/accelerated_scan/kitten.py b/accelerated_scan/kitten.py new file mode 100644 index 0000000..aa26bee --- /dev/null +++ b/accelerated_scan/kitten.py @@ -0,0 +1,125 @@ +from pathlib import Path + +import torch +from torch.utils.cpp_extension import load_inline + +module = load_inline( + name='kitten', + cpp_sources=[""" +extern void attend(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor f, torch::Tensor o); +"""], + cuda_sources=[(Path(__file__).parent / 'kitten.cu').read_text()], + functions=['attend'], + verbose=True, + extra_cuda_cflags=[ + "-O3", + "-std=c++20", + "--ptxas-options=-v", + "-lineinfo", + #"--fmad", "false", + "--use_fast_math", + "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-DKITTENS_4090", + f"-I{str(Path(__file__).parent)}" + ] +) + +delta_module = load_inline( + name='delta', + cpp_sources=["""\ +extern void forward(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor beta, + torch::Tensor y); +extern void backward(torch::Tensor d_out_y, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor beta, + torch::Tensor d_q, torch::Tensor d_k, torch::Tensor d_v, torch::Tensor d_beta, + torch::Tensor u); + """], + cuda_sources=[(Path(__file__).parent / 'delta.cu').read_text()], + functions=['forward', 'backward'], + verbose=True, + extra_cuda_cflags=[ + "-O3", + "-std=c++20", + "--ptxas-options=-v", + "-lineinfo", + #"--fmad", "false", + "--use_fast_math", + "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-DKITTENS_4090", + f"-I{str(Path(__file__).parent)}" + ] +) + +attend = module.attend +delta_forward = delta_module.forward +delta_backward = delta_module.backward + + +class Delta(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, beta): + N, H, T, D = q.shape + NH = N * H + assert k.shape == v.shape == (N, H, T, D) + assert beta.shape == (N, H, T) + assert q.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert beta.is_contiguous() + + ctx.save_for_backward(q, k, v, beta) + q = q.view(NH, T, D) + k = k.view(NH, T, D) + v = v.view(NH, T, D) + beta = beta.view(NH, T) + y = torch.empty_like(q) + delta_forward(q, k, v, beta, y) + return y.view(N, H, T, D) + + @staticmethod + def backward(ctx, d_y): + q, k, v, beta = ctx.saved_tensors + N, H, T, D = q.shape + NH = N * H + + d_y = d_y.view(NH, T, D) + q = q.view(NH, T, D) + k = k.view(NH, T, D) + v = v.view(NH, T, D) + beta = beta.view(NH, T) + d_q = q.new_zeros(NH, T, D) + d_k = k.new_zeros(NH, T, D) + d_v = v.new_zeros(NH, T, D) + d_beta = beta.new_zeros(NH, T) + u = v.new_zeros(NH, T, D) # buffer + delta_backward( + d_y, + q, k, v, beta, + d_q, d_k, d_v, d_beta, + u + ) + return d_q.view(N, H, T, D), d_k.view(N, H, T, D), d_v.view(N, H, T, D), d_beta.view(N, H, T) + + +def delta(query, key, value, beta): + """Delta rule compressive attention. + + Maintains the state matrix for a linear model using online stochastic gradient descent + on the mean squared error objective. Beta is the learning rate, key is the input and value is the target. + + At every step, the difference between current prediction given a key and the value is added to the state. + When beta is constant ones this is equivalent to causal linear attention, + which always stores the complete value. + + Arguments: + query (torch.Tensor): shape (N, H, T, D), regression monitoring inputs + key (torch.Tensor): shape (N, H, T, D), regression inputs + value (torch.Tensor): shape (N, H, T, D), regression outputs + beta (torch.Tensor): shape (N, H, T), learning rate + + Returns: + torch.Tensor: shape (N, H, T, D), query outputs + """ + return Delta.apply(query, key, value, beta) \ No newline at end of file diff --git a/accelerated_scan/tk b/accelerated_scan/tk new file mode 160000 index 0000000..dc13fbb --- /dev/null +++ b/accelerated_scan/tk @@ -0,0 +1 @@ +Subproject commit dc13fbbc3c3eb897934865f2d0a1546ee72a2e7e diff --git a/pyproject.toml b/pyproject.toml index d652fd5..c7570ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,9 @@ build-backend = "hatchling.build" name = "accelerated-scan" dynamic = ["version"] dependencies = [ - "torch>=2.1.0" + "torch>=2.1.0", + "ninja", # compiling kernels on the fly + "matplotlib" # triton perf_report ] authors = [ { name="Volodymyr Kyrylov", email="vol@wilab.org.ua" }, diff --git a/tests/bench.py b/tests/bench.py index 88eebb5..db20b50 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -11,10 +11,10 @@ def init(B, C, T, *, device, requires_grad=False): return gates, tokens -def make_benchmark(plot_name, *, direction, max_exponent=17): +def make_benchmark(plot_name, *, direction, batch_size=82, dim=64, max_exponent=15): return triton.testing.Benchmark( x_names=["SEQUENCE_LENGTH"], # argument names to use as an x-axis for the plot - x_vals=[2**i for i in range(7, max_exponent)], + x_vals=[2**i for i in range(8, max_exponent)], xlabel='sequence length', ylabel='ms', x_log=True, @@ -22,11 +22,26 @@ def make_benchmark(plot_name, *, direction, max_exponent=17): line_arg="provider", # argument name whose value corresponds to a different line in the plot #line_names=["triton", "ref", "warp"], #line_vals=["triton", "ref", "warp"], - line_names=["warp"], - line_vals=["warp"], + #line_names=["flash", "kittenexp", "warp"], + #line_vals=["flash", "kittenexp", "warp"], + #line_names=["linear", "flash2", "delta", "scan"], + #line_vals=["kitten", "flash", "delta", "warp"], + #line_names=["delta", "fla"], + #line_vals=["delta", "fla"], + + #line_names=["linear", "flash2", "delta", "fla-delta", "warpscan"], + #line_vals=["kitten", "flash", "delta", "fla", "warp"], + + line_names=["flash2", "delta", "fla-delta", "warpscan"], + line_vals=["flash", "delta", "fla", "warp"], + + # line_names=["flash2", "delta", "fla", "scan"], + # line_vals=["flash", "delta", "fla", "warp"], plot_name=plot_name, args={ "direction": direction, + "dim": dim, + "batch_size": batch_size, } ) @@ -36,9 +51,12 @@ def grad2(f, x, y, grad_out): sum(x.sum().item() for x in grad) -def bench(provider, SEQUENCE_LENGTH, device="cuda", direction: Literal["forward", "backward", "train"] = "forward"): - B, C, T = 8, 1536, SEQUENCE_LENGTH - gates, tokens = init(B, C, T, device=device, requires_grad=direction=="train") +from collections import defaultdict +c = defaultdict(int) + +def bench(provider, SEQUENCE_LENGTH, device="cuda", batch_size: int = 82, dim: int = 32, direction: Literal["forward", "backward", "train"] = "forward"): + B, H, D, T = 1, batch_size, dim, SEQUENCE_LENGTH + gates, tokens = init(B, H*D, T, device=device, requires_grad=direction=="train") outputs = torch.empty_like(tokens) grad_outputs = torch.empty_like(tokens) @@ -48,10 +66,10 @@ def bench(provider, SEQUENCE_LENGTH, device="cuda", direction: Literal["forward" match direction: case "forward": from accelerated_scan.triton import forward_scan - scan = lambda: forward_scan[(B,C)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False) + scan = lambda: forward_scan[(B,H*D)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False) case "backward": from accelerated_scan.triton import backward_scan - scan = lambda: backward_scan[(B,C)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False) + scan = lambda: backward_scan[(B,H*D)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False) case "train": # note that these measurements include time for memory allocation for forward output tensors from accelerated_scan.triton import scan as train_scan @@ -71,7 +89,11 @@ def bench(provider, SEQUENCE_LENGTH, device="cuda", direction: Literal["forward" match direction: case "forward": from accelerated_scan.warp import warpscan_forward - scan = lambda: warpscan_forward(gates, tokens, outputs, False) + o = torch.empty_like(tokens) + def scan(): + warpscan_forward(gates, tokens, outputs, False) + return o * outputs + scan# = lambda: warpscan_forward(gates, tokens, outputs, False) case "backward": from accelerated_scan.warp import warpscan_forward scan = lambda: warpscan_forward(gates, tokens, outputs, True) @@ -79,6 +101,94 @@ def bench(provider, SEQUENCE_LENGTH, device="cuda", direction: Literal["forward" # note that these measurements include time for memory allocation for forward output tensors from accelerated_scan.warp import scan as train_scan scan = lambda: grad2(train_scan, gates, tokens, grad_outputs) + + case "kitten": + print(f"Running {provider} with sequence length {SEQUENCE_LENGTH} {direction}") + from accelerated_scan.kitten import attend + + gates, tokens = init(B, H, T, device=device, requires_grad=direction=="train") + + k = tokens.unsqueeze(-1).expand(B, H, T, D).bfloat16().contiguous() + q = torch.ones_like(k).bfloat16().contiguous() + v = torch.ones_like(q).bfloat16().contiguous() + f = gates.float().contiguous() + + match direction: + case "forward": + def scan(): + o = torch.empty_like(v).bfloat16().contiguous() + attend(q, k, v, f, o) + + case "flash": + print(f"Running {provider} with sequence length {SEQUENCE_LENGTH} {direction}") + from torch.nn.functional import scaled_dot_product_attention + + gates, tokens = init(B, H, T, device=device, requires_grad=direction=="train") + + k = tokens.unsqueeze(-1).expand(B, H, T, D).bfloat16().contiguous().requires_grad_() + q = torch.ones_like(k).bfloat16().contiguous().requires_grad_() + v = torch.ones_like(q).bfloat16().contiguous().requires_grad_() + f = gates.float().contiguous() + o = torch.empty_like(v).bfloat16().contiguous() + do = torch.randn_like(o) + + match direction: + case "forward": + scan = lambda: scaled_dot_product_attention(q, k, v, is_causal=True) + case "train": + def scan(): + grad = torch.autograd.grad(scaled_dot_product_attention(q, k, v), (q, k, v), do) + sum(x.sum().item() for x in grad) + case "delta": + print(f"Running {provider} with sequence length {SEQUENCE_LENGTH} {direction}") + from accelerated_scan.kitten import delta_forward, delta_backward, delta + + gates, tokens = init(B, H, T, device=device, requires_grad=direction=="train") + + k = tokens.unsqueeze(-1).expand(B, H, T, D).bfloat16().contiguous().requires_grad_() + q = torch.ones_like(k).bfloat16().contiguous().requires_grad_() + v = torch.ones_like(q).bfloat16().contiguous().requires_grad_() + f = gates.bfloat16().contiguous() + o = torch.empty_like(v).bfloat16().contiguous() + do = torch.randn_like(o) + + match direction: + case "forward": + def scan(): + delta(q, k, v, f) + case "train": + def scan(): + grad = torch.autograd.grad(delta(q, k, v, f), (q, k, v, f), do) + sum(x.sum().item() for x in grad) + + case "fla": + print(f"Running {provider} with sequence length {SEQUENCE_LENGTH} {direction}") + from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule + gates, tokens = init(B, H, T, device=device, requires_grad=direction=="train") + + k = tokens.unsqueeze(-1).expand(B, H, T, D).bfloat16().contiguous() + q = torch.ones_like(k).bfloat16().contiguous() + v = torch.ones_like(q).bfloat16().contiguous() + f = gates.bfloat16().contiguous() + o = torch.empty_like(v).bfloat16().contiguous() + + k = k.view(B, H, T, D).requires_grad_() + q = q.view(B, H, T, D).requires_grad_() + v = v.view(B, H, T, D).requires_grad_() + f = f.view(B, H, T).requires_grad_() + o = o.view(B, H, T, D) + do = torch.randn_like(o) + + match direction: + case "forward": + def scan(): + fused_chunk_delta_rule(q, k, v, f, BT=16) + + case "train": + def scan(): + y, _ = fused_chunk_delta_rule(q, k, v, f, BT=16) + grad = torch.autograd.grad(y, (q, k, v, f), do) + sum(x.sum().item() for x in grad) case _: raise ValueError(f"Unknown provider {provider}") @@ -94,13 +204,16 @@ def bench(provider, SEQUENCE_LENGTH, device="cuda", direction: Literal["forward" if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() + parser.add_argument("--dim", type=int, default=32) + parser.add_argument("--batch-size", type=int, default=82) + parser.add_argument("--max-exponent", type=int, default=15) parser.add_argument("--direction", choices=["forward", "backward", "train", "all"], default="all") args = parser.parse_args() directions = { - 'forward': make_benchmark("accelerated_scan: forward speed of (8,1536,seqlen), inference mode", direction="forward"), - 'backward': make_benchmark("accelerated_scan: backward speed of (8,1536,seqlen), inference mode", direction="backward"), - 'train': make_benchmark("accelerated_scan: training speed of (8,1536,seqlen)", direction="train", max_exponent=15), + 'forward': make_benchmark("accelerated_scan: forward speed", dim=args.dim, direction="forward", max_exponent=args.max_exponent, batch_size=args.batch_size), + 'backward': make_benchmark(f"accelerated_scan: backward speed of ({args.batch_size},{args.dim},seqlen), inference mode", dim=args.dim, direction="backward", max_exponent=args.max_exponent, batch_size=args.batch_size), + 'train': make_benchmark(f"accelerated_scan: training speed of ({args.batch_size},{args.dim},seqlen)", direction="train", dim=args.dim, max_exponent=args.max_exponent, batch_size=args.batch_size), } benchmarks = [] @@ -112,4 +225,7 @@ def bench(provider, SEQUENCE_LENGTH, device="cuda", direction: Literal["forward" case dir: benchmarks.append(directions[dir]) - triton.testing.perf_report(benchmarks)(bench).run(save_path=".", print_data=True) + try: + triton.testing.perf_report(benchmarks)(bench).run(save_path=".", print_data=True) + finally: + print(f"{dict(c)=}") diff --git a/tests/chunk_size.py b/tests/chunk_size.py new file mode 100644 index 0000000..28eaea6 --- /dev/null +++ b/tests/chunk_size.py @@ -0,0 +1,46 @@ +#%% +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import numpy as np + +plt.rcParams['font.family'] = 'serif' + + +dim = np.array([16,32,64,128,256,512,1024]) +chunk_size = np.array([8,16,32,64,128,256,512]) + +colormap = plt.get_cmap('viridis') +num_colors = len(dim) +colors = [colormap(i) for i in np.linspace(0, 1, num_colors)] + + +plt.figure(figsize=(10,7)) + +words = chunk_size[:,None] * dim +state = dim*dim + +# plot horizontal lines for state +for i in range(len(dim)): + plt.axhline(y=state[i], color=colors[i], linestyle=':', label=f'${dim[i]}^2$', linewidth=1) + d = dim[i] + + # plt.plot(chunk_size, 2 * chunk_size * d, color=colors[i], label=f'$2C\cdot {dim[i]}$', linewidth=1) + # c = d / 2 + # plt.scatter(c, 2*c*d, color=colors[i], label=f'breakeven chunk for dim {dim[i]}', s=50) + + plt.plot(chunk_size, 3 * chunk_size * d, color=colors[i], label=f'$3C\cdot {dim[i]}$', linewidth=1) + c = d / 3 + plt.scatter(c, 3*c*d, color=colors[i], label=f'saturation length for dim {dim[i]} is {int(c)}', s=50) + +plt.title('Until when does chunking reduce memory usage?') + +# powers of 2 y axis +plt.yscale('log', base=2) +# integers x axis +plt.xticks(chunk_size) +plt.xlabel('chunk size') +plt.ylabel('used memory words') + +plt.legend() + +plt.savefig('chunk_size.pdf', bbox_inches='tight', dpi=300) \ No newline at end of file diff --git a/tests/deltacu.py b/tests/deltacu.py new file mode 100644 index 0000000..5e057de --- /dev/null +++ b/tests/deltacu.py @@ -0,0 +1,114 @@ +#%% + +import os +import torch +from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange + +import pytest + +from tests.deltanet import tileprint +from tests.deltanet import shape, make_example +from tests.deltanet import backward as backward_ref +from tests.deltanet import forward as forward_ref + + + +@pytest.mark.parametrize('T', [16, 32, 64, 128, 256]) +@pytest.mark.parametrize('D', [16, 32, 64, 128]) +def test_backward(T, D): + NH, T, D = 1, T, D + + q, k, v, beta = make_example(NH, T, D, device='cuda', dtype=torch.bfloat16) + d_out_w = torch.zeros_like(k) # placeholder + d_out_u = torch.zeros_like(v) # placeholder + d_out_y = torch.randn_like(v) / D**0.5 # actual gradient + + d_q0, d_k0, d_v0, d_beta0 = backward_ref( + d_out_w.clone(), d_out_u.clone(), d_out_y.clone(), q, k, v, beta, + chunk_size=16 + ) + + d_q1 = q.new_zeros(NH, T, D) + d_k1 = k.new_zeros(NH, T, D) + d_v1 = v.new_zeros(NH, T, D) + d_beta1 = beta.new_zeros(NH, T) + u1 = v.new_zeros(NH, T, D) # placeholder + from accelerated_scan import kitten + kitten.delta_backward( + d_out_y.clone(), + q, k, v, beta, + d_q1, d_k1, d_v1, d_beta1, + u1 + ) + + torch.set_printoptions(precision=4, sci_mode=False, linewidth=300) + #torch.set_printoptions(linewidth=300) + + #print(d_q0 - d_q1, 'd_q diff') + err = (d_q0 - d_q1).abs().max().item() + assert allclose(d_q0, d_q1, atol=5e-2), f'd_q is wrong: {err}' + + #print(d_k0, 'd_k ref') + #print(d_k1, 'd_k hyp') + #print((d_k1 - d_k0).abs().topk(10), 'd_k diff') + err = (d_k0 - d_k1).abs().max().item() + assert allclose(d_k0, d_k1, atol=5e-2), f'd_k is wrong: {err}' # ??? + err = (d_v0 - d_v1).abs().max().item() + assert allclose(d_v0, d_v1, atol=5e-2), f'd_v is wrong: {err}' + + #print(d_beta0, 'ref') + #print(d_beta1, 'hyp') + # XXX: atol=1e-2 might be too low. cast to float32? + err = (d_beta0 - d_beta1).abs().max().item() + assert allclose(d_beta0, d_beta1, atol=5e-2), f'd_beta is wrong: {err}' + + +@pytest.mark.parametrize('T', [16, 32, 64, 128, 256, 512, 1024]) +@pytest.mark.parametrize('D', [16, 32, 64, 128]) +def test_forward(T, D): + NH, T, D = 1, T, D + + chunk_size = 16 + C = T // chunk_size + + q, k, v, beta = make_example(NH, T, D, device='cuda', dtype=torch.bfloat16) + + # q, k, v, beta = ( + # q.view(NH, C, chunk_size, D).view(NH*C,chunk_size,D), + # k.view(NH, C, chunk_size, D).view(NH*C,chunk_size,D), + # v.view(NH, C, chunk_size, D).view(NH*C,chunk_size,D), + # beta.view(NH, C, chunk_size).view(NH*C,chunk_size) + # ) + + torch.set_printoptions(linewidth=300, sci_mode=False, precision=3) + w, u, y = forward_ref(q, k, v, beta, chunk_size=chunk_size) + + # w, u, y = ( + # w.view(NH, C, chunk_size, D).view(NH,T,D), + # u.view(NH, C, chunk_size, D).view(NH,T,D), + # y.view(NH, C, chunk_size, D).view(NH,T,D) + # ) + + y2 = y.new_zeros(NH, T, D) + from accelerated_scan import kitten + kitten.delta_forward(q, k, v, beta, y2) + + assert allclose(y, y2, atol=2e-2), 'y2 is wrong' + + +@pytest.mark.parametrize('T', [2048, 4096, 8192, 16384]) +@pytest.mark.parametrize('D', [16, 32, 64, 128]) +def test_longf(T, D): + return test_forward(T, D) + +@pytest.mark.parametrize('T', [2048, 4096, 8192, 16384]) +@pytest.mark.parametrize('D', [16, 32, 64, 128]) +def test_longb(T, D): + return test_backward(T, D) + + +if __name__ == '__main__': + #pytest.main([__file__, "--disable-warnings", "-v"]) + #test_forward(T=128, D=16) + test_backward(T=128, D=16) + #test_forward(T=32, D=16) diff --git a/tests/deltanet.py b/tests/deltanet.py new file mode 100644 index 0000000..7838708 --- /dev/null +++ b/tests/deltanet.py @@ -0,0 +1,478 @@ +""" +DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory. + +`forward` is inspired by Yang 2024. It applies single chunk version pointwise and then performs chunk-level stitching. +`forward_loop` is the reference implementation of the original recurrence. + +References: +[1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985) + +Method 1, section 3 guides `decay_values`. +https://ecommons.cornell.edu/items/92a11030-dca1-45d4-a0ba-732cf962b2b2 + +[2] Parallelizing Linear Transformers with the Delta Rule over Sequence Length (Yang et al 2024) + +- equation 5 is a specialization of method 1 of [1] is in `decay_values` +- equation 6 is application of decayed keys to values is also in `decay_values` +- `forward_chunkwise` uses the distributed form of equation 7 and 8 + (actually look the two equations before it instead, they are easier to read) + +https://arxiv.org/abs/2406.06484 + +[3] Linear Transformers Are Secretly Fast Weight Programmers (Schlag et al 2021) + +Introduction to Transformers as RNNs. Ignore all of the kernel stuff. +https://arxiv.org/abs/2102.11174 +""" +#%% + +import os +os.environ['TORCH_LOGS'] = 'output_code' +import torch +from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange + +#set_float32_matmul_precision('high') + + +def tileprint(K, name='K'): + "format matches tileprint in tk code so you can diff it" + assert K.shape == (16, 16), f'K shape is {K.shape}' + for laneid in range(32): + row_top = laneid // 4 + row_bottom = row_top + 8 + col_left = laneid % 4 * 2 + col_right = col_left + 8 + + def fmt(r,c,tag): + odd = "y" in tag + if odd: # do not print r for odd rows because cuda printf silently runs out of function arguments + return f"{name}[,{c:02}] {tag}={K[r,c]: .3f}" + else: + return f"{name}[{r:02},{c:02}] {tag}={K[r,c]: .3f}" + + print(f"lane={laneid:02}", " ".join([ + " ".join([fmt(row_top, col_left, "0x"), fmt(row_top, col_left+1, "0y")]), + " ".join([fmt(row_bottom, col_left, "1x"), fmt(row_bottom, col_left+1, "1y")]), + " ".join([fmt(row_top, col_right, "2x"), fmt(row_top, col_right+1, "2y")]), + " ".join([fmt(row_bottom, col_right, "3x"), fmt(row_bottom, col_right+1, "3y")]) + ])) + + +def decay_values(q, k, v, beta, chunk_size=2): + NH, T, D = shape(q, k, v, beta) + C = T // chunk_size + q_, k_, v_, beta_ = ( + q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D), + v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size) + ) + + # evaluate all chunks in parallel + beta__ = beta_.unsqueeze(-1) + w = beta__ * k_.clone() + u = beta__ * v_.clone() + K = einsum('nsd,ntd->nst', k_, k_) # (chunk_size,chunk_size) matrix + + for t in range(1,chunk_size): + w[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone()) + u[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone()) + + # attend to decayed values + qk = einsum("nsk,ntk->nst", q_, k_) + qk.tril_() + y = einsum("nst,ntj->nsj", qk, u) + + return w, u, y + + +def forward(q, k, v, beta, chunk_size=2): + "decay values applying deltanet forgetting rules, then stitch chunks" + NH, T, D = shape(q, k, v, beta) + C = T // chunk_size + + w, u, y = decay_values(q, k, v, beta, chunk_size=chunk_size) + + # stitch chunks sequentially + q_ = q.view(NH, C, chunk_size, D) + k_ = k.view(NH, C, chunk_size, D) + u = u.view(NH, C, chunk_size, D) + w = w.view(NH, C, chunk_size, D) + y = y.view(NH, C, chunk_size, D) + + state = u.new_zeros(NH, D, D) + + for c in range(C): + qc = q_[:, c] # load q + kc = k_[:, c] # load k + wc = w[:, c] # load w + uc = u[:, c] # load u + yc = y[:, c] # load y + + if c: + u_old = einsum('ntk,nvk->ntv', wc, state) # DDT + + # attend to old values + qk = einsum("nsi,nti->nst", qc, kc) # TDT + qk = qk.tril() + + y_prev = einsum("nst,ntv->nsv", qk, u_old) # TTD + yc = yc - y_prev + + y_cur = einsum('nsk,nvk->nsv', qc, state) # DDT + yc = yc + y_cur + + uc = uc - u_old + + state = state + einsum('ntv,ntk->nvk', uc, kc) + y[:, c] = yc # store + + w = w.view(NH, T, D) + u = u.view(NH, T, D) + y = y.view(NH, T, D) + + return w, u, y + + +def forward_loop(q, k, v, beta): + "reference: w_t = w_{t-1} + beta_t (v_t - w_t k_t) k_t" + NH, T, D = shape(q, k, v, beta) + + w = k.new_zeros(NH, D, D) + y = [] + + for t in range(T): + q_ = q[:, t] + k_ = k[:, t] + v_ = v[:, t] + beta_ = beta[:, t].unsqueeze(-1) + + v_old = einsum("nij,nj->ni", w, k_) + delta = beta_ * (v_ - v_old) + w = w + einsum("ni,nj->nij", delta, k_) + + y.append(einsum("nij,nj->ni", w, q_)) + + return stack(y, dim=1) + + +def shape(q, k, v, beta=None): + NH, T, D = (q if q is not None else k).shape + if q is not None: + assert q.shape == (NH, T, D) + if v is not None: + assert k.shape == v.shape + if beta is not None: + assert beta.shape == (NH, T) + return NH, T, D + + +def make_example(NH, T, D, device='cpu', dtype=torch.float32): + manual_seed(0) + q = randn(NH, T, D, device=device, dtype=dtype) / D**0.5 + q.requires_grad_() + k = randn(NH, T, D, device=device, dtype=dtype) / D**0.5 + k.requires_grad_() + v = randn(NH, T, D, device=device, dtype=dtype) / D**0.5 + v.requires_grad_() + beta = randn(NH, T, device=device, dtype=dtype).sigmoid() + beta.requires_grad_() + return q, k, v, beta + + +@no_grad() +def backward(d_out_w_long, d_out_u_long, d_out_y_long, q_long, k_long, v_long, beta_long, chunk_size=2): + NH, T, D = shape(q_long, k_long, v_long, beta_long) + + C = T // chunk_size + q, k, v, beta, d_out_y = ( + q_long.view(NH*C, chunk_size, D), k_long.view(NH*C, chunk_size, D), + v_long.view(NH*C, chunk_size, D), beta_long.view(NH*C, chunk_size), + d_out_y_long.view(NH*C, chunk_size, D) + ) + + # + # allocations + # + + # this group is loaded from global memory + q = q.clone() # load q + k = k.clone() # load k + v = v.clone() # load v + beta = beta.clone() # load beta + #d_out_w = d_out_w.clone() # ntk # placeholders + #d_out_y = d_out_y.clone() # ntv # placeholders + + w = k.new_zeros(NH*C, chunk_size, D) # ntk + u = v.new_zeros(NH*C, chunk_size, D) # ntw + w_bases = w.clone() # ntk + u_bases = u.clone() # ntw + + bk = einsum('nt,ntk->ntk', beta, k) + + bKl = k.new_zeros(NH*C, chunk_size, chunk_size) + tt = k.new_zeros(NH*C, chunk_size, chunk_size) + + d_k = k.new_zeros(NH*C, chunk_size, D) # nsk + tk = k.new_zeros(NH*C, chunk_size, D) # ntk + + # + # forward + # + + tt = einsum('ntk,nsk->nts', k, k) + tt = tt.tril(diagonal=-1) # make_causal(0); set_diagonal(0) + bKl = einsum('nt,nts->nts', beta, tt) # multiply each row of K by beta + + u_bases = v + v = einsum('nt,ntw->ntw', beta, v) + + for t in range(chunk_size): + tk = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row + w[:, t] = bk[:, t, :] - tk[:, t, :] + tk = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row + u[:, t] = v[:, t, :] - tk[:, t, :] + + w.clone() # store w + u.clone() # store u + + # + # stitch_backward + # + w_long = w.view(NH, T, D) + u_long = u.view(NH, T, D) + d_q_1_long, d_k_1_long, d_out_w_long, d_out_u_long = stitch_backward(d_out_y_long, q_long, k_long, w_long, u_long, C, chunk_size) + + d_out_w, d_out_u = ( + d_out_w_long.view(NH*C, chunk_size, D), d_out_u_long.view(NH*C, chunk_size, D) + ) + + w_bases = einsum('nts,nsk->ntk', tt, w) + w_bases = k - w_bases + v = einsum('nts,nsw->ntw', tt, u) + u_bases = u_bases - v + + # + # causal_attend_backward for d_q, d_k_2, d_out_u + # + + tt = einsum('nsv,ntv->nst', d_out_y, u) + tt = tt.tril() + d_q = einsum('nst,ntk->nsk', tt, k) + d_q.clone() # store + + d_k_2 = einsum('nst,nsk->ntk', tt, q) + d_k_2.clone() # store to shared memory? + + tt = einsum('nsk,ntk->nst', q, k) + tt = tt.tril() + + v.zero_() # reuse register space of v for d_out_u + d_out_u = d_out_u.clone() # load ntw + d_out_u += einsum('nst,nsv->ntv', tt, d_out_y) + + # + # backward for d_k, d_v, d_beta + # + + d_k.zero_() + + for t in range(chunk_size-1,-1,-1): + # d_k + tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t + tt[:, t:, :] = 0 + tk = einsum('njt,njk->ntk', tt, k) + + tt = einsum('njv,ntv->njt', u, d_out_u) # matmul for the sake of one column t + tt[:, t:, :] = 0 + tk += einsum('njt,njk->ntk', tt, k) + + d_k[:, t] += tk[:, t] + + # backpropagate through time, updating only remaining timestamps + tt.zero_() + tt[:, t] += bKl[:, t] + tk = einsum('ntj,ntk->njk', tt, d_out_w) + d_out_w = d_out_w - tk + tk = einsum('ntj,ntk->njk', tt, d_out_u) + d_out_u = d_out_u - tk + + d_k = d_out_w - d_k + d_k = einsum('ntk,nt->ntk', d_k, beta) + + # decay w and u + tt = einsum('ntw,njw->ntj', d_out_w, w) + tt += einsum('ntw,njw->ntj', d_out_u, u) + tt.tril_(diagonal=-1) + + tk = einsum('ntj,ntk->njk', tt, bk) + d_k = d_k - tk + d_k_2 = d_k_2.clone() # load from shared memory + d_k = d_k_2 + d_k + d_k = d_k.clone() # store + + # d_beta + w_bases = einsum('ntk,ntk->ntk', w_bases, d_out_w) + u_bases = einsum('ntw,ntw->ntw', u_bases, d_out_u) + + # d_v using d_out_u register + d_out_u = einsum('nt,ntv->ntv', beta, d_out_u) + d_v = d_out_u.clone() # store + + # continue d_beta reusing the beta register + beta = einsum('ntk->nt', w_bases) + beta += einsum('ntv->nt', u_bases) + d_beta = beta.clone() # store + + d_q_long = d_q.view(NH, T, D) + d_q_1_long + d_k_long = d_k.view(NH, T, D) + d_k_1_long + d_v_long = d_v.view(NH, T, D) + d_beta_long = d_beta.view(NH, T) + + return d_q_long, d_k_long, d_v_long, d_beta_long + + +def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): + NH, T, D = shape(q, k, None, None) + + # outputs + d_q_ = q.new_zeros(NH, C, chunk_size, D) + d_k_ = k.new_zeros(NH, C, chunk_size, D) + d_w = w.new_zeros(NH, C, chunk_size, D) + d_u = u.new_zeros(NH, C, chunk_size, D) + + # chunked inputs + d_y_delta = d_y_delta.view(NH, C, chunk_size, D) + q_ = q.view(NH, C, chunk_size, D) + k_ = k.view(NH, C, chunk_size, D) + w = w.view(NH, C, chunk_size, D) + + # shared memory copy + u = u.view(NH, C, chunk_size, D).clone() + + state = w.new_zeros(NH, D, D) + d_state = w.new_zeros(NH, D, D) # NHVK + state_delta = w.new_zeros(NH, D, D) # NHVK # can this be float32? + + qk = k.new_zeros(NH, chunk_size, C) + tk = k.new_zeros(NH, chunk_size, D) + + # materialize the state for the leading chunk + state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) + + # stitch forward + for c in range(1, C): + tk = einsum('nvk,ntk->ntv', state, w[:, c]) + u[:, c] = u[:, c] - tk + state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) + if c < C-1: + state = state + state_delta # walk the state forwards + + # from now on, u's are decayed + + # stitch backward + for c in range(C-1, 0, -1): + if c < C-1: + state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) + state = state - state_delta # uncompute the state backwards + tk = einsum('nvk,ntk->ntv', state, w[:, c]) # state_decay + + d_y_delta_c = d_y_delta[:, c] + d_y_delta_c = -d_y_delta_c # neg + + # d_q, d_k + qk = einsum('nsv,ntv->nst', d_y_delta_c, tk) + qk.tril_() + + # d_q + tk = einsum('nst,ntk->nsk', qk, k_[:, c]) # causal_attend_backward for delta + tk.sub_(einsum('nsv,nvk->nsk', d_y_delta_c, state)) # prev_output + d_q_[:, c] = tk + + # d_k + tk = einsum('nst,nsk->ntk', qk, q_[:, c]) + if c < C-1: + tk.add_(einsum('nvk,ntv->ntk', d_state, u[:, c])) # state_add + else: + # d_state is zero + pass + d_k_[:, c] = tk + + # d_u + if c < C-1: + d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add + else: + # d_state is zero + pass + + # d_state_decays + qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c]) + qk.tril_() + d_state_decays = einsum('nsv,nst->ntv', d_y_delta_c, qk) + if c < C-1: + d_state_decays.sub_(einsum('nvk,ntk->ntv', d_state, k_[:, c])) # state_add + + # d_w + tk = einsum('ntv,nvk->ntk', d_state_decays, state) + d_w[:, c] = tk # state_decays + + # backpropagate through time + d_state.sub_(einsum('nsv,nsk->nvk', d_y_delta_c, q_[:, c])) # prev_output + d_state.add_(einsum('ntv,ntk->nvk', d_state_decays, w[:, c])) # state_decays + + tk = einsum('nvk,ntk->ntv', d_state, k_[:, 0]) + d_u[:, 0] = tk # state_add + tk = einsum('nvk,ntv->ntk', d_state, u[:, 0]) + d_k_[:, 0] = tk # state_add + + return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D) + + +class Delta(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, beta, chunk_size): + w, u, y = forward(q, k, v, beta, chunk_size) + ctx.save_for_backward(q, k, v, beta) + ctx.chunk_size = chunk_size + return y + + @staticmethod + def backward(ctx, d_y): + q, k, v, beta = ctx.saved_tensors + NH, T, D = shape(q, k, v, beta) + d_w = k.new_zeros(NH, T, D) + d_u = v.new_zeros(NH, T, D) + + d_q, d_k, d_v, d_beta = backward(d_w, d_u, d_y, q, k, v, beta, chunk_size=ctx.chunk_size) + return d_q, d_k, d_v, d_beta, None + + +def test_delta(): + NH, T, D = 1, 64, 16 + q1, k1, v1, beta1 = make_example(NH, T, D) + + y0 = forward_loop(q1, k1, v1, beta1) + chunk_size = 8 + + w1, u1, y1 = forward(q1, k1, v1, beta1, chunk_size=chunk_size) + (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward() + + assert allclose(y0, y1, atol=1e-5), 'y1 is wrong' + + q, k, v, beta = make_example(NH, T, D) + y = Delta.apply(q, k, v, beta, chunk_size) + (y - torch.ones_like(y).detach()).pow(2).mean().backward() + + assert allclose(y1, y, atol=1e-5), 'y is wrong' + + # print(beta1.grad - beta.grad, 'beta.grad diff') + # print(q1.grad - q.grad, 'q.grad diff') + # print(k1.grad - k.grad, 'k.grad diff') + # print(v1.grad - v.grad, 'v.grad diff') + + assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong' + assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong' + assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong' + assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong' + + +if __name__ == '__main__': + test_delta() \ No newline at end of file diff --git a/tests/gated_rnn.py b/tests/gated_rnn.py new file mode 100644 index 0000000..fed60d7 --- /dev/null +++ b/tests/gated_rnn.py @@ -0,0 +1,114 @@ +""" +Gated Linear RNNs with state expansion are Linear Transformers with data-dependent cumulative masking +""" +#%% +import torch +from torch import tensor + + +def pascal1(g): + "compute the mask: prescan gates in log space (explicit)" + N, T = g.shape + l = g.new_zeros(N, T, T) + float('-inf') + + for t in range(T): + l[:, t, t] = 0 + for s in range(t): + l[:, s, t] = sum(g[:, k] for k in range(s+1, t)) + g[:, t] + return l + + +def pascal(g): + "compute the mask: prescan gates in log space (dynamic programming)" + N, T = g.shape + l = g.new_zeros(N, T, T) + float('-inf') + + for t in range(T): + l[:, t, t] = 0 + for s in range(t-1, -1, -1): + l[:, s, t] = l[:, s+1, t] + g[:, s+1] + return l + + +@torch.no_grad() +def attend_backward(q, k, v, g): + d_q = torch.einsum('ntk,ntv,nts->nsk', k, v, g.exp()) + d_k = torch.einsum('nsk,ntv,nts->ntk', q, v, g.exp()) + d_v = torch.einsum('nsk,ntk,nts->nt', q, k, g.exp()).unsqueeze(-1).expand_as(v) + g_mask = g.exp().bool().float() + d_g = torch.einsum('nsk,ntk,ntv,nts->nts', q, k, v, g_mask) + return d_q, d_k, d_v, d_g + + +def attend(q, k, v, g): + "masked linear attention: mask is data dependent, no softmax -- can be an RNN" + y = torch.einsum('nsk,ntk,ntv,nts->nsv', q, k, v, g.exp()) + return y + + +def lscan(q, k, v, f): + "linear time scan: fast weight programmer-style loop" + N, D, T = k.shape + N, T = f.shape + h = k.new_zeros(N, D, D, T) + y = k.new_zeros(N, D, T) + h[..., 0] = torch.einsum('nk,nv->nkv', k[..., 0], v[..., 0]) + y[..., 0] = torch.einsum('nk,nkv->nv', q[..., 0], h[..., 0]) + for i in range(1, T): + h[..., i] = f[:, None, None, i] * h[..., i-1] + torch.einsum('nk,nv->nkv', k[..., i], v[..., i]) + y[..., i] = torch.einsum('nk,nkv->nv', q[..., i], h[..., i]) + return y + + +if __name__ == '__main__': + primes = tensor([1, 2, 3, 5, 7, 11, 13, 17, 19])[None, :] + a = pascal(primes.log()).exp() + b = primes.cumprod(dim=-1).float() + assert torch.allclose(a[:, 0, :], b), f'{a[:,0,:]} != {b}' + + torch.manual_seed(0) + + N, T, D = 1, 512, 64 + q = torch.randn(N, T, D, requires_grad=True) / D**0.5 / D**0.25 + q.retain_grad() + k = (torch.randn(N, T, D, requires_grad=True) / D**0.5 / D**0.25).sigmoid() + k.retain_grad() + v = torch.randn(N, T, D, requires_grad=True) / D**0.5 + v.retain_grad() + + #f = torch.rand(N, T) # token-level forget gates: "Gated RNN" with outer product state expansion + f = torch.ones(N, T, requires_grad=True)*0.999 # same sequence-level forget gate: FWP with Decay + f.retain_grad() + + k = (1-f).unsqueeze(-1).expand_as(k).clone().detach().requires_grad_(True) + + g = pascal(f.log()) # Prescan of all gates + g = g.clone().detach().requires_grad_(True) + ## add gradient hook to g + #def g_hook(grad): + # g_hook.grad = grad + #g.register_hook(g_hook) + g_hook = g + + y1 = attend(q, k, v, g) + print(y1, 'gated_attend') # N,T,D + y2 = lscan(q.mT, k.mT, v.mT, f).mT + print(y2, 'lscan') + assert torch.allclose(y1, y2, atol=1e-5), 'gate_attend and lscan should be the same' + + def y_hook(grad): + print(grad, 'y.grad') + y_hook.grad = grad + y1.register_hook(y_hook) + y1.sum().backward() + print(g_hook.grad, 'g.grad') + + d_q, d_k, d_v, d_g = attend_backward(q, k, v, g) + print(d_g, 'd_g') + + assert torch.allclose(q.grad, d_q, atol=1e-5), 'q.grad is wrong' + assert torch.allclose(k.grad, d_k, atol=1e-5), 'k.grad is wrong' + assert torch.allclose(v.grad, d_v, atol=1e-5), 'v.grad is wrong' + # print((g_hook.grad - d_g).pow(2).mean(), 'error') + # print((g_hook.grad - d_g).abs().max(), 'max abs error') + # assert torch.allclose(g_hook.grad, d_g, atol=1e-1), 'g.grad is wrong' diff --git a/tests/single.py b/tests/single.py new file mode 100644 index 0000000..70a1f46 --- /dev/null +++ b/tests/single.py @@ -0,0 +1,68 @@ +import torch + +torch.set_grad_enabled(False) + +def init(B, C, T, *, device, requires_grad=False): + torch.manual_seed(12312323) + gates = 0.999 + 0.001 * torch.rand(B, C, T, device=device, requires_grad=requires_grad) + gates = gates.half().float() + #tokens = torch.rand(B, C, T, device=device, requires_grad=requires_grad) + tokens = torch.ones(B, C, T, device=device, requires_grad=requires_grad) + + return gates, tokens + +device = 'cuda' + +#for SEQUENCE_LENGTH in [512,1024,2048,4096]: +for SEQUENCE_LENGTH in [256]: + direction = "forward" + provider = "kittenexp" + #B, H, D, T = 32, 64, 16, SEQUENCE_LENGTH + B, H, D, T = 1, 1, 32, SEQUENCE_LENGTH + #B, H, D, T = 1, 1, 64, SEQUENCE_LENGTH + print("running", T) + + # from triton: + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + + gates, tokens = init(B, H, T, device=device, requires_grad=direction=="train") + + k = tokens.unsqueeze(-1).expand(B, H, T, D).bfloat16().contiguous() + q = torch.ones_like(k).bfloat16().contiguous() + v = torch.ones_like(q).bfloat16().contiguous() + f = gates.float().contiguous() + o = torch.empty_like(v).bfloat16().contiguous() + + from accelerated_scan.kitten import attend + + if B <= 32: + qk = torch.einsum('bhsd,bhtd->bhst', q, k) + causal_mask = torch.tril(torch.ones(T, T, device=device)) + causal_mask = causal_mask[None, None].expand(B, H, T, T) + causal_mask = causal_mask.to(dtype=qk.dtype) + y = torch.einsum('bhst,bhte->bhse', qk * causal_mask, v) + + torch.cuda.synchronize() + for _ in range(1): + cache.zero_() + attend(q, k, v, f, o) + from torch.nn.functional import scaled_dot_product_attention + scaled_dot_product_attention(q, k, v, is_causal=True) + print('flash') + + if B <= 32: + try: + assert torch.allclose(y, o, atol=1e-3, rtol=1e-3) + except: + print(y[:,:,:,0], 'ref') + print(o[:,:,:,0], 'ker') + # print(y[:,:,0], 'ref, t=0') + # print(o[:,:,0], 'ker, t=0') + # for t in range(SEQUENCE_LENGTH): + # if (y[:,:,t,:]- o[:,:,t,:]).pow(2).mean().item() > 0: + # print(t, y[:,:,t,:]- o[:,:,t,:]) + raise + torch.cuda.synchronize() diff --git a/tests/tile_layout.pdf b/tests/tile_layout.pdf new file mode 100644 index 0000000..006ef3f Binary files /dev/null and b/tests/tile_layout.pdf differ diff --git a/tests/tile_layout.png b/tests/tile_layout.png new file mode 100644 index 0000000..34c9a65 Binary files /dev/null and b/tests/tile_layout.png differ diff --git a/tests/tile_layout.py b/tests/tile_layout.py new file mode 100644 index 0000000..23ab866 --- /dev/null +++ b/tests/tile_layout.py @@ -0,0 +1,163 @@ + +#%% + +import numpy as np +import matplotlib.pyplot as plt + +def tileprint(K, name='K'): + "format matches tileprint in tk code" + assert K.shape == (16, 16) + for laneid in range(32): + row_top = laneid // 4 + row_bottom = row_top + 8 + col_left = laneid % 4 * 2 + col_right = col_left + 8 + + def fmt(r,c,tag): + odd = "y" in tag + if odd: # do not print r for odd rows because cuda printf silently runs out of function arguments + return f"{name}[,{c:02}] {tag}={K[r,c]: .3f}" + else: + return f"{name}[{r:02},{c:02}] {tag}={K[r,c]: .3f}" + + print(f"lane={laneid:02}", " ".join([ + " ".join([fmt(row_top, col_left, "0x"), fmt(row_top, col_left+1, "0y")]), + " ".join([fmt(row_bottom, col_left, "1x"), fmt(row_bottom, col_left+1, "1y")]), + " ".join([fmt(row_top, col_right, "2x"), fmt(row_top, col_right+1, "2y")]), + " ".join([fmt(row_bottom, col_right, "3x"), fmt(row_bottom, col_right+1, "3y")]) + ])) + +""" +template +__device__ void tileprint(rt reg, char *name) { + auto laneid = kittens::laneid(); + static_assert(reg.height == 1 && reg.width == 1, "height and width must be 1"); + for(int i = 0; i < reg.height; i++) { + for(int j = 0; j < reg.width; j++) { + static_assert(reg.packed_per_thread == 4, "packed_per_thread must be 4"); + + int row_top = laneid / 4; + int row_bottom = row_top + 8; + int col_left = laneid % 4 * 2; // stride 4 + int col_right = col_left + 8; + + auto item_top_left = __bfloat1622float2(reg.tiles[i][j].data[0]); + auto item_bottom_left = __bfloat1622float2(reg.tiles[i][j].data[1]); + auto item_top_right = __bfloat1622float2(reg.tiles[i][j].data[2]); + auto item_bottom_right = __bfloat1622float2(reg.tiles[i][j].data[3]); + printf("lane=%02d " + "%s[%02d,%02d] 0x=% .3f " + "%s[,%02d] 0y=% .3f " + "%s[%02d,%02d] 1x=% .3f " + "%s[,%02d] 1y=% .3f " + "%s[%02d,%02d] 2x=% .3f " + "%s[,%02d] 2y=% .3f " + "%s[%02d,%02d] 3x=% .3f " + "%s[,%02d] 3y=% .3f\n", + laneid, + name, row_top, col_left, item_top_left.x, + name, col_left+1, item_top_left.y, + name, row_bottom, col_left, item_bottom_left.x, + name, col_left+1, item_bottom_left.y, + name, row_top, col_right, item_top_right.x, + name, col_right+1, item_top_right.y, + name, row_bottom, col_right, item_bottom_right.x, + name, col_right+1, item_bottom_right.y); + } + } +} +""" + +plt.rcParams["figure.autolayout"] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["toolbar"] = "None" +plt.rcParams["axes.linewidth"] = 0.5 +plt.rcParams['xtick.major.size'] = 0 +plt.rcParams['xtick.major.width'] = 0 +plt.rcParams['ytick.major.size'] = 0 +plt.rcParams['ytick.major.width'] = 0 +#plt.rcParams['text.usetex'] = True + +dotdata = np.arange(8)//2 +xy = ['x', 'y'] * 4 +xticklabels = [f'{d}{x}' for d, x in zip(dotdata, xy)] + +@np.vectorize +def coord(laneid, bottom, right): + row_top = laneid // 4 + row_bottom = row_top + 8 + colL = laneid % 4 * 2 # stride 4 + colR = colL + 8 + row = row_bottom if bottom else row_top + col = colR if right else colL + return row, col + +@np.vectorize +def tdcoord_to_laneid(row, col): + return (row % 8) * 4 + (col % 8) // 2 + +lanes = np.arange(32)[:, None, None] +top = np.arange(2)[None, None, :] # 0 for top or 1 for bottom +left = np.arange(2)[None, :, None] # 0 for left or 1 for right + +coords = coord(lanes, top, left) +rows, cols = coords +rows = rows.reshape(32, -1) +cols = cols.reshape(32, -1) + +rows_xy = np.repeat(rows, 2, axis=1) +cols_xy = np.repeat(cols, 2, axis=1) + +#fig, ax = plt.subplots(figsize=(8, 8)) +fig, (ax, bx) = plt.subplots(1, 2, figsize=(16, 8)) +fig.suptitle('rt layout') + +ax.matshow(rows_xy, aspect='auto', cmap='coolwarm', vmin=-2, vmax=18) +ax.set_yticks(np.arange(32)) +ax.set_title(f'Lane and Register to [Row, Column]') +ax.set_ylabel('thread lane') +ax.set_xlabel('thread registers, $0x$ stands for .data[0].x') +ax.set_xticks(np.arange(8)) +# remove spine +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['bottom'].set_visible(False) +ax.spines['left'].set_visible(False) +ax.set_xticklabels([f'${xt}$' for xt in xticklabels]) + +x, y = np.meshgrid(np.arange(16), np.arange(16)) +@np.vectorize +def cat(x, y): + return f'{x},{y}' +rowcol_to_loc = cat(x,y) +rowcol_to_r = np.zeros((16,16)) + +for (i, j), r in np.ndenumerate(rows_xy): + c = cols_xy[i,j] + j%2 + laneid = tdcoord_to_laneid(r, c) + rowcol_to_loc[r,c] = f'{laneid}:{xticklabels[j]}' + rowcol_to_r[r,c] = r + ax.text(j, i, f'[{str(r).rjust(2)},{str(c).rjust(2)}]', ha='center', va='center') + +# highlight the diagonal +for i in range(16): + rowcol_to_r[i,i] = 8.5 + +bx.matshow(rowcol_to_r, aspect='auto', cmap='coolwarm', vmin=-2, vmax=18) +bx.set_xticks(np.arange(16)) +bx.set_yticks(np.arange(16)) +# remove spine +bx.spines['top'].set_visible(False) +bx.spines['right'].set_visible(False) +bx.spines['bottom'].set_visible(False) +bx.spines['left'].set_visible(False) +bx.set_title('Tile Location to lane:register, $0x$ stands for .data[0].x') +bx.set_xlabel('tile columns') +bx.set_ylabel('tile rows') + +for (i, j), r in np.ndenumerate(rowcol_to_loc): + bx.text(j, i, r, ha='center', va='center') + +plt.savefig('tile_layout.pdf', dpi=300, bbox_inches='tight') +plt.savefig('tile_layout.png', dpi=300, bbox_inches='tight') +# %% diff --git a/tests/warpsim.py b/tests/warpsim.py new file mode 100644 index 0000000..e6f8306 --- /dev/null +++ b/tests/warpsim.py @@ -0,0 +1,54 @@ +# timesteps_per_tile = rows of q +# NUM_WORKERS = number of warps + +def load(*args): + pass + +# all to all attention +def sim_full(warpid, n=4, timesteps_per_tile=1, NUM_WORKERS=4): + qo_blocks = n//(timesteps_per_tile*NUM_WORKERS) + kv_blocks = n//(timesteps_per_tile*NUM_WORKERS) + for q_blk in range(qo_blocks): + q_seq = (q_blk * NUM_WORKERS + warpid) * timesteps_per_tile + #print(f"{warpid=} {q_blk=} {q_seq=}") + for kv_idx in range(kv_blocks): + kv_seq = (kv_idx * NUM_WORKERS + warpid) * timesteps_per_tile # # one warp loads the whole kv block to memory + #print(f"{warpid=} {q_seq=} {kv_seq=}") + for subtile in range(NUM_WORKERS): + k_index = kv_idx * NUM_WORKERS + subtile # every warp now accesses every block in share memory + print(f"{warpid=} {kv_seq=} q[{q_seq}]@k[{k_index}]") + # store here + +# causal attention +def sim(warpid, NUM_WORKERS, seqlen=4, timestep_tiles_per_thread=1, timesteps_per_tile=1): + time_stride = timestep_tiles_per_thread*timesteps_per_tile + qo_blocks = seqlen//(time_stride*NUM_WORKERS) + + for q_blk in range(qo_blocks): + q_seq = (q_blk * NUM_WORKERS + warpid) * time_stride + q_end = q_seq + time_stride + load(q_seq) + + #kv_blocks = n//(timesteps_per_tile*NUM_WORKERS) + #kvbs = [x for x in range(kv_blocks) if x <= q_blk] + #print(f"{warpid=} {q_blk=} {q_seq=} {kvbs=}") + + for kv_blk in range(q_blk,-1,-1): + kv_warp_index = kv_blk * NUM_WORKERS + warpid + kv_seq = kv_warp_index * time_stride # one warp loads the whole kv block to memory + if q_seq >= kv_seq: + load(kv_seq) + #print(f"{warpid=} {q_seq=} {kv_blk=} {kv_seq=}") + for subtile in range(NUM_WORKERS,-1,-1): + k_seq = (kv_blk * NUM_WORKERS + subtile) * time_stride # every warp now accesses every block in share memory + k_end = k_seq + time_stride + if q_seq >= k_seq: + load(k_seq) + needs_make_causal = "\\" if q_seq == k_seq else "" + print(f"{warpid=} {kv_seq=} q[{q_seq}:{q_end}]@k[{k_seq}:{k_end}] {needs_make_causal}") + # store here + + +NUM_WORKERS = 4 +for warpid in range(NUM_WORKERS): + sim(warpid, NUM_WORKERS)