Skip to content

Commit 0146e6c

Browse files
committed
add gist generation utils to library
1 parent 12be67c commit 0146e6c

18 files changed

+288
-208
lines changed

cache.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
2020
help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
2121
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.",
2222
)
23-
strategies = ["full", "random", "window", "scissor", "l2", "fastgen"]
23+
strategies = ["full", "random", "window", "scissor", "l2", "fastgen", "gist"]
2424
debug_strategies = [f"debug_{strategy}" for strategy in strategies]
2525
strategies.extend(debug_strategies)
2626

@@ -105,11 +105,14 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
105105

106106

107107
def cache_compatibility(args):
108-
if args.cache_strategy == "full":
108+
if args.cache_strategy in ("full", "gist"):
109109
# Full implies no compression, which means --max_cache_length = [1.0] (same size as prompt + max_new_tokens)
110110
assert all(
111111
[l == 1.0 for l in args.max_cache_length]
112112
), "Full cache strategy only supports max_cache_length=1.0."
113+
114+
if args.cache_strategy == "gist":
115+
assert "gist" in str(args.checkpoint_path)
113116

114117
# Attention-based eviction policies must use an attention-based prompt compressor
115118
if args.cache_strategy in {"scissor"}:
@@ -461,6 +464,68 @@ def mark_global_tokens(self, num_total_insertions: int) -> bool:
461464
self.pos[:, :, :num_to_mark] = self.max_seq_length
462465
return num_to_mark == self.global_tokens
463466

467+
class KVCacheGist(KVCache):
468+
relevant_kwargs = [
469+
'gist_token_id',
470+
'max_cache_length'
471+
]
472+
def __init__(
473+
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
474+
):
475+
# Never any prompt compression for full cache
476+
self.prompt_compression_strategy = None
477+
self.global_tokens = 0 # No global tokens for full cache (they are all global)
478+
479+
assert 'gist_token_id' in kwargs, "You must provide a gist token id for the gist cache."
480+
self.gist_token_id = kwargs.pop('gist_token_id')
481+
super().__init__(
482+
max_batch_size, n_heads, head_dim, dtype, head_specific=False, **kwargs
483+
)
484+
self.prefill_attn_callback = {
485+
"func": self.profile_and_update,
486+
"kwargs": {},
487+
}
488+
self.register_buffer(
489+
"ids", # Track ids to keep track of the original ids of each item in cache. required to determine gist mask in case of multi-batch inputs
490+
torch.full(
491+
(
492+
max_batch_size,
493+
self.max_cache_length,
494+
),
495+
-1,
496+
dtype=torch.int,
497+
),
498+
)
499+
500+
501+
def _update(self, input_pos, k_val, v_val, input_ids=None):
502+
# input_pos: [S], k_val: [B, H, S, D], input_ids: [B, S]
503+
504+
self.fill_contiguous(input_pos, k_val, v_val)
505+
self.ids[:, self.cache_cts[0]:self.cache_cts[0]+input_ids.shape[-1]] = input_ids
506+
return input_pos.shape[-1]
507+
508+
def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn):
509+
assert self.is_prefill(), "Should only be profiling during prefill stage."
510+
gist_pos = torch.where(input_ids == self.gist_token_id)[-1].min().cpu().item()
511+
seq_len = input_pos.shape[-1]
512+
input_pos = input_pos[gist_pos:]
513+
input_ids = input_ids[:, gist_pos:]
514+
k_val = k_val[:, :, gist_pos:, :]
515+
v_val = v_val[:, :, gist_pos:, :]
516+
517+
self.fill_contiguous(input_pos, k_val, v_val)
518+
self.ids[:, self.cache_cts[0]:self.cache_cts[0]+input_ids.shape[-1]] = input_ids
519+
self.cache_cts[0] = input_pos.shape[-1]
520+
521+
def return_kv_cache(self):
522+
k, v, mask = super().return_kv_cache()
523+
mask_shape = (k.shape[0], k.shape[1], 1, k.shape[-2])
524+
causal_mask = torch.ones(mask_shape, dtype=torch.bool).to(k.device)
525+
gist_token_positions = torch.stack(torch.where(self.ids == self.gist_token_id)).T
526+
for position in gist_token_positions:
527+
causal_mask[position[0], :, :, :position[1]] = False
528+
return k, v, causal_mask
464529

