Skip to content

Commit d223cb2

Browse files
committed
pytest
1 parent 5a61460 commit d223cb2

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

.github/workflows/test.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
name: Pytest
2+
on: [push, pull_request]
3+
4+
jobs:
5+
build:
6+
7+
runs-on: ubuntu-latest
8+
9+
steps:
10+
- uses: actions/checkout@v4
11+
- name: Set up Python 3.10
12+
uses: actions/setup-python@v5
13+
with:
14+
python-version: "3.10"
15+
- name: Install dependencies
16+
run: |
17+
python -m pip install --upgrade pip
18+
python -m pip install -e .[test]
19+
- name: Test with pytest
20+
run: |
21+
python -m pytest tests/

tests/test_sparse_attn.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
import torch
3+
4+
def test_sparse_attn():
5+
from native_sparse_attention_pytorch import SparseAttention
6+
7+
attn = SparseAttention(
8+
dim = 512,
9+
dim_head = 64,
10+
heads = 8,
11+
sliding_window_size = 2,
12+
compress_block_size = 4,
13+
selection_block_size = 4,
14+
num_selected_blocks = 2
15+
)
16+
17+
tokens = torch.randn(2, 31, 512)
18+
19+
attended = attn(tokens)
20+
21+
assert tokens.shape == attended.shape

0 commit comments

Comments
 (0)