Skip to content

Commit aab74c1

Browse files
authored
[Kernel] Remove all syncs from STA & VSA kernels (#517)
1 parent f89d869 commit aab74c1

File tree

12 files changed

+250
-139
lines changed

12 files changed

+250
-139
lines changed

.github/workflows/pr-test.yml

Lines changed: 97 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,9 @@ on:
1414
- ".github/workflows/pr-test.yml"
1515
- "pyproject.toml"
1616
- "docker/Dockerfile.python3.12"
17+
- "csrc/**"
1718
workflow_dispatch:
1819
inputs:
19-
custom_image:
20-
description: "Custom image from this repository (default: fastvideo-dev:py3.12-latest)"
21-
required: false
22-
default: "fastvideo-dev:py3.12-latest"
23-
type: string
2420
run_encoder_test:
2521
description: "Run encoder-test"
2622
required: false
@@ -56,6 +52,16 @@ on:
5652
required: false
5753
default: false
5854
type: boolean
55+
run_precision_test_STA:
56+
description: "Run precision-test-STA"
57+
required: false
58+
default: false
59+
type: boolean
60+
run_precision_test_VSA:
61+
description: "Run precision-test-VSA"
62+
required: false
63+
default: false
64+
type: boolean
5965
run_nightly_test:
6066
description: "Run nightly-test"
6167
required: false
@@ -65,6 +71,7 @@ on:
6571
env:
6672
PYTHONUNBUFFERED: "1"
6773

74+
6875
concurrency:
6976
group: pr-test-${{ github.ref }}
7077
cancel-in-progress: true
@@ -84,44 +91,69 @@ jobs:
8491
training-test: ${{ steps.filter.outputs.training-test }}
8592
training-test-VSA: ${{ steps.filter.outputs.training-test-VSA }}
8693
inference-test-STA: ${{ steps.filter.outputs.inference-test-STA }}
94+
precision-test-STA: ${{ steps.filter.outputs.precision-test-STA }}
95+
precision-test-VSA: ${{ steps.filter.outputs.precision-test-VSA }}
8796
steps:
8897
- uses: actions/checkout@v4
8998
- uses: dorny/paths-filter@v3
9099
id: filter
91100
with:
92101
filters: |
102+
# Define reusable path patterns
103+
common-paths: &common-paths
104+
- 'pyproject.toml'
105+
- 'docker/Dockerfile.python3.12'
106+
sta-kernel-paths: &sta-kernel-paths
107+
- 'csrc/attn/st_attn/**'
108+
- 'csrc/attn/setup_sta.py'
109+
- 'csrc/attn/config_sta.py'
110+
- 'csrc/attn/st_attn.cpp'
111+
vsa-kernel-paths: &vsa-kernel-paths
112+
- 'csrc/attn/vsa/**'
113+
- 'csrc/attn/tk/**'
114+
- 'csrc/attn/setup_vsa.py'
115+
- 'csrc/attn/config_vsa.py'
116+
- 'csrc/attn/vsa.cpp'
117+
vsa-paths: &vsa-paths
118+
- 'fastvideo/v1/**'
119+
- *common-paths
120+
- *vsa-kernel-paths
121+
122+
# Actual tests
93123
encoder-test:
94124
- 'fastvideo/v1/models/encoders/**'
95125
- 'fastvideo/v1/models/loaders/**'
96126
- 'fastvideo/v1/tests/encoders/**'
97-
- 'pyproject.toml'
98-
- 'docker/Dockerfile.python3.12'
127+
- *common-paths
99128
vae-test:
100129
- 'fastvideo/v1/models/vaes/**'
101130
- 'fastvideo/v1/models/loaders/**'
102131
- 'fastvideo/v1/tests/vaes/**'
103-
- 'pyproject.toml'
104-
- 'docker/Dockerfile.python3.12'
132+
- *common-paths
105133
transformer-test:
106134
- 'fastvideo/v1/models/dits/**'
107135
- 'fastvideo/v1/models/loaders/**'
108136
- 'fastvideo/v1/tests/transformers/**'
109137
- 'fastvideo/v1/layers/**'
110138
- 'fastvideo/v1/attention/**'
111-
- 'pyproject.toml'
112-
- 'docker/Dockerfile.python3.12'
139+
- *common-paths
113140
training-test:
114141
- 'fastvideo/v1/**'
115-
- 'pyproject.toml'
116-
- 'docker/Dockerfile.python3.12'
142+
- *common-paths
117143
training-test-VSA:
118144
- 'fastvideo/v1/**'
119-
- 'pyproject.toml'
120-
- 'docker/Dockerfile.python3.12'
145+
- *common-paths
146+
- *vsa-kernel-paths
121147
inference-test-STA:
122148
- 'fastvideo/v1/**'
123-
- 'pyproject.toml'
124-
- 'docker/Dockerfile.python3.12'
149+
- *common-paths
150+
- *sta-kernel-paths
151+
precision-test-STA:
152+
- *common-paths
153+
- *sta-kernel-paths
154+
precision-test-VSA:
155+
- *common-paths
156+
- *vsa-kernel-paths
125157
126158
encoder-test:
127159
needs: change-filter
@@ -134,7 +166,7 @@ jobs:
134166
gpu_type: "NVIDIA A40"
135167
gpu_count: 1
136168
volume_size: 100
137-
image: "ghcr.io/${{ github.repository }}/${{ github.event.inputs.custom_image || 'fastvideo-dev:py3.12-latest' }}"
169+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
138170
test_command: "uv pip install -e .[test] && pytest ./fastvideo/v1/tests/encoders -s"
139171
timeout_minutes: 30
140172
secrets:
@@ -152,7 +184,7 @@ jobs:
152184
gpu_type: "NVIDIA A40"
153185
gpu_count: 1
154186
volume_size: 100
155-
image: "ghcr.io/${{ github.repository }}/${{ github.event.inputs.custom_image || 'fastvideo-dev:py3.12-latest' }}"
187+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
156188
test_command: "uv pip install -e .[test] && pytest ./fastvideo/v1/tests/vaes -s"
157189
timeout_minutes: 30
158190
secrets:
@@ -170,7 +202,7 @@ jobs:
170202
gpu_type: "NVIDIA L40S"
171203
gpu_count: 1
172204
volume_size: 100
173-
image: "ghcr.io/${{ github.repository }}/${{ github.event.inputs.custom_image || 'fastvideo-dev:py3.12-latest' }}"
205+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
174206
test_command: "uv pip install -e .[test] && pytest ./fastvideo/v1/tests/transformers -s"
175207
timeout_minutes: 30
176208
secrets:
@@ -216,7 +248,7 @@ jobs:
216248
gpu_count: 4
217249
volume_size: 100
218250
disk_size: 100
219-
image: "ghcr.io/${{ github.repository }}/${{ github.event.inputs.custom_image || 'fastvideo-dev:py3.12-latest' }}"
251+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
220252
test_command: "wandb login $WANDB_API_KEY && uv pip install -e .[test] && pytest ./fastvideo/v1/tests/training/Vanilla -srP"
221253
timeout_minutes: 30
222254
secrets:
@@ -236,7 +268,7 @@ jobs:
236268
gpu_count: 1
237269
volume_size: 100
238270
disk_size: 100
239-
image: "ghcr.io/${{ github.repository }}/${{ github.event.inputs.custom_image || 'fastvideo-dev:py3.12-latest' }}"
271+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
240272
test_command: "wandb login $WANDB_API_KEY && uv pip install -e .[test] && pytest ./fastvideo/v1/tests/training/VSA -srP"
241273
timeout_minutes: 30
242274
secrets:
@@ -256,13 +288,51 @@ jobs:
256288
gpu_count: 1
257289
volume_size: 100
258290
disk_size: 100
259-
image: "ghcr.io/${{ github.repository }}/${{ github.event.inputs.custom_image || 'fastvideo-dev:py3.12-latest' }}"
291+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
260292
test_command: "uv pip install -e .[test] && pytest ./fastvideo/v1/tests/inference/STA -srP"
261293
timeout_minutes: 30
262294
secrets:
263295
RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }}
264296
RUNPOD_PRIVATE_KEY: ${{ secrets.RUNPOD_PRIVATE_KEY }}
265297

