Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
adedad8
trying flash attention
proger Jun 16, 2024
5a4941c
kitten: remove frontend scaffolding
proger Jun 16, 2024
2d71795
expand to 255 registers per thread with some spilling:
proger Jun 21, 2024
e0d967c
load_inline: do not set -arch
proger Jun 21, 2024
06f1715
pyproject: depend on ninja and matplotlib out of the box
proger Jun 21, 2024
f0e5418
reference implementation of deltanet
proger Jul 5, 2024
805a8e9
decay_values_backward_kernel with only forward
proger Jul 6, 2024
7abcb48
minor cleanup
proger Jul 6, 2024
3693fa0
tile_layout is a notebook that shows how tiles are mapped to warp reg…
proger Jul 7, 2024
4e38a84
decay_values_backward: compute d_{w_t} / d_{b_s}
proger Jul 7, 2024
b699acb
deltacu: sketch du_t/db_s
proger Jul 7, 2024
688342c
shrink WK and UK expressions a bit
proger Jul 10, 2024
498706b
i hear backpropagation is a nice algorithm
proger Jul 10, 2024
26b7320
backpropagation through time is great
proger Jul 10, 2024
7179aea
deltanet: massage the code
proger Jul 10, 2024
33115c4
pull d_v and d_beta out of the loop
proger Jul 10, 2024
ae10e4a
deltacu cleanup
proger Jul 10, 2024
d68916e
decay_values_backward 16x16 kernel works
proger Jul 11, 2024
f3ca475
delta.cu: bump dimension to 64
proger Jul 11, 2024
be78214
delta.cu: type and dimension dispatch
proger Jul 11, 2024
80d2626
decay_values_forward
proger Jul 11, 2024
1df2cb6
fuse linear attention into decay_values
proger Jul 12, 2024
8781794
almost fused
proger Jul 12, 2024
22894f4
inline attention into decay_values
proger Jul 15, 2024
e950632
when stitching forward don't need to recompute y
proger Jul 15, 2024
d7ad825
stitch backward: when uncomputing state you don't need to store all a…
proger Jul 15, 2024
1a48f7f
stitch_backward: shave off a bit of computation on the boundary
proger Jul 15, 2024
bd9f2c4
whitespace
proger Jul 15, 2024
410b3f1
massage to TK assembly
proger Jul 15, 2024
25a36f7
chunk size vis
proger Jul 15, 2024
7e510a2
seqlen -> num_chunks
proger Jul 15, 2024
6132dc3
delta_forward
proger Jul 16, 2024
235de30
stitch forward and backward
proger Jul 16, 2024
cadd04e
end to end backward for one chunk
proger Jul 16, 2024
b0b649c
two chunks end to end
proger Jul 16, 2024
39c601d
simplify
proger Jul 17, 2024
ed58a6f
test T and D ranges with pytest
proger Jul 17, 2024
685bb54
start measuring delta forward speed
proger Jul 17, 2024
a3f211a
forward: use shared memory for state passing
proger Jul 19, 2024
6cef837
garbage goobers pass batons
proger Jul 19, 2024
2f01fb7
decay_values_forward: try not to saturate the tensor core
proger Jul 20, 2024
2055bfd
zeroexcept
proger Jul 20, 2024
3622c4a
value dimension chunking for forward
proger Jul 20, 2024
5ec885f
move around
proger Jul 20, 2024
e35e284
add fla to benchmark
proger Jul 20, 2024
8d7af9e
loop_impl can avoid computing output
proger Jul 27, 2024
160e701
make kitten bigger
proger Jul 27, 2024
770e8d7
update bench
proger Jul 27, 2024
9a74775
focus on backward
proger Jul 27, 2024
1f4fce0
ref: implement decay_values without tensor cores
proger Jul 27, 2024
6f78a85
add a test for gated_rnn
proger Jul 27, 2024
af39ca0
prepare to speed up backward
proger Jul 27, 2024
676aa85
Refactor chunk_backward function to use shared memory for state and d…
proger Jul 27, 2024
d9f5fa7
backward: fuse forward and backward into one loop, works with only sh…
proger Jul 27, 2024
951dad8
move decay_values_backward out
proger Jul 27, 2024
66862e5
d_v is ok but the rest is not
proger Aug 2, 2024
abf0be6
backward go brrr
proger Aug 2, 2024
ca9307b
api
proger Aug 2, 2024
4b4d2b0
no prints
proger Aug 2, 2024
e1226b7
start benchmarking backward
proger Aug 2, 2024
c688170
bench backward
proger Aug 3, 2024
1c94d10
backward uses less global memory
proger Aug 3, 2024
9809086
properly name benchmarks
proger Aug 3, 2024
85b8f2a
backward uses less global memory and initializes well
proger Aug 3, 2024
a8827ea
prepare values for dimension groups
proger Aug 3, 2024
5b9be61
store d_q through registers
proger Aug 3, 2024
602bcca
try cudaLaunchCooperativeKernel for backward
proger Aug 4, 2024
d37d37c
bench tweaks
proger Aug 4, 2024
f5aca36
more forgiving atol for backward
proger Aug 4, 2024
3152924
implement with locking
proger Aug 5, 2024
249b600
no need for reloading when no value groups
proger Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "accelerated_scan/tk"]
path = accelerated_scan/tk
url = https://github.com/HazyResearch/ThunderKittens
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,12 @@ forward speed of (8,1536,seqlen), inference mode:
When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation):

![max-abs-error.png](max-abs-error.png)


## Attention

```
(cd accelerated_scan; python3 kitten_setup.py build)
ncu -k causal_attend_kernel python3 ./tests/single.py
python3 ./tests/bench.py --direction forward
```
Loading