Skip to content

Commit 5b9df45

Browse files
committed
SWA reference script
1 parent df6e87a commit 5b9df45

File tree

1 file changed

+290
-0
lines changed

1 file changed

+290
-0
lines changed
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import time
2+
import typing
3+
4+
import tiktoken
5+
import torch
6+
import torch.nn as nn
7+
from torch import cuda
8+
9+
import attention_helpers
10+
11+
12+
class MultiHeadAttentionWithSWA(nn.Module):
13+
def __init__(
14+
self,
15+
d_in: int,
16+
d_out: int,
17+
dropout: float,
18+
num_heads: int,
19+
sliding_window_size: int,
20+
dtype: typing.Optional[torch.dtype] = None,
21+
qkv_bias: bool = False,
22+
) -> None:
23+
super().__init__()
24+
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
25+
assert sliding_window_size > 0, "sliding_window_size must be positive"
26+
27+
self.d_out = d_out
28+
self.num_heads = num_heads
29+
self.head_dim = d_out // num_heads
30+
self.sliding_window_size = int(sliding_window_size)
31+
32+
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias, dtype=dtype)
33+
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias, dtype=dtype)
34+
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias, dtype=dtype)
35+
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
36+
self.dropout = nn.Dropout(dropout)
37+
38+
self.register_buffer("cache_k", None, persistent=False)
39+
self.register_buffer("cache_v", None, persistent=False)
40+
self.ptr_current_pos = 0
41+
42+
def forward(self, x: torch.Tensor) -> torch.Tensor:
43+
b, num_tokens, _ = x.shape
44+
45+
keys_new = self.W_key(x)
46+
values_new = self.W_value(x)
47+
queries = self.W_query(x)
48+
49+
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(
50+
1, 2
51+
)
52+
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim).transpose(
53+
1, 2
54+
)
55+
values_new = values_new.view(
56+
b, num_tokens, self.num_heads, self.head_dim
57+
).transpose(1, 2)
58+
59+
# 1. Update the Cache
60+
if self.cache_k is None:
61+
self.cache_k, self.cache_v = keys_new, values_new
62+
else:
63+
self.cache_k = torch.cat([self.cache_k, keys_new], dim=2)
64+
self.cache_v = torch.cat([self.cache_v, values_new], dim=2)
65+
66+
# 2. Apply Sliding Window (Truncate)
67+
#
68+
# We check the current size (after adding new tokens).
69+
#
70+
# If it exceeds the window, we physically delete the oldest tokens.
71+
if self.cache_k.size(2) > self.sliding_window_size:
72+
self.cache_k = self.cache_k[:, :, -self.sliding_window_size :, :]
73+
self.cache_v = self.cache_v[:, :, -self.sliding_window_size :, :]
74+
75+
keys, values = self.cache_k, self.cache_v
76+
77+
attn_scores = queries @ keys.transpose(2, 3)
78+
79+
num_tokens_q = queries.shape[-2]
80+
num_tokens_k = keys.shape[-2]
81+
device = queries.device
82+
83+
# 3. Calculate Absolute Positions for Masking
84+
#
85+
# We need the absolute index of the tokens in the full text sequence (0, 1, 2,
86+
# ... 500) to ensure the mask works correctly.
87+
#
88+
# The 'right edge' of our cache is the total number of tokens processed so far.
89+
current_absolute_end = self.ptr_current_pos + num_tokens
90+
91+
# The 'left edge' is simply the end minus the current cache size.
92+
# This gives us the Absolute Position of keys[:, :, 0, :].
93+
k_start_absolute = current_absolute_end - num_tokens_k
94+
95+
# The queries start wherever the previous batch ended.
96+
q_start_absolute = self.ptr_current_pos
97+
98+
q_positions = torch.arange(
99+
q_start_absolute,
100+
q_start_absolute + num_tokens_q,
101+
device=device,
102+
dtype=torch.long,
103+
)
104+
105+
k_positions = torch.arange(
106+
k_start_absolute,
107+
k_start_absolute + num_tokens_k,
108+
device=device,
109+
dtype=torch.long,
110+
)
111+
112+
# 4. Create and Apply Mask
113+
diff = q_positions.unsqueeze(-1) - k_positions.unsqueeze(0)
114+
115+
# (diff < 0) -> Causal Mask (prevent looking at future)
116+
#
117+
# (diff >= window) -> Window Mask (prevent looking too far back)
118+
mask = (diff < 0) | (diff >= self.sliding_window_size)
119+
120+
self.ptr_current_pos += num_tokens_q
121+
122+
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
123+
124+
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
125+
attn_weights = self.dropout(attn_weights)
126+
127+
context_vec = (attn_weights @ values).transpose(1, 2)
128+
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
129+
context_vec = self.out_proj(context_vec)
130+
return context_vec
131+
132+
def reset_cache(self) -> None:
133+
self.cache_k, self.cache_v = None, None
134+
self.ptr_current_pos = 0
135+
136+
137+
class TransformerBlock(nn.Module):
138+
def __init__(self, cfg: dict[str, typing.Any]) -> None:
139+
super().__init__()
140+
self.att = MultiHeadAttentionWithSWA(
141+
d_in=cfg["emb_dim"],
142+
d_out=cfg["emb_dim"],
143+
num_heads=cfg["n_heads"],
144+
dropout=cfg["drop_rate"],
145+
qkv_bias=cfg["qkv_bias"],
146+
sliding_window_size=cfg["sliding_window_size"],
147+
)
148+
self.ff = attention_helpers.FeedForward(cfg)
149+
self.norm1 = attention_helpers.LayerNorm(cfg["emb_dim"])
150+
self.norm2 = attention_helpers.LayerNorm(cfg["emb_dim"])
151+
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
152+
153+
def forward(self, x: torch.Tensor) -> torch.Tensor:
154+
shortcut = x
155+
x = self.norm1(x)
156+
x = self.att(x)
157+
x = self.drop_shortcut(x)
158+
x = x + shortcut
159+
160+
shortcut = x
161+
x = self.norm2(x)
162+
x = self.ff(x)
163+
x = self.drop_shortcut(x)
164+
x = x + shortcut
165+
166+
return x
167+
168+
169+
class GPTModel(nn.Module):
170+
def __init__(self, cfg: dict[str, typing.Any]) -> None:
171+
super().__init__()
172+
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
173+
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
174+
self.drop_emb = nn.Dropout(cfg["drop_rate"])
175+
176+
self.trf_blocks = nn.ModuleList(
177+
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
178+
)
179+
180+
self.current_pos = 0
181+
182+
self.final_norm = attention_helpers.LayerNorm(cfg["emb_dim"])
183+
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
184+
185+
def forward(self, in_idx: torch.Tensor) -> torch.Tensor:
186+
_, seq_len = in_idx.shape
187+
tok_embeds = self.tok_emb(in_idx)
188+
189+
pos_ids = torch.arange(
190+
self.current_pos,
191+
self.current_pos + seq_len,
192+
device=in_idx.device,
193+
dtype=torch.long,
194+
)
195+
self.current_pos += seq_len
196+
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
197+
198+
x = tok_embeds + pos_embeds
199+
x = self.drop_emb(x)
200+
201+
for blk in self.trf_blocks:
202+
x = blk(x)
203+
204+
x = self.final_norm(x)
205+
logits = self.out_head(x)
206+
return logits
207+
208+
def reset_kv_cache(self) -> None:
209+
for blk in self.trf_blocks:
210+
blk.att.reset_cache()
211+
self.current_pos = 0
212+
213+
214+
def generate_text_simple_cached(
215+
model: GPTModel,
216+
idx: torch.Tensor,
217+
max_new_tokens: int,
218+
context_size: typing.Optional[int] = None,
219+
) -> torch.Tensor:
220+
model.eval()
221+
ctx_len = context_size or model.pos_emb.num_embeddings
222+
223+
with torch.no_grad():
224+
model.reset_kv_cache()
225+
logits = model(idx[:, -ctx_len:])
226+
227+
for _ in range(max_new_tokens):
228+
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
229+
idx = torch.cat([idx, next_idx], dim=1)
230+
logits = model(next_idx)
231+
232+
return idx
233+
234+
235+
def main() -> None:
236+
start_context = "Hello, I am"
237+
tokenizer = tiktoken.get_encoding("gpt2")
238+
encoded = tokenizer.encode(start_context)
239+
240+
GPT_CONFIG_124M = {
241+
"vocab_size": 50257, # Vocabulary size
242+
"context_length": 1024, # Context length
243+
"emb_dim": 768, # Embedding dimension
244+
"n_heads": 12, # Number of attention heads
245+
"n_layers": 12, # Number of layers
246+
"drop_rate": 0.0, # Dropout rate
247+
"qkv_bias": False, # Query-Key-Value bias
248+
"sliding_window_size": 256, # SWA window size W
249+
}
250+
torch.manual_seed(251221)
251+
model = GPTModel(GPT_CONFIG_124M)
252+
device = torch.device("cuda" if cuda.is_available() else "cpu")
253+
model.to(device, dtype=torch.bfloat16)
254+
model.eval()
255+
256+
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
257+
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
258+
print("\nInput text:", start_context)
259+
print("Encoded input text:", encoded)
260+
print("encoded_tensor.shape:", encoded_tensor.shape)
261+
262+
if cuda.is_available():
263+
cuda.synchronize()
264+
start = time.time()
265+
266+
token_ids = generate_text_simple_cached(
267+
model=model, idx=encoded_tensor, max_new_tokens=200
268+
)
269+
270+
if cuda.is_available():
271+
cuda.synchronize()
272+
total_time = time.time() - start
273+
274+
decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
275+
276+
print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
277+
print("\nOutput:", token_ids)
278+
print("Output length:", len(token_ids[0]))
279+
print("Output text:", decoded_text)
280+
281+
print(f"\nTime: {total_time:.2f} sec")
282+
print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
283+
if cuda.is_available():
284+
max_mem_bytes = cuda.max_memory_allocated()
285+
max_mem_gb = max_mem_bytes / (1024**3)
286+
print(f"Max memory allocated: {max_mem_gb:.2f} GB")
287+
288+
289+
if __name__ == "__main__":
290+
main()

0 commit comments

Comments
 (0)