|
| 1 | +import time |
| 2 | +import sys |
| 3 | +import os |
| 4 | +import pathlib |
| 5 | +import importlib |
| 6 | +import traceback |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +import torch.nn.functional as F |
| 10 | +import math |
| 11 | + |
| 12 | + |
| 13 | +######################################################## |
| 14 | +# Baseline |
| 15 | +######################################################## |
| 16 | +class Model(nn.Module): |
| 17 | + """ |
| 18 | + A vanilla multi-head masked self-attention layer with a projection at the end. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen): |
| 22 | + super().__init__() |
| 23 | + assert n_embd % n_head == 0 |
| 24 | + # key, query, value projections for all heads, but in a batch |
| 25 | + self.c_attn = nn.Linear(n_embd, 3 * n_embd) |
| 26 | + # output projection |
| 27 | + self.c_proj = nn.Linear(n_embd, n_embd) |
| 28 | + # regularization |
| 29 | + self.attn_dropout = nn.Dropout(attn_pdrop) |
| 30 | + self.resid_dropout = nn.Dropout(resid_pdrop) |
| 31 | + # causal mask to ensure that attention is only applied to the left in the input sequence |
| 32 | + self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen)) |
| 33 | + self.n_head = n_head |
| 34 | + self.n_embd = n_embd |
| 35 | + |
| 36 | + def forward(self, x): |
| 37 | + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) |
| 38 | + # calculate query, key, values for all heads in batch and move head forward to be the batch dim |
| 39 | + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| 40 | + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |
| 41 | + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |
| 42 | + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |
| 43 | + |
| 44 | + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) |
| 45 | + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| 46 | + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
| 47 | + att = F.softmax(att, dim=-1) |
| 48 | + att = self.attn_dropout(att) |
| 49 | + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) |
| 50 | + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side |
| 51 | + # output projection |
| 52 | + y = self.resid_dropout(self.c_proj(y)) |
| 53 | + return y |
| 54 | + |
| 55 | + |
| 56 | +######################################################## |
| 57 | +# Weco Solution |
| 58 | +######################################################## |
| 59 | +def load_module_from_path(module_path: str, add_to_sys_modules: bool = False): |
| 60 | + # Clean out all old compiled extensions to prevent namespace collisions during build |
| 61 | + module_path = pathlib.Path(module_path) |
| 62 | + name = module_path.stem |
| 63 | + spec = importlib.util.spec_from_file_location(name, module_path) |
| 64 | + mod = importlib.util.module_from_spec(spec) # type: ignore |
| 65 | + if add_to_sys_modules: |
| 66 | + sys.modules[name] = mod |
| 67 | + spec.loader.exec_module(mod) # type: ignore |
| 68 | + return mod |
| 69 | + |
| 70 | + |
| 71 | +######################################################## |
| 72 | +# Benchmark |
| 73 | +######################################################## |
| 74 | +os.environ["MAX_JOBS"] = "1" # number of workers for building with ninja |
| 75 | + |
| 76 | + |
| 77 | +def get_inputs(batch_size, seq_len, n_embd, device): |
| 78 | + return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32) |
| 79 | + |
| 80 | + |
| 81 | +def bench(f, inputs, n_warmup, n_rep): |
| 82 | + with torch.no_grad(): |
| 83 | + # warmup |
| 84 | + for _ in range(n_warmup): |
| 85 | + f(inputs) # noqa |
| 86 | + |
| 87 | + # benchmark |
| 88 | + t_avg = 0.0 |
| 89 | + for _ in range(n_rep): |
| 90 | + torch.cuda.empty_cache() # Clear cache before timing |
| 91 | + start_time = time.time() |
| 92 | + f(inputs) |
| 93 | + torch.cuda.synchronize() # Wait for all computations to complete |
| 94 | + t_avg += time.time() - start_time |
| 95 | + t_avg /= n_rep * 1e-3 |
| 96 | + return t_avg |
| 97 | + |
| 98 | + |
| 99 | +if __name__ == "__main__": |
| 100 | + import argparse |
| 101 | + |
| 102 | + parser = argparse.ArgumentParser() |
| 103 | + parser.add_argument("--solution-path", type=str, required=True) |
| 104 | + args = parser.parse_args() |
| 105 | + |
| 106 | + # benchmarking parameters |
| 107 | + n_correctness_trials = 10 |
| 108 | + n_warmup = 1000 |
| 109 | + n_rep = 5000 |
| 110 | + |
| 111 | + # init parameters |
| 112 | + max_seqlen = 512 |
| 113 | + seq_len = 256 |
| 114 | + n_embd = 768 |
| 115 | + n_head = 8 |
| 116 | + # turn off dropout to measure correctness well |
| 117 | + attn_pdrop = 0.0 |
| 118 | + resid_pdrop = 0.0 |
| 119 | + |
| 120 | + # input parameters |
| 121 | + batch_size = 32 |
| 122 | + |
| 123 | + # load solution module |
| 124 | + try: |
| 125 | + torch.manual_seed(0) |
| 126 | + solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False) |
| 127 | + solution_model = solution_module.Model( |
| 128 | + n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen |
| 129 | + ).to("cuda") |
| 130 | + assert isinstance(solution_model, nn.Module) |
| 131 | + except Exception: |
| 132 | + print(f"Candidate module initialization failed: {traceback.format_exc()}") |
| 133 | + exit(1) |
| 134 | + |
| 135 | + torch.manual_seed(0) |
| 136 | + baseline_model = Model( |
| 137 | + n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen |
| 138 | + ).to("cuda") |
| 139 | + |
| 140 | + # measure correctness |
| 141 | + max_diff_avg = 0 |
| 142 | + for _ in range(n_correctness_trials): |
| 143 | + inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda") |
| 144 | + with torch.no_grad(): |
| 145 | + baseline_output = baseline_model(inputs) |
| 146 | + optimized_output = solution_model(inputs) |
| 147 | + max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output)) |
| 148 | + max_diff_avg /= n_correctness_trials |
| 149 | + print(f"max float diff between values of baseline and optimized model: {max_diff_avg}") |
| 150 | + |
| 151 | + # measure performance |
| 152 | + inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda") |
| 153 | + t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep) |
| 154 | + print(f"baseline time: {t_avg_baseline:.2f}ms") |
| 155 | + t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep) |
| 156 | + print(f"optimized time: {t_avg_optimized:.2f}ms") |
| 157 | + print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x") |
0 commit comments