Skip to content

Commit e3e1b25

Browse files
author
Faisal Ladhak
committed
Squashed commit of the following:
commit 3fb32cf Merge: 47dc1fa e7baa14 Author: Faisal Ladhak <[email protected]> Date: Mon Jun 24 21:17:21 2024 +0000 Merge branch 'main' into mcqa commit 47dc1fa Author: Faisal Ladhak <[email protected]> Date: Thu Jun 20 22:19:02 2024 +0000 Refactor code to move generation logic to generation_utils. commit 00d582f Author: Faisal Ladhak <[email protected]> Date: Thu Jun 20 21:25:59 2024 +0000 Added code for selecting answer based on logits for MCQ, along with code for TruthfulQA.
1 parent e7baa14 commit e3e1b25

File tree

6 files changed

+460
-329
lines changed

6 files changed

+460
-329
lines changed

eval.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def device_sync(device):
3535
sys.path.append(str(wd))
3636

3737
from tokenizer import get_tokenizer
38-
from generate import _load_model, generate, encode_tokens
38+
from generation_utils import _load_model, generate, encode_tokens
3939
from task import TASK_MAPPING, AutoTask
4040

4141

@@ -99,6 +99,7 @@ def main(
9999
"accept_counts": [],
100100
}
101101
predictions = []
102+
all_probs = []
102103
for row in tqdm(task.get_test()):
103104
prompt = row["prompt"]
104105
if is_chat:
@@ -119,7 +120,7 @@ def main(
119120
torch.profiler._utils._init_for_cuda_graphs()
120121
prof = torch.profiler.profile()
121122
with prof:
122-
y, metrics = generate(
123+
y, metrics, probs = generate(
123124
model,
124125
encoded,
125126
max_new_tokens=task.max_tokens,
@@ -144,6 +145,10 @@ def main(
144145
tokenizer.decode(encoded.tolist())
145146
)[1]
146147
predictions.append(pred)
148+
if task.requires_logits:
149+
all_probs.append(
150+
{k: v for k, v in zip(tokenizer.get_vocab(), probs[0].tolist())}
151+
)
147152
tokens_generated = y.size(0) - prompt_length
148153
tokens_sec = tokens_generated / t
149154
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
@@ -155,7 +160,10 @@ def main(
155160
task_metrics[task_name]["tokens_per_sec"] = torch.mean(
156161
torch.tensor(aggregate_metrics["tokens_per_sec"])
157162
).item()
158-
task_metrics[task_name]["task_metrics"] = task.test_metrics(predictions)
163+
if task.requires_logits:
164+
task_metrics[task_name]["task_metrics"] = task.test_metrics(all_probs)
165+
else:
166+
task_metrics[task_name]["task_metrics"] = task.test_metrics(predictions)
159167
print(task_metrics[task_name]["task_metrics"])
160168

161169

@@ -170,7 +178,7 @@ def main(
170178
"--tasks",
171179
type=str,
172180
nargs="+",
173-
default=["squality"],
181+
default=["truthfulqa"],
174182
choices=list(TASK_MAPPING.keys()),
175183
help="List of tasks to be evaluated.",
176184
)

generate.py

Lines changed: 2 additions & 323 deletions
Original file line numberDiff line numberDiff line change
@@ -33,329 +33,8 @@ def device_sync(device):
3333
wd = Path(__file__).parent.parent.resolve()
3434
sys.path.append(str(wd))
3535

36-
from model import Transformer, find_multiple
3736
from tokenizer import get_tokenizer
38-
39-
40-
def multinomial_sample_one_no_sync(
41-
probs_sort,
42-
): # Does multinomial sampling without a cuda synchronization
43-
q = torch.empty_like(probs_sort).exponential_(1)
44-
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
45-
46-
47-
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
48-
logits = logits / max(temperature, 1e-5)
49-
50-
if top_k is not None:
51-
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
52-
pivot = v.select(-1, -1).unsqueeze(-1)
53-
logits = torch.where(logits < pivot, -float("Inf"), logits)
54-
probs = torch.nn.functional.softmax(logits, dim=-1)
55-
return probs
56-
57-
58-
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
59-
probs = logits_to_probs(logits[0, -1], temperature, top_k)
60-
idx_next = multinomial_sample_one_no_sync(probs)
61-
return idx_next, probs
62-
63-
64-
def prefill(
65-
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
66-
) -> torch.Tensor:
67-
# input_pos: [B, S]
68-
causal_mask = (
69-
torch.tril(torch.ones(len(input_pos), len(input_pos), dtype=torch.bool))
70-
.unsqueeze(0)
71-
.unsqueeze(0)
72-
.to(x.device)
73-
)
74-
logits = model(x, input_pos, mask=causal_mask)
75-
return sample(logits, **sampling_kwargs)[0]
76-
77-
78-
def decode_one_token(
79-
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
80-
) -> Tuple[torch.Tensor, torch.Tensor]:
81-
# input_pos: [B, 1]
82-
assert input_pos.shape[-1] == 1
83-
logits = model(x, input_pos)
84-
return sample(logits, **sampling_kwargs)
85-
86-
87-
def decode_n_tokens(
88-
model: Transformer,
89-
cur_token: torch.Tensor,
90-
input_pos: torch.Tensor,
91-
num_new_tokens: int,
92-
terminator_ids: Optional[list] = None,
93-
callback=lambda _: _,
94-
**sampling_kwargs,
95-
):
96-
new_tokens, new_probs = [], []
97-
for i in range(num_new_tokens):
98-
with torch.backends.cuda.sdp_kernel(
99-
enable_flash=False, enable_mem_efficient=False, enable_math=True
100-
): # Actually better for Inductor to codegen attention here
101-
next_token, next_prob = decode_one_token(
102-
model, cur_token, input_pos, **sampling_kwargs
103-
)
104-
105-
if terminator_ids and next_token in terminator_ids:
106-
break
107-
108-
input_pos += 1
109-
new_tokens.append(next_token.clone())
110-
callback(new_tokens[-1])
111-
new_probs.append(next_prob.clone())
112-
cur_token = next_token.view(1, -1)
113-
114-
return new_tokens, new_probs
115-
116-
117-
def model_forward(model, x, input_pos):
118-
return model(x, input_pos)
119-
120-
121-
def speculative_decode(
122-
model: Transformer,
123-
draft_model: Transformer,
124-
cur_token: torch.Tensor,
125-
input_pos: int,
126-
speculate_k: int,
127-
**sampling_kwargs,
128-
) -> torch.Tensor:
129-
# draft model inference sequentially
130-
device = cur_token.device
131-
orig_input_pos = torch.tensor(
132-
[input_pos], dtype=torch.int64, device=cur_token.device
133-
)
134-
draft_tokens, draft_probs = decode_n_tokens(
135-
draft_model,
136-
cur_token.view(1, -1),
137-
orig_input_pos.clone(),
138-
speculate_k,
139-
**sampling_kwargs,
140-
)
141-
142-
draft_tokens = torch.cat(draft_tokens)
143-
# parallel inference on target model using draft tokens
144-
target_logits = model_forward(
145-
model,
146-
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
147-
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device),
148-
)
149-
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
150-
draft_probs = torch.stack(draft_probs)
151-
# q: target prob, p: draft prob
152-
# q >= p: always accept draft token
153-
# q < p: q/p prob to accept draft token
154-
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
155-
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
156-
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k] / p)
157-
rejected_locations = (
158-
torch.rand_like(accept_draft_prob) > accept_draft_prob
159-
).nonzero()
160-
161-
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
162-
accept_length = speculate_k + 1
163-
last_token = multinomial_sample_one_no_sync(target_probs[-1])
164-
# fill last token into draft model
165-
model_forward(
166-
draft_model,
167-
draft_tokens[-1].view(1, -1),
168-
orig_input_pos + speculate_k,
169-
)
170-
return torch.cat([draft_tokens, last_token])
171-
else:
172-
accept_length = rejected_locations[0].item()
173-
p = draft_probs[accept_length]
174-
q = target_probs[accept_length]
175-
new = q - p
176-
new = torch.where(new > 0, new, 0.0)
177-
new = new / new.sum()
178-
next_token = multinomial_sample_one_no_sync(new)
179-
return torch.cat([draft_tokens[:accept_length], next_token])
180-
181-
182-
def normalize_cache_length(
183-
max_cache_length: float, max_seq_length: int, multiple_of: int = 8
184-
) -> int:
185-
"""
186-
Computes the absolute cache length given the max_cache_length and max_seq_length.
187-
"""
188-
if 0 < max_cache_length <= 1:
189-
max_cache_length = round(max_seq_length * max_cache_length)
190-
else:
191-
assert int(max_cache_length) == max_cache_length
192-
max_cache_length = int(max_cache_length)
193-
if max_cache_length > max_seq_length:
194-
print(
195-
f"Warning: max_cache_length ({max_cache_length}) is greater than max_seq_length ({max_seq_length}). Setting to {max_seq_length}"
196-
)
197-
max_cache_length = max_seq_length
198-
return min(find_multiple(max_cache_length, multiple_of), max_seq_length)
199-
200-
201-
@torch.no_grad()
202-
def generate(
203-
model: Transformer,
204-
prompt: torch.Tensor,
205-
max_new_tokens: int,
206-
*,
207-
interactive: bool,
208-
draft_model: Transformer,
209-
speculate_k: Optional[int] = 8,
210-
callback=lambda x: x,
211-
terminator_ids: Optional[list] = None,
212-
cache_kwargs: dict = None,
213-
**sampling_kwargs,
214-
) -> torch.Tensor:
215-
"""
216-
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
217-
"""
218-
219-
is_speculative = draft_model is not None
220-
# create an empty tensor of the expected final shape and fill in the current tokens
221-
T = prompt.size(0)
222-
max_seq_length = min(T + max_new_tokens, model.config.block_size)
223-
if interactive:
224-
max_seq_length = 350
225-
print(f"Maximum context length of {max_seq_length} tokens.")
226-
227-
max_new_tokens = max_seq_length - T
228-
229-
device, dtype = prompt.device, prompt.dtype
230-
max_seq_length = (
231-
max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
232-
)
233-
234-
# Normalize max_cache_length to absolute cache length if provided as a fraction of the max seq sequence length
235-
cache_kwargs["max_cache_length"] = list(
236-
map(
237-
lambda l: normalize_cache_length(l, max_seq_length),
238-
cache_kwargs["max_cache_length"],
239-
)
240-
)
241-
assert (
242-
model.config.n_layer % len(cache_kwargs["max_cache_length"]) == 0
243-
), f'max_cache_length ({len(cache_kwargs["max_cache_length"])}) must be a factor of {model.config.n_layer} layers.'
244-
245-
tile_size = model.config.n_layer // len(cache_kwargs["max_cache_length"])
246-
cache_kwargs["max_cache_length"] = [
247-
item for item in cache_kwargs["max_cache_length"] for _ in range(tile_size)
248-
]
249-
250-
# Gets called twice when model is wrapped in torch.compile which causes an error without the if statement
251-
if type(cache_kwargs["drop_amount"]) != list:
252-
cache_kwargs["drop_amount"] = [
253-
max(int(cache_kwargs["drop_amount"] * l), 1)
254-
for l in cache_kwargs["max_cache_length"]
255-
]
256-
257-
assert cache_kwargs["global_tokens"] <= min(
258-
cache_kwargs["max_cache_length"]
259-
), "Global tokens must be less than max_cache_length."
260-
261-
with torch.device(device):
262-
model.setup_caches(max_batch_size=1, **cache_kwargs)
263-
if is_speculative and draft_model is not model:
264-
draft_model.setup_caches(max_batch_size=1, **cache_kwargs)
265-
266-
# create an empty tensor (all -1) of the expected final shape and fill in the current tokens
267-
# GPT-Fast had this as empty but the values of empty are non-deterministic
268-
seq = torch.full((max_seq_length,), -1, dtype=dtype, device=device)
269-
seq[:T] = prompt
270-
input_pos = torch.arange(0, T, device=device)
271-
272-
next_token = prefill(
273-
model, prompt.view(1, -1), input_pos, **sampling_kwargs
274-
).clone()
275-
if is_speculative:
276-
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
277-
seq[T] = next_token
278-
279-
input_pos = torch.tensor([T], device=device, dtype=torch.int)
280-
accept_counts = [0] * (speculate_k + 1)
281-
282-
if is_speculative:
283-
input_pos = input_pos.item() # for speculative decoding easier to keep on host
284-
while input_pos < max_seq_length - 1:
285-
cur_token = next_token.view(())
286-
287-
next_tokens = speculative_decode(
288-
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
289-
)
290-
291-
accept_counts[len(next_tokens) - 1] += 1
292-
num_added = min(max_seq_length - input_pos - 1, len(next_tokens))
293-
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[:num_added]
294-
for i in next_tokens[:num_added,]:
295-
callback(i)
296-
input_pos = input_pos + num_added
297-
next_token = next_tokens[-1]
298-
else:
299-
generated_tokens, _ = decode_n_tokens(
300-
model,
301-
next_token.view(1, -1),
302-
input_pos,
303-
max_new_tokens - 1,
304-
callback=callback,
305-
terminator_ids=terminator_ids,
306-
**sampling_kwargs,
307-
)
308-
if len(generated_tokens) > 0:
309-
seq[T + 1 : T + 1 + len(generated_tokens)] = torch.cat(generated_tokens)
310-
311-
# Truncate seq to first instance of -1 if -1 is present
312-
if -1 in seq:
313-
seq = seq[: torch.where(seq == -1)[0][0]]
314-
315-
generate_stats = {"accept_counts": accept_counts}
316-
return seq, generate_stats
317-
318-
319-
def encode_tokens(tokenizer, string, bos=True, device=default_device):
320-
tokens = tokenizer.encode(string)
321-
if bos:
322-
tokens = [tokenizer.bos_id()] + tokens
323-
return torch.tensor(tokens, dtype=torch.int, device=device)
324-
325-
326-
def _load_model(checkpoint_path, device, precision, use_tp):
327-
use_cuda = "cuda" in device
328-
with torch.device("meta"):
329-
model = Transformer.from_name(checkpoint_path.parent.name)
330-
331-
if "int8" in str(checkpoint_path):
332-
print("Using int8 weight-only quantization!")
333-
from quantize import WeightOnlyInt8QuantHandler
334-
335-
simple_quantizer = WeightOnlyInt8QuantHandler(model)
336-
model = simple_quantizer.convert_for_runtime()
337-
338-
if "int4" in str(checkpoint_path):
339-
print("Using int4 weight-only quantization!")
340-
path_comps = checkpoint_path.name.split(".")
341-
groupsize = int(path_comps[-2][1:])
342-
from quantize import WeightOnlyInt4QuantHandler
343-
344-
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
345-
model = simple_quantizer.convert_for_runtime()
346-
347-
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
348-
if "model" in checkpoint and "stories" in str(checkpoint_path):
349-
checkpoint = checkpoint["model"]
350-
model.load_state_dict(checkpoint, assign=True)
351-
if use_tp:
352-
from tp import apply_tp
353-
354-
print("Applying tensor parallel to model ...")
355-
apply_tp(model)
356-
357-
model = model.to(device=device, dtype=precision)
358-
return model.eval()
37+
from generation_utils import generate, encode_tokens, _load_model
35938

36039

36140
def _get_model_size(model):
@@ -513,7 +192,7 @@ def callback(x):
513192
torch.profiler._utils._init_for_cuda_graphs()
514193
prof = torch.profiler.profile()
515194
with prof:
516-
y, metrics = generate(
195+
y, metrics, _ = generate(
517196
model,
518197
encoded,
519198
max_new_tokens,

0 commit comments

Comments
 (0)