Skip to content

Commit ac9aa36

Browse files
generate
1 parent 384ae56 commit ac9aa36

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

data/eassy.txt

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

examples/generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
parser.add_argument('--M', type=int, default=8192, help='max length')
1212
parser.add_argument('--D', type=int, default=1, help='dec length')
1313
parser.add_argument('--G', type=int, default=256, help='generation length')
14+
parser.add_argument('--t', type=float, default=0.6, help='temperature')
1415
parser.add_argument('--K', type=int, default=10, help='K')
1516
parser.add_argument('--L', type=int, default=150, help='K')
1617
parser.add_argument('--data', type=str, default="../data/story.txt", help='source data file')
@@ -35,7 +36,7 @@
3536
input_ids = input_ids.to(DEVICE)
3637
PREFIX_LEN = input_ids.shape[1]
3738
position_ids = torch.arange(MAX_LEN, device=DEVICE).unsqueeze(0)
38-
generated = llm.generate(input_ids, max_tokens=args.G)
39+
generated = llm.generate(input_ids, max_tokens=args.G, verbose=True, temperature=args.t)
3940
text = tokenizer.decode(generated, skip_special_tokens=True)
4041
print("\033[32m" + text + "\033[0m")
4142

models/llama.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .utils import apply_rotary_pos_emb, layer_norm, topp_temperature_decode
77
import flashinfer
88
from .attnserver import LSHSparseAttnServer, AttnServer
9+
import time
910
class LLMLayer:
1011
def __init__(self, layer_idx) -> None:
1112

@@ -328,13 +329,16 @@ def generate(self,
328329
input_ids: torch.LongTensor,
329330
max_tokens: int = 128,
330331
temperature: float = 0.6,
331-
topp: float = 0.9):
332+
topp: float = 0.9,
333+
verbose: bool = False):
332334

333335
generated = []
334336
prefix_len = input_ids.shape[1]
335337
position_ids = torch.arange(prefix_len + max_tokens, device=self.device).unsqueeze(0)
336338
logits = self.prefill(input_ids=input_ids)
337-
339+
torch.cuda.synchronize()
340+
if verbose:
341+
t1 = time.time()
338342
for k in range(max_tokens):
339343
if temperature < 0.1:
340344
input_ids = logits.argmax(dim=-1)
@@ -344,6 +348,12 @@ def generate(self,
344348
generated.append(input_ids[0].item())
345349
if input_ids[0].item() in self.eos_tokens:
346350
break
351+
if verbose:
352+
torch.cuda.synchronize()
353+
t2 = time.time()
354+
print("\033[94m[INFO] Prefill {} tokens\033[0m".format(prefix_len))
355+
print("\033[94m[INFO] Generate {} tokens\033[0m".format(len(generated)))
356+
print("\033[94m[INFO] Decoding Latency {:.2f} ms/token\033[0m".format(1000 * (t2 - t1)/len(generated)))
347357
self.attention_server.clear()
348358
self.k_cache.zero_()
349359
self.v_cache.zero_()

models/template.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414

1515
Templates = {
1616
'meta-llama2': "[INST] {} [/INST]",
17-
'meta-llama3': "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n",
17+
'meta-llama3': "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n",
18+
'None': "{}",
1819
}

0 commit comments

Comments
 (0)