55
66from LongNet .utils import XPOS , RelativePositionBias
77
8- from LongNet .attend import FlashMHA
8+ from LongNet .attend import FlashAttention
99
1010device = "cuda:0"
1111dtype = torch .float16
@@ -34,7 +34,6 @@ def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout=0.0,
3434 self .num_heads = num_heads # number of attention heads
3535 self .dilation_rate = dilation_rate # dilation rate
3636 self .segment_size = segment_size # segment size
37-
3837 self .dropout = nn .Dropout (dropout )
3938 # If casual attention is used
4039 self .casual = casual
@@ -44,13 +43,12 @@ def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout=0.0,
4443 self .use_rel_pos_bias = use_rel_pos_bias
4544 self .distributed = Distributed
4645
46+
4747 # Initialize attention for each head with dilation
48- # Initialize the attention heads with or without DataParallel based on the value of 'distributed'
4948 if self .distributed :
50- self .attentions = nn .ModuleList ([DataParallel (FlashMHA ( embed_dim = d_model , num_heads = num_heads , device = device , dtype = dtype )) for _ in range (self .dilation_rate )])
49+ self .attentions = nn .ModuleList ([DataParallel (FlashAttention ( causal = self . casual , dropout = dropout )) for _ in range (self .dilation_rate )])
5150 else :
52- self .attentions = nn .ModuleList ([FlashMHA (embed_dim = d_model , num_heads = num_heads , device = device , dtype = dtype ) for _ in range (self .dilation_rate )])
53-
51+ self .attentions = nn .ModuleList ([FlashAttention (causal = self .casual , dropout = dropout ) for _ in range (self .dilation_rate )])
5452
5553 # If using positional encoding, initialize it
5654 if use_xpos :
@@ -104,8 +102,6 @@ def forward(self, x):
104102 #option2
105103 # elements_attns = [attention(element.to(dtype), element.to(dtype), element.to(dtype)) for element in x_]
106104 # attn_output = torch.cat(elements_attns, dim=1)
107-
108-
109105
110106 # If using relative positional bias, add it
111107 if self .use_rel_pos_bias :
@@ -137,6 +133,11 @@ def forward(self, x):
137133
138134
139135
136+ class MultiHeadDilatedAttention :
137+ def __init__ ():
138+ pass
139+
140+
140141
141142
142143
0 commit comments