Skip to content

Commit 53b0de0

Browse files
author
Kye
committed
testing suite
1 parent c72e361 commit 53b0de0

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

LongNet/attention.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from LongNet.utils import XPOS, RelativePositionBias
77

8-
from LongNet.attend import FlashMHA
8+
from LongNet.attend import FlashAttention
99

1010
device = "cuda:0"
1111
dtype=torch.float16
@@ -34,7 +34,6 @@ def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout=0.0,
3434
self.num_heads = num_heads # number of attention heads
3535
self.dilation_rate = dilation_rate # dilation rate
3636
self.segment_size = segment_size # segment size
37-
3837
self.dropout = nn.Dropout(dropout)
3938
# If casual attention is used
4039
self.casual = casual
@@ -44,13 +43,12 @@ def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout=0.0,
4443
self.use_rel_pos_bias = use_rel_pos_bias
4544
self.distributed = Distributed
4645

46+
4747
# Initialize attention for each head with dilation
48-
# Initialize the attention heads with or without DataParallel based on the value of 'distributed'
4948
if self.distributed:
50-
self.attentions = nn.ModuleList([DataParallel(FlashMHA(embed_dim=d_model, num_heads=num_heads, device=device, dtype=dtype)) for _ in range(self.dilation_rate)])
49+
self.attentions = nn.ModuleList([DataParallel(FlashAttention(causal=self.casual, dropout=dropout)) for _ in range(self.dilation_rate)])
5150
else:
52-
self.attentions = nn.ModuleList([FlashMHA(embed_dim=d_model, num_heads=num_heads, device=device, dtype=dtype) for _ in range(self.dilation_rate)])
53-
51+
self.attentions = nn.ModuleList([FlashAttention(causal=self.casual, dropout=dropout) for _ in range(self.dilation_rate)])
5452

5553
# If using positional encoding, initialize it
5654
if use_xpos:
@@ -104,8 +102,6 @@ def forward(self, x):
104102
#option2
105103
# elements_attns = [attention(element.to(dtype), element.to(dtype), element.to(dtype)) for element in x_]
106104
# attn_output = torch.cat(elements_attns, dim=1)
107-
108-
109105

110106
# If using relative positional bias, add it
111107
if self.use_rel_pos_bias:
@@ -137,6 +133,11 @@ def forward(self, x):
137133

138134

139135

136+
class MultiHeadDilatedAttention:
137+
def __init__():
138+
pass
139+
140+
140141

141142

142143

LongNet/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self):
9191
distributed = False # whether to distribute attention for DilatedAttention
9292
)
9393

94-
def forward(self, text_tokens, temperature: int = None, filter_thres: int = None, **kwargs):
94+
def generate(self, text_tokens, temperature: int = None, filter_thres: int = None, **kwargs):
9595
sampled = self.model.generate(temperature=temperature, filter_thres=filter_thres)
9696
return sampled
9797

0 commit comments

Comments
 (0)