465530
class KVCacheFull(KVCache):
466531
def __init__(
@@ -1258,6 +1323,8 @@ def get_cache_constructor(cache_strategy):
12581323
cls = KVCacheScissorhands
12591324
elif cache_strategy == "fastgen":
12601325
cls = KVCacheFastGen
1326+
elif cache_strategy == "gist":
1327+
cls = KVCacheGist
12611328
elif cache_strategy.startswith("debug"):
12621329
cache_strategy = re.sub(r"debug_+", "", cache_strategy).strip()
12631330
relevant_kwargs = get_cache_constructor(cache_strategy)[1]

generate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import time
88
import contextlib
9+
import json
910
from pathlib import Path
1011
from typing import Optional
1112

@@ -84,6 +85,8 @@ def main(
8485

8586
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path, is_chat=is_chat)
8687

88+
gist_token_id = tokenizer.gist_token_id() if hasattr(tokenizer, "gist_token_id") else None
89+
8790
inputs = [encode(tokenizer, prompt, device=device, is_chat=is_chat)]
8891

8992
terminator_ids = tokenizer.get_terminator_ids()
@@ -124,6 +127,7 @@ def main(
124127
inputs[0],
125128
max_new_tokens=max_new_tokens,
126129
terminator_ids=terminator_ids,
130+
gist_token_id=gist_token_id,
127131
feed_long_prompts=feed_long_prompts,
128132
)
129133
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
@@ -159,8 +163,8 @@ def main(
159163
parser.add_argument(
160164
"--prompt",
161165
type=str,
162-
default="long_prompt_short_output.txt",
163-
help="Input prompt. If it ends in .txt, we will load the prompt from the ./prompts dir.",
166+
default="long_prompt_short_output.json",
167+
help="Input prompt. If it ends in .json, we will load the prompt from the ./prompts dir.",
164168
)
165169
parser.add_argument(
166170
"--max_new_tokens", type=int, default=512, help="Maximum number of new tokens."
@@ -171,10 +175,10 @@ def main(
171175

172176
args = parser.parse_args()
173177

174-
if args.prompt.endswith(".txt"):
178+
if args.prompt.endswith(".json"):
175179
prompt_fn = Path(__file__).resolve().parent / "prompts" / args.prompt
176180
with open(prompt_fn) as fd:
177-
args.prompt = fd.read().strip()
181+
args.prompt = json.load(fd)
178182

179183
cache_compatibility(args)
180184

@@ -188,3 +192,4 @@ def main(
188192
args.device,
189193
cache_kwargs=vars(args),
190194
)
195+

generation_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
7979
def sample(
8080
logits: torch.Tensor,
8181
next_token: torch.Tensor = None,
82-
temperature: float = 1.0,
83-
top_k: Optional[int] = None,
82+
temperature: float = 0.4,
83+
top_k: Optional[int] = 100,
8484
):
8585
probs = logits_to_probs(logits[0, -1], temperature, top_k)
8686
if next_token is None:
@@ -104,6 +104,7 @@ def prefill(
104104
x: torch.Tensor,
105105
input_pos: torch.Tensor,
106106
next_token: torch.Tensor = None,
107+
gist_token_id: int = -1,
107108
**sampling_kwargs,
108109
) -> torch.Tensor:
109110
# input_pos: [B, S]
@@ -113,8 +114,12 @@ def prefill(
113114
.unsqueeze(0)
114115
.to(x.device)
115116
)
117+
gist_token_positions = torch.stack(torch.where(x == gist_token_id)).T
118+
for position in gist_token_positions:
119+
causal_mask[position[0], :, position[1] + 1:, :position[1]] = False
120+
116121
logits = model(x, input_pos, mask=causal_mask)
117-
return greedy(logits, next_token)
122+
return sample(logits, next_token)
118123

119124

120125
def decode_one_token(
@@ -127,7 +132,7 @@ def decode_one_token(
127132
# input_pos: [B, 1]
128133
assert input_pos.shape[-1] == 1
129134
logits = model(x, input_pos)
130-
return greedy(logits, next_token=next_token)
135+
return sample(logits, next_token=next_token)
131136

132137

133138
def decode_n_tokens(
@@ -240,6 +245,9 @@ def setup_caches(
240245
"punctuation": tokenizer.punctuation_ids(),
241246
}
242247

248+
if "gist" in cache_kwargs["cache_strategy"]:
249+
cache_kwargs["gist_token_id"] = tokenizer.gist_token_id()
250+
243251
with torch.device(device):
244252
model.setup_caches(max_batch_size=1, **cache_kwargs)
245253

@@ -258,6 +266,7 @@ def generate(
258266
prompt: torch.Tensor,
259267
max_new_tokens: int,
260268
terminator_ids: Optional[list] = None,
269+
gist_token_id: int = -1,
261270
feed_long_prompts: bool = False,
262271
**sampling_kwargs,
263272
) -> torch.Tensor:
@@ -295,6 +304,7 @@ def generate(
295304
prompt.view(1, -1),
296305
input_pos,
297306
next_token=None if prefix is None else prefix[0].view(1),
307+
gist_token_id=gist_token_id,
298308
**sampling_kwargs,
299309
)
300310
next_token = ret[0].clone()

model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ def from_name(cls, name: str):
145145
norm_eps=1e-6,
146146
max_length=32768,
147147
),
148+
"Meta-Llama-3-8B-gist-finetune-input-only": dict(
149+
block_size=8192,
150+
n_layer=32,
151+
n_head=32,
152+
n_local_heads=8,
153+
dim=4096,
154+
intermediate_size=14336,
155+
vocab_size=128257,
156+
rope_base=500000
157+
),
148158
}
149159

150160

prompts/long_prompt_long_output.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"instruction": "You are an architect tasked with drawing up plans for a modern residential house.\n\nArchitectural Plan Creation Instructions\n\nObjective:\nCreate a comprehensive set of architectural plans for a modern residential house. The plans should include detailed layouts, elevations, sections, and necessary annotations to guide the construction process. The design should focus on functionality, aesthetics, sustainability, and compliance with local building codes.\n\nRequirements:\n\nGeneral Layout:\n\nTotal area: Approximately 2,500 square feet.\nNumber of floors: Two.\nNumber of bedrooms: Four (including a master suite).\nNumber of bathrooms: Three full bathrooms and one half bathroom.\nCommon areas: Open-plan kitchen, dining area, living room, and a study/office.\nAdditional spaces: Laundry room, garage (for two cars), storage rooms, and a small basement.\nSite Plan:\n\nInclude property boundaries, adjacent streets, and any existing structures.\nShow the placement of the house, driveway, pathways, garden, and outdoor living spaces (e.g., patio, deck).\nInclude landscaping elements like trees, shrubs, and lawn areas.\nFloor Plans:\n\nGround Floor: Include entryway, living spaces, kitchen, one bedroom (guest room), one full bathroom, and access to the garage.\nSecond Floor: Include master suite with attached bathroom and walk-in closet, two additional bedrooms, one full bathroom, and a study/office.\nIndicate all door and window placements, furniture layouts, and circulation paths.\nElevations:\n\nProvide front, rear, and side elevations.\nShow the external appearance, including the roof design, facade materials, window and door placements, and any architectural features (e.g., balconies, porches).\nSections:\n\nInclude at least two sections (one longitudinal and one cross-sectional) showing internal details.\nHighlight the relationship between different floors and ceiling heights.\nShow structural elements like beams, columns, and floor slabs.\nRoof Plan:\n\nIndicate the roof slope, materials, drainage system, and any roof features (e.g., skylights, chimneys).\nElectrical and Plumbing Plans:\n\nShow the layout of electrical outlets, switches, lighting fixtures, and major appliances.\nInclude the plumbing layout for water supply and drainage, showing the location of pipes, fixtures, and connections.\nMaterials and Finishes:\n\nSpecify the materials for walls, floors, ceilings, and roofs.\nInclude details on interior and exterior finishes (e.g., paint, tiles, cladding).\nSustainability Features:\n\nIncorporate energy-efficient systems (e.g., HVAC, solar panels).\nUse sustainable building materials.\nPlan for natural lighting and ventilation.\nInclude rainwater harvesting and greywater recycling systems if possible.\nCompliance:\n\nEnsure the design complies with local building codes and regulations.\nInclude necessary annotations and notes for construction guidelines.\n\n",
3+
"input": "You must return the following:\n- Include a detailed list of materials and specifications.\n- Add a cover sheet with project title, address, date, and designer's name.\n- Add a sheet for each component with detailed plans.\n- Ensure all documents are clearly labeled and organized."
4+
}

prompts/long_prompt_long_output.txt

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)