298+
precision-test-STA:
299+
needs: change-filter
300+
if: >-
301+
(github.event_name != 'workflow_dispatch' && github.event.pull_request.draft == false) ||
302+
(github.event_name == 'workflow_dispatch' && github.event.inputs.run_precision_test_STA == 'true')
303+
uses: ./.github/workflows/runpod-test.yml
304+
with:
305+
job_id: "precision-test-STA"
306+
gpu_type: "NVIDIA H100 NVL"
307+
gpu_count: 1
308+
volume_size: 100
309+
disk_size: 100
310+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
311+
test_command: "uv pip install -e .[test] && python csrc/attn/tests/test_sta.py"
312+
timeout_minutes: 30
313+
secrets:
314+
RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }}
315+
RUNPOD_PRIVATE_KEY: ${{ secrets.RUNPOD_PRIVATE_KEY }}
316+
317+
precision-test-VSA:
318+
needs: change-filter
319+
if: >-
320+
(github.event_name != 'workflow_dispatch' && github.event.pull_request.draft == false) ||
321+
(github.event_name == 'workflow_dispatch' && github.event.inputs.run_precision_test_VSA == 'true')
322+
uses: ./.github/workflows/runpod-test.yml
323+
with:
324+
job_id: "precision-test-VSA"
325+
gpu_type: "NVIDIA H100 NVL"
326+
gpu_count: 1
327+
volume_size: 100
328+
disk_size: 100
329+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
330+
test_command: "uv pip install -e .[test] && python csrc/attn/tests/test_block_sparse.py"
331+
timeout_minutes: 30
332+
secrets:
333+
RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }}
334+
RUNPOD_PRIVATE_KEY: ${{ secrets.RUNPOD_PRIVATE_KEY }}
335+
266336
nightly-test:
267337
if: >-
268338
(github.event_name == 'workflow_dispatch' && github.event.inputs.run_nightly_test == 'true')
@@ -273,7 +343,7 @@ jobs:
273343
gpu_count: 4
274344
volume_size: 100
275345
disk_size: 100
276-
image: "ghcr.io/${{ github.repository }}/${{ github.event.inputs.custom_image || 'fastvideo-dev:py3.12-latest' }}"
346+
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
277347
test_command: "wandb login $WANDB_API_KEY && uv pip install -e .[test] && pytest ./fastvideo/v1/tests/nightly/test_e2e_overfit_single_sample.py -vs"
278348
timeout_minutes: 30
279349
secrets:
@@ -282,7 +352,8 @@ jobs:
282352
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
283353

