-
Notifications
You must be signed in to change notification settings - Fork 51
Description
When calling the native_sparse_attend() function on an GTX 3090 GPU, it runs normally when dim_head=64, but when dim_head=128, the following error occurs:
Traceback (most recent call last):
File "/root/NPNP/native-sparse-attention-pytorch/test_triton_nsa.py", line 175, in
nsa_out = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, sel_scale = nsel_scale, return_lse = False, block_dk_dv_use_dot = block_dk_dv_use_dot, return_sliding_window_out = fused_sliding_window)
File "/root/NPNP/native-sparse-attention-pytorch/native_sparse_attention_pytorch/triton_native_sparse_attention.py", line 789, in native_sparse_attend
out, sliding_out, lse, sliding_lse = _native_sparse_attend(
File "/home/vipuser/miniconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/root/NPNP/native-sparse-attention-pytorch/native_sparse_attention_pytorch/triton_native_sparse_attention.py", line 725, in forward
out, slide_out, lse, slide_lse = native_sparse_attn_forward(
File "/root/NPNP/native-sparse-attention-pytorch/native_sparse_attention_pytorch/triton_native_sparse_attention.py", line 649, in native_sparse_attn_forward
forward_kernel[grid](
File "/home/vipuser/miniconda3/envs/myenv/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/home/vipuser/miniconda3/envs/myenv/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 338, in run
return self.fn.run(*args, **kwargs)
File "/home/vipuser/miniconda3/envs/myenv/lib/python3.10/site-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/home/vipuser/miniconda3/envs/myenv/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in getattribute
self._init_handles()
File "/home/vipuser/miniconda3/envs/myenv/lib/python3.10/site-packages/triton/compiler/compiler.py", line 374, in _init_handles
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 139264, Hardware limit: 101376. Reducing block sizes ornum_stagesmay help.
Could you provide some suggestions? Thank you very much!🩷