Skip to content

Commit 298cfec

Browse files
committed
quick tests for flex attn masks with the ascii __repr__
1 parent 6f73ceb commit 298cfec

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test_flex_masks.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from native_sparse_attention_pytorch.native_sparse_attention import (
3+
create_compress_mask,
4+
create_fine_mask,
5+
create_sliding_mask,
6+
)
7+
8+
# compress
9+
10+
print('compress mask:', create_compress_mask(1024, 256, 4))
11+
12+
# fine
13+
14+
selected_blocks = torch.randint(0, 5, (1, 1, 1024, 2)) # select mostly first few blocks
15+
16+
fine_block_mask = create_fine_mask(1024, 64)(selected_blocks.cuda())
17+
18+
print('fine:', fine_block_mask)
19+
20+
# sliding
21+
22+
print('sliding:', create_sliding_mask(1024, 32))

0 commit comments

Comments
 (0)