284354
runpod-cleanup:
285-
needs: [encoder-test, vae-test, transformer-test, ssim-test] # Add other jobs to this list as you create them
355+
# Add other jobs to this list as you create them
356+
needs: [encoder-test, vae-test, transformer-test, ssim-test, training-test, training-test-VSA, inference-test-STA, precision-test-STA, precision-test-VSA]
286357
if: ${{ always() && ((github.event_name != 'workflow_dispatch' && github.event.pull_request.draft == false) || github.event_name == 'workflow_dispatch') }}
287358
runs-on: ubuntu-latest
288359
steps:
@@ -299,7 +370,7 @@ jobs:
299370

300371
- name: Cleanup all RunPod instances
301372
env:
302-
JOB_IDS: '["encoder-test", "vae-test", "transformer-test", "ssim-test-py3.10", "ssim-test-py3.11", "ssim-test-py3.12"]'
373+
JOB_IDS: '["encoder-test", "vae-test", "transformer-test", "ssim-test-py3.10", "ssim-test-py3.11", "ssim-test-py3.12", "training-test", "training-test-VSA", "inference-test-STA", "precision-test-STA", "precision-test-VSA"]'
303374
RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }}
304375
GITHUB_RUN_ID: ${{ github.run_id }}
305376
run: python .github/scripts/runpod_cleanup.py

