Skip to content

Commit 0941bb8

Browse files
author
Kye
committed
parallel wrapper
1 parent b6bcb8b commit 0941bb8

File tree

5 files changed

+109
-2
lines changed

5 files changed

+109
-2
lines changed

LongNet/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
from LongNet.attention import DilatedAttention
2+
from LongNet.attention import ParallelWrapper, DilatedAttention
33
# from LongNet.model import LongNetTokenizer, LongNet, DecoderConfig, Decoder, DilatedLongNet
44

55
# from LongNet.iterations import DynamicDilatedAttention, DilatedAttentionOld, DilatedAttentionOP

LongNet/attention.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
dtype=torch.float16
1515

1616

17+
18+
19+
20+
21+
22+
1723
def SparsifyIndices(
1824
x: torch.Tensor, ws: List[int], rs: List[int], head_idx: int
1925
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]:
@@ -104,6 +110,45 @@ def MixOutputs(
104110

105111

106112

113+
114+
class ParallelWrapper:
115+
"""
116+
A simple wrapper to enable easy usage of data parallelism.
117+
118+
Arguments:
119+
model: The neural network model to be parallelized.
120+
device (optional): The device to which the model should be moved. Default: "cuda".
121+
use_data_parallel (optional): A boolean flag to indicate whether to use data parallelism or not. Default: True.
122+
"""
123+
def __init__(
124+
self,
125+
model,
126+
device="cuda",
127+
use_data_parallel=True
128+
):
129+
self.model = model.to(device)
130+
self.use_data_parallel = use_data_parallel
131+
self.device = device
132+
133+
if self.use_data_parallel and torch.cuda.device_count() < 1:
134+
print(f"Using {torch.cuda.device_count()} GPUS")
135+
self.model = nn.DataParallel(self.model)
136+
137+
def forward(self, *args, **kwargs):
138+
return self.model(*args, **kwargs)
139+
140+
def to(self, device):
141+
self.device = device
142+
self.model= self.model.to(device)
143+
return self
144+
145+
def __getattr__(self, name):
146+
#redirect attribute access to the internal model to allow direct access to its methods and props
147+
return getattr(self.model, name)
148+
149+
150+
151+
107152
#add alibi, qk layer norm, one write head, multiway,
108153
class DilatedAttentionNew(nn.Module):
109154
"""
@@ -319,6 +364,22 @@ def forward(self, x):
319364

320365

321366

367+
368+
369+
370+
371+
372+
373+
374+
375+
376+
377+
378+
379+
380+
381+
382+
322383

323384

324385
class MultiHeadDilatedAttention:

mh_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#

parallel_example.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import timeit
2+
import torch
3+
from LongNet.attention import ParallelWrapper, DilatedAttention
4+
5+
#model condig
6+
d_model = 512
7+
num_heads = 8
8+
dilation_rate = 2
9+
segment_size = 64
10+
11+
12+
device="cuda:0"
13+
dtype=torch.float16
14+
15+
#inputs
16+
batch_size = 32
17+
seq_len = 8192
18+
19+
20+
#create model
21+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22+
model = DilatedAttention(
23+
d_model,
24+
num_heads,
25+
dilation_rate,
26+
segment_size
27+
)
28+
parallel_model = ParallelWrapper(model, device=device)
29+
30+
x = torch.randn((batch_size, seq_len, d_model), device=device, dtype=dtype)
31+
32+
#test forward pass
33+
with torch.no_grad():
34+
output = model(x)
35+
print(f"Output shape: {output.shape}") #expected (batch_size, seq_len)
36+
37+
#benchmark model
38+
num_runs = 1000
39+
start_time = timeit.default_timer()
40+
for _ in range(num_runs):
41+
model(x)
42+
43+
44+
elapsed_time = timeit.default_timer() - start_time
45+
print(f"Average forward pass time: {elapsed_time / num_runs:.6f} seconds")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
setup(
55
name = 'LongNet',
66
packages = find_packages(exclude=[]),
7-
version = '0.4.3',
7+
version = '0.4.8',
88
license='MIT',
99
description = 'LongNet - Pytorch',
1010
author = 'Kye Gomez',

0 commit comments

Comments
 (0)