Skip to content

Commit 80184c1

Browse files
author
james
committed
bugfix
1 parent 1defaa1 commit 80184c1

File tree

8 files changed

+25
-28
lines changed

8 files changed

+25
-28
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ CUDA_VISIBLE_DEVICES=0 python generate.py \
9494
--interactive
9595
```
9696

97-
4. Run sparse inference (`scripts/run.sh`)!:
97+
4. Run sparse inference! (`scripts/run.sh`):
9898
```bash
9999
CUDA_VISIBLE_DEVICES=0 python generate.py \
100100
--compile \

gpt-fast/scripts/speculate_34B_bf16.sh

Lines changed: 0 additions & 4 deletions
This file was deleted.

gpt-fast/scripts/speculate_70B_int4.sh

Lines changed: 0 additions & 4 deletions
This file was deleted.

gpt-fast/scripts/speculate_7B_int4.sh

Lines changed: 0 additions & 3 deletions
This file was deleted.

gpt-fast/scripts/speculate_tp_70B_bf16.sh

Lines changed: 0 additions & 3 deletions
This file was deleted.

gpt-fast/scripts/tp_run.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
time torchrun --standalone --nproc_per_node=4 generate.py --compile --checkpoint_path $OUTPUT_PATH/meta-llama/Llama-2-7b-hf/model.pth --hist_path ../models/Llama-2-70B/histograms --sparsity 0.5 --prompt "Hello, my name is "

kernels/sparse_gemv.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,32 @@ def init_func(nargs):
1414
# NOTE: will need to warm up kernels each time, triton autotune caching isn't a thing right now
1515

1616
configs=[
17+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
18+
19+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")),
1720
triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
1821
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
1922
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
20-
triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
21-
#triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
2223
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
23-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
24-
#triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
2524
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
26-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
27-
#triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
2825
triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_warps=4, pre_hook=init_to_zero("Y")),
2926
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4, pre_hook=init_to_zero("Y")),
3027
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")),
3128
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")),
3229
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
33-
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
34-
#triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
30+
31+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
32+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
33+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
34+
triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
35+
36+
37+
# Llama 3 variants can use BLOCK_N >= 1024
38+
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
39+
# triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
40+
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
41+
# triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
42+
# triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
3543
]
3644

3745
@triton.autotune(
@@ -287,7 +295,6 @@ def forward(
287295
sparsity_bin: int,
288296
kv_size: int
289297
) -> torch.Tensor:
290-
return torch.matmul(x, weight.T)
291298
return qkv_gemv(x, weight, threshold_q, threshold_k, threshold_v, sparsity_bin, kv_size) if x.shape[1] == 1 else torch.matmul(x, weight.T)
292299

293300
# for testing purposes, to see if overhead at 0% is really due to strengthening torch.matmul (seems like it is)

teal/grab_acts.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,16 @@
5252
for sample in tqdm(dataset):
5353
text += sample["text"] + "\n\n"
5454

55-
55+
print(len(text))
5656
bsz, seq_len = 10, 2048
5757

58+
input_ids = []
59+
for i in range(0, len(text), seq_len):
60+
ttext = text[i:i+seq_len]
61+
encodings = tokenizer(ttext, truncation=True, return_tensors="pt", max_length=seq_len, return_overflowing_tokens=True, padding="max_length")
62+
input_ids.append(encodings.input_ids)
5863

59-
encodings = tokenizer(text, truncation=True, return_tensors="pt", max_length=seq_len, return_overflowing_tokens=True, padding="max_length")
60-
61-
input_ids = encodings.input_ids[:bsz,:].to(device="cuda:0")
64+
input_ids = torch.cat(input_ids, dim=0)[:bsz,:].to(device="cuda:0")
6265
print(input_ids.shape)
6366

6467
hidden_states = model.model.embed_tokens(input_ids)

0 commit comments

Comments
 (0)