Skip to content

Commit f1b9c2d

Browse files
committed
init to sliding window strategy only
1 parent e582786 commit f1b9c2d

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from math import ceil
55

66
import torch
7-
from torch import nn, arange, stack, cat, Tensor
87
import torch.nn.functional as F
8+
from torch import nn, arange, stack, cat, tensor, Tensor
99
from torch.nn import Module, ModuleList
1010

1111
from local_attention import LocalAttention
@@ -226,6 +226,11 @@ def __init__(
226226
if not exists(strategy_combine_mlp):
227227
strategy_combine_mlp = nn.Linear(dim, 3 * heads)
228228

229+
# init to sliding windows first, as network tends to pick up on local patterns first before distant ones
230+
231+
nn.init.zeros_(strategy_combine_mlp.weight)
232+
strategy_combine_mlp.bias.data.copy_(tensor([-2., -2., 2.] * heads))
233+
229234
self.to_strategy_combine = nn.Sequential(
230235
strategy_combine_mlp,
231236
nn.Sigmoid(),

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.26"
3+
version = "0.0.27"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)