|
| 1 | +import time |
| 2 | +import typing |
| 3 | + |
| 4 | +import tiktoken |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +from torch import cuda |
| 8 | + |
| 9 | +import attention_helpers |
| 10 | + |
| 11 | + |
| 12 | +class MultiHeadAttentionWithSWA(nn.Module): |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + d_in: int, |
| 16 | + d_out: int, |
| 17 | + dropout: float, |
| 18 | + num_heads: int, |
| 19 | + sliding_window_size: int, |
| 20 | + dtype: typing.Optional[torch.dtype] = None, |
| 21 | + qkv_bias: bool = False, |
| 22 | + ) -> None: |
| 23 | + super().__init__() |
| 24 | + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" |
| 25 | + assert sliding_window_size > 0, "sliding_window_size must be positive" |
| 26 | + |
| 27 | + self.d_out = d_out |
| 28 | + self.num_heads = num_heads |
| 29 | + self.head_dim = d_out // num_heads |
| 30 | + self.sliding_window_size = int(sliding_window_size) |
| 31 | + |
| 32 | + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias, dtype=dtype) |
| 33 | + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias, dtype=dtype) |
| 34 | + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias, dtype=dtype) |
| 35 | + self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) |
| 36 | + self.dropout = nn.Dropout(dropout) |
| 37 | + |
| 38 | + self.register_buffer("cache_k", None, persistent=False) |
| 39 | + self.register_buffer("cache_v", None, persistent=False) |
| 40 | + self.ptr_current_pos = 0 |
| 41 | + |
| 42 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 43 | + b, num_tokens, _ = x.shape |
| 44 | + |
| 45 | + keys_new = self.W_key(x) |
| 46 | + values_new = self.W_value(x) |
| 47 | + queries = self.W_query(x) |
| 48 | + |
| 49 | + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose( |
| 50 | + 1, 2 |
| 51 | + ) |
| 52 | + keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim).transpose( |
| 53 | + 1, 2 |
| 54 | + ) |
| 55 | + values_new = values_new.view( |
| 56 | + b, num_tokens, self.num_heads, self.head_dim |
| 57 | + ).transpose(1, 2) |
| 58 | + |
| 59 | + # 1. Update the Cache |
| 60 | + if self.cache_k is None: |
| 61 | + self.cache_k, self.cache_v = keys_new, values_new |
| 62 | + else: |
| 63 | + self.cache_k = torch.cat([self.cache_k, keys_new], dim=2) |
| 64 | + self.cache_v = torch.cat([self.cache_v, values_new], dim=2) |
| 65 | + |
| 66 | + # 2. Apply Sliding Window (Truncate) |
| 67 | + # |
| 68 | + # We check the current size (after adding new tokens). |
| 69 | + # |
| 70 | + # If it exceeds the window, we physically delete the oldest tokens. |
| 71 | + if self.cache_k.size(2) > self.sliding_window_size: |
| 72 | + self.cache_k = self.cache_k[:, :, -self.sliding_window_size :, :] |
| 73 | + self.cache_v = self.cache_v[:, :, -self.sliding_window_size :, :] |
| 74 | + |
| 75 | + keys, values = self.cache_k, self.cache_v |
| 76 | + |
| 77 | + attn_scores = queries @ keys.transpose(2, 3) |
| 78 | + |
| 79 | + num_tokens_q = queries.shape[-2] |
| 80 | + num_tokens_k = keys.shape[-2] |
| 81 | + device = queries.device |
| 82 | + |
| 83 | + # 3. Calculate Absolute Positions for Masking |
| 84 | + # |
| 85 | + # We need the absolute index of the tokens in the full text sequence (0, 1, 2, |
| 86 | + # ... 500) to ensure the mask works correctly. |
| 87 | + # |
| 88 | + # The 'right edge' of our cache is the total number of tokens processed so far. |
| 89 | + current_absolute_end = self.ptr_current_pos + num_tokens |
| 90 | + |
| 91 | + # The 'left edge' is simply the end minus the current cache size. |
| 92 | + # This gives us the Absolute Position of keys[:, :, 0, :]. |
| 93 | + k_start_absolute = current_absolute_end - num_tokens_k |
| 94 | + |
| 95 | + # The queries start wherever the previous batch ended. |
| 96 | + q_start_absolute = self.ptr_current_pos |
| 97 | + |
| 98 | + q_positions = torch.arange( |
| 99 | + q_start_absolute, |
| 100 | + q_start_absolute + num_tokens_q, |
| 101 | + device=device, |
| 102 | + dtype=torch.long, |
| 103 | + ) |
| 104 | + |
| 105 | + k_positions = torch.arange( |
| 106 | + k_start_absolute, |
| 107 | + k_start_absolute + num_tokens_k, |
| 108 | + device=device, |
| 109 | + dtype=torch.long, |
| 110 | + ) |
| 111 | + |
| 112 | + # 4. Create and Apply Mask |
| 113 | + diff = q_positions.unsqueeze(-1) - k_positions.unsqueeze(0) |
| 114 | + |
| 115 | + # (diff < 0) -> Causal Mask (prevent looking at future) |
| 116 | + # |
| 117 | + # (diff >= window) -> Window Mask (prevent looking too far back) |
| 118 | + mask = (diff < 0) | (diff >= self.sliding_window_size) |
| 119 | + |
| 120 | + self.ptr_current_pos += num_tokens_q |
| 121 | + |
| 122 | + attn_scores = attn_scores.masked_fill(mask, -torch.inf) |
| 123 | + |
| 124 | + attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1) |
| 125 | + attn_weights = self.dropout(attn_weights) |
| 126 | + |
| 127 | + context_vec = (attn_weights @ values).transpose(1, 2) |
| 128 | + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) |
| 129 | + context_vec = self.out_proj(context_vec) |
| 130 | + return context_vec |
| 131 | + |
| 132 | + def reset_cache(self) -> None: |
| 133 | + self.cache_k, self.cache_v = None, None |
| 134 | + self.ptr_current_pos = 0 |
| 135 | + |
| 136 | + |
| 137 | +class TransformerBlock(nn.Module): |
| 138 | + def __init__(self, cfg: dict[str, typing.Any]) -> None: |
| 139 | + super().__init__() |
| 140 | + self.att = MultiHeadAttentionWithSWA( |
| 141 | + d_in=cfg["emb_dim"], |
| 142 | + d_out=cfg["emb_dim"], |
| 143 | + num_heads=cfg["n_heads"], |
| 144 | + dropout=cfg["drop_rate"], |
| 145 | + qkv_bias=cfg["qkv_bias"], |
| 146 | + sliding_window_size=cfg["sliding_window_size"], |
| 147 | + ) |
| 148 | + self.ff = attention_helpers.FeedForward(cfg) |
| 149 | + self.norm1 = attention_helpers.LayerNorm(cfg["emb_dim"]) |
| 150 | + self.norm2 = attention_helpers.LayerNorm(cfg["emb_dim"]) |
| 151 | + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) |
| 152 | + |
| 153 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 154 | + shortcut = x |
| 155 | + x = self.norm1(x) |
| 156 | + x = self.att(x) |
| 157 | + x = self.drop_shortcut(x) |
| 158 | + x = x + shortcut |
| 159 | + |
| 160 | + shortcut = x |
| 161 | + x = self.norm2(x) |
| 162 | + x = self.ff(x) |
| 163 | + x = self.drop_shortcut(x) |
| 164 | + x = x + shortcut |
| 165 | + |
| 166 | + return x |
| 167 | + |
| 168 | + |
| 169 | +class GPTModel(nn.Module): |
| 170 | + def __init__(self, cfg: dict[str, typing.Any]) -> None: |
| 171 | + super().__init__() |
| 172 | + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) |
| 173 | + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) |
| 174 | + self.drop_emb = nn.Dropout(cfg["drop_rate"]) |
| 175 | + |
| 176 | + self.trf_blocks = nn.ModuleList( |
| 177 | + [TransformerBlock(cfg) for _ in range(cfg["n_layers"])] |
| 178 | + ) |
| 179 | + |
| 180 | + self.current_pos = 0 |
| 181 | + |
| 182 | + self.final_norm = attention_helpers.LayerNorm(cfg["emb_dim"]) |
| 183 | + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) |
| 184 | + |
| 185 | + def forward(self, in_idx: torch.Tensor) -> torch.Tensor: |
| 186 | + _, seq_len = in_idx.shape |
| 187 | + tok_embeds = self.tok_emb(in_idx) |
| 188 | + |
| 189 | + pos_ids = torch.arange( |
| 190 | + self.current_pos, |
| 191 | + self.current_pos + seq_len, |
| 192 | + device=in_idx.device, |
| 193 | + dtype=torch.long, |
| 194 | + ) |
| 195 | + self.current_pos += seq_len |
| 196 | + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) |
| 197 | + |
| 198 | + x = tok_embeds + pos_embeds |
| 199 | + x = self.drop_emb(x) |
| 200 | + |
| 201 | + for blk in self.trf_blocks: |
| 202 | + x = blk(x) |
| 203 | + |
| 204 | + x = self.final_norm(x) |
| 205 | + logits = self.out_head(x) |
| 206 | + return logits |
| 207 | + |
| 208 | + def reset_kv_cache(self) -> None: |
| 209 | + for blk in self.trf_blocks: |
| 210 | + blk.att.reset_cache() |
| 211 | + self.current_pos = 0 |
| 212 | + |
| 213 | + |
| 214 | +def generate_text_simple_cached( |
| 215 | + model: GPTModel, |
| 216 | + idx: torch.Tensor, |
| 217 | + max_new_tokens: int, |
| 218 | + context_size: typing.Optional[int] = None, |
| 219 | +) -> torch.Tensor: |
| 220 | + model.eval() |
| 221 | + ctx_len = context_size or model.pos_emb.num_embeddings |
| 222 | + |
| 223 | + with torch.no_grad(): |
| 224 | + model.reset_kv_cache() |
| 225 | + logits = model(idx[:, -ctx_len:]) |
| 226 | + |
| 227 | + for _ in range(max_new_tokens): |
| 228 | + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) |
| 229 | + idx = torch.cat([idx, next_idx], dim=1) |
| 230 | + logits = model(next_idx) |
| 231 | + |
| 232 | + return idx |
| 233 | + |
| 234 | + |
| 235 | +def main() -> None: |
| 236 | + start_context = "Hello, I am" |
| 237 | + tokenizer = tiktoken.get_encoding("gpt2") |
| 238 | + encoded = tokenizer.encode(start_context) |
| 239 | + |
| 240 | + GPT_CONFIG_124M = { |
| 241 | + "vocab_size": 50257, # Vocabulary size |
| 242 | + "context_length": 1024, # Context length |
| 243 | + "emb_dim": 768, # Embedding dimension |
| 244 | + "n_heads": 12, # Number of attention heads |
| 245 | + "n_layers": 12, # Number of layers |
| 246 | + "drop_rate": 0.0, # Dropout rate |
| 247 | + "qkv_bias": False, # Query-Key-Value bias |
| 248 | + "sliding_window_size": 256, # SWA window size W |
| 249 | + } |
| 250 | + torch.manual_seed(251221) |
| 251 | + model = GPTModel(GPT_CONFIG_124M) |
| 252 | + device = torch.device("cuda" if cuda.is_available() else "cpu") |
| 253 | + model.to(device, dtype=torch.bfloat16) |
| 254 | + model.eval() |
| 255 | + |
| 256 | + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) |
| 257 | + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") |
| 258 | + print("\nInput text:", start_context) |
| 259 | + print("Encoded input text:", encoded) |
| 260 | + print("encoded_tensor.shape:", encoded_tensor.shape) |
| 261 | + |
| 262 | + if cuda.is_available(): |
| 263 | + cuda.synchronize() |
| 264 | + start = time.time() |
| 265 | + |
| 266 | + token_ids = generate_text_simple_cached( |
| 267 | + model=model, idx=encoded_tensor, max_new_tokens=200 |
| 268 | + ) |
| 269 | + |
| 270 | + if cuda.is_available(): |
| 271 | + cuda.synchronize() |
| 272 | + total_time = time.time() - start |
| 273 | + |
| 274 | + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) |
| 275 | + |
| 276 | + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") |
| 277 | + print("\nOutput:", token_ids) |
| 278 | + print("Output length:", len(token_ids[0])) |
| 279 | + print("Output text:", decoded_text) |
| 280 | + |
| 281 | + print(f"\nTime: {total_time:.2f} sec") |
| 282 | + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") |
| 283 | + if cuda.is_available(): |
| 284 | + max_mem_bytes = cuda.max_memory_allocated() |
| 285 | + max_mem_gb = max_mem_bytes / (1024**3) |
| 286 | + print(f"Max memory allocated: {max_mem_gb:.2f} GB") |
| 287 | + |
| 288 | + |
| 289 | +if __name__ == "__main__": |
| 290 | + main() |
0 commit comments