csrc/attn/README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
## Installation
7-
We test our code on Pytorch 2.5.0 and CUDA>=12.4. Currently we only have implementation on H100.
7+
We test our code on Pytorch 2.5.0 and CUDA>=12.4. Currently we only support H100/H200, because ThunderKittens uses TMA but doesn't support Blackwell yet.
88
First, install C++20 for ThunderKittens:
99
```bash
1010
sudo apt update
@@ -53,8 +53,14 @@ out = sliding_tile_attention(q, k, v, window_size, 0, False)
5353

5454
## Test
5555
```bash
56-
python test/test_sta.py
56+
python tests/test_sta.py # test STA
57+
python tests/test_block_sparse.py # test VSA
5758
```
59+
## Benchmark
60+
```bash
61+
python benchmarks/bench_sta.py
62+
```
63+
5864

5965
## How Does STA Work?
6066
We give a demo for 2D STA with window size (6,6) operating on a (10, 10) image.

csrc/attn/bench/bench_sta.py renamed to csrc/attn/benchmarks/bench_sta.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import torch
77
from st_attn import sliding_tile_attention
8+
from triton.testing import do_bench
89

910

1011
def flops(batch, seqlen, nheads, headdim, causal, mode="fwd"):
@@ -13,55 +14,48 @@ def flops(batch, seqlen, nheads, headdim, causal, mode="fwd"):
1314
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
1415

1516

16-
def efficiency(flop, time):
17-
flop = flop / 1e12
18-
time = time / 1e6
19-
return flop / time
17+
def compute_TFLOPS(flops, ms):
18+
flops = flops / 1e12
19+
ms = ms / 1e3
20+
return flops / ms
2021

2122

2223
def benchmark_attention(configurations):
2324
results = {'fwd': defaultdict(list), 'bwd': defaultdict(list)}
2425

25-
for B, H, N, D, causal in configurations:
26+
for B, H, N, D, causal, dit_seq_shape, window_size in configurations:
2627
print("=" * 60)
2728
print(f"Timing forward and backward pass for B={B}, H={H}, N={N}, D={D}, causal={causal}")
2829

2930
q = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda', requires_grad=False).contiguous()
3031
k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda', requires_grad=False).contiguous()
3132
v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda', requires_grad=False).contiguous()
3233

33-
grad_output = torch.randn_like(q, requires_grad=False).contiguous()
34+
# grad_output = torch.randn_like(q, requires_grad=False).contiguous()
35+
# qg = torch.zeros_like(q, requires_grad=False, dtype=torch.float).contiguous()
36+
# kg = torch.zeros_like(k, requires_grad=False, dtype=torch.float).contiguous()
37+
# vg = torch.zeros_like(v, requires_grad=False, dtype=torch.float).contiguous()
3438

35-
qg = torch.zeros_like(q, requires_grad=False, dtype=torch.float).contiguous()
36-
kg = torch.zeros_like(k, requires_grad=False, dtype=torch.float).contiguous()
37-
vg = torch.zeros_like(v, requires_grad=False, dtype=torch.float).contiguous()
3839

39-
# Prepare for timing forward pass
40-
start_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)]
41-
end_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)]
42-
43-
torch.cuda.empty_cache()
44-
torch.cuda.synchronize()
45-
46-
# Warmup for forward pass
47-
for _ in range(10):
48-
o = sliding_tile_attention(q, k, v, [[3, 6, 10]] * 24, 0, False, '18x48x80')
40+
# # Warmup for forward pass
41+
# for _ in range(10):
42+
# o = sliding_tile_attention(q, k, v, [[3, 6, 10]] * 24, 0, False, dit_seq_shape)
4943

50-
# Time the forward pass
51-
for i in range(10):
52-
start_events_fwd[i].record()
53-
o = sliding_tile_attention(q, k, v, [[3, 6, 10]] * 24, 0, False, '18x48x80')
54-
end_events_fwd[i].record()
44+
# # Time the forward pass
45+
# for i in range(10):
46+
# start_events_fwd[i].record()
47+
# o = sliding_tile_attention(q, k, v, [[3, 6, 10]] * 24, 0, False, dit_seq_shape)
48+
# end_events_fwd[i].record()
49+
ms = do_bench(lambda: sliding_tile_attention(q, k, v, [window_size] * 24, 0, False, dit_seq_shape))
5550

56-
torch.cuda.synchronize()
57-
times_fwd = [s.elapsed_time(e) for s, e in zip(start_events_fwd, end_events_fwd)]
58-
time_us_fwd = np.mean(times_fwd) * 1000
51+
# times_fwd = [s.elapsed_time(e) for s, e in zip(start_events_fwd, end_events_fwd)]
52+
# time_us_fwd = np.mean(times_fwd) * 1000
5953

60-
tflops_fwd = efficiency(flops(B, N, H, D, causal, 'fwd'), time_us_fwd)
54+
tflops_fwd = compute_TFLOPS(flops(B, N, H, D, causal, 'fwd'), ms)
6155
results['fwd'][(D, causal)].append((N, tflops_fwd))
6256

63-
print(f"Average time for forward pass in us: {time_us_fwd:.2f}")
64-
print(f"Average efficiency for forward pass in TFLOPS: {tflops_fwd}")
57+
print(f"Average time for forward pass (ms): {ms:.2f}")
58+
print(f"Average TFLOPS: {tflops_fwd}")
6559
print("-" * 60)
6660

6761
# torch.cuda.empty_cache()
@@ -85,15 +79,14 @@ def benchmark_attention(configurations):
8579
# times_bwd = [s.elapsed_time(e) for s, e in zip(start_events_bwd, end_events_bwd)]
8680
# time_us_bwd = np.mean(times_bwd) * 1000
8781

88-
# tflops_bwd = efficiency(flops(B, N, H, D, causal, 'bwd'), time_us_bwd)
82+
# tflops_bwd = compute_TFLOPS(flops(B, N, H, D, causal, 'bwd'), ms)
8983
# results['bwd'][(D, causal)].append((N, tflops_bwd))
9084

91-
# print(f"Average time for backward pass in us: {time_us_bwd:.2f}")
92-
# print(f"Average efficiency for backward pass in TFLOPS: {tflops_bwd}")
93-
print("=" * 60)
85+
# print(f"Average time for backward pass(ms): {ms:.2f}")
86+
# print(f"Average TFLOPS: {tflops_bwd}")
87+
# print("=" * 60)
9488

9589
torch.cuda.empty_cache()
96-
torch.cuda.synchronize()
9790

9891
return results
9992

@@ -124,7 +117,10 @@ def plot_results(results):
124117

125118
# Example list of configurations to test
126119
configurations = [
127-
(2, 24, 69120, 128, False),
120+
(2, 24, 69120, 128, False, '18x48x80', [3, 6, 10]),
121+
(2, 24, 69120, 128, True, '18x48x80', [3, 6, 10]),
122+
(2, 24, 82944, 128, False, '36x48x48', [3, 3, 6]), # Stepvideo
123+
(2, 24, 82944, 128, True, '36x48x48', [3, 3, 6]),
128124
# (16, 16, 768*16, 128, False),
129125
# (16, 16, 768*2, 128, False),
130126
# (16, 16, 768*4, 128, False),
File renamed without changes.

0 commit comments

Comments
 (0)