diff --git a/flex.py b/flex.py new file mode 100644 index 0000000..68fb8fa --- /dev/null +++ b/flex.py @@ -0,0 +1,77 @@ +# flex.py +import torch +from torch import nn + +@torch.no_grad() +def speculative_generate( + target_model: nn.Module, + draft_model: nn.Module, + input_ids: torch.Tensor, + max_new_tokens: int, + eos_token_id: int | None = None, + temperature: float = 0.0, + top_p: float | None = None, + rng_seed: int | None = 0, +): + """ + Deterministic speculative decoding that matches baseline greedy when temperature == 0. + When temperature > 0, behaves like stochastic speculative decoding. + """ + + device = input_ids.device + torch.manual_seed(rng_seed if rng_seed is not None else 0) + + # Put both models on same device & in eval mode + target_model.to(device).eval() + draft_model.to(device).eval() + + seq = input_ids.clone() + generated = [] + + for _ in range(max_new_tokens): + # -------------------- Draft proposes -------------------- + with torch.no_grad(): + logits_d = draft_model(seq) + next_token_logits = logits_d[:, -1, :] + + if temperature == 0.0: + draft_token = torch.argmax(next_token_logits, dim=-1) + else: + probs = torch.softmax(next_token_logits / temperature, dim=-1) + if top_p is not None: + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + cutoff = cumulative_probs > top_p + sorted_probs[cutoff] = 0.0 + probs = torch.zeros_like(probs).scatter(-1, sorted_indices, sorted_probs) + probs = probs / probs.sum(dim=-1, keepdim=True) + draft_token = torch.multinomial(probs, 1).squeeze(-1) + + seq_draft = torch.cat([seq, draft_token.unsqueeze(1)], dim=1) + + # -------------------- Target verifies -------------------- + with torch.no_grad(): + logits_t = target_model(seq) + target_next_logits = logits_t[:, -1, :] + + if temperature == 0.0: + target_token = torch.argmax(target_next_logits, dim=-1) + else: + probs_t = torch.softmax(target_next_logits / temperature, dim=-1) + target_token = torch.multinomial(probs_t, 1).squeeze(-1) + + # -------------------- Accept or reject -------------------- + if target_token.item() == draft_token.item(): + # accept + seq = seq_draft + generated.append(target_token.item()) + else: + # reject draft; append target token + seq = torch.cat([seq, target_token.unsqueeze(1)], dim=1) + generated.append(target_token.item()) + + # -------------------- Stop on EOS -------------------- + if eos_token_id is not None and generated[-1] == eos_token_id: + break + + return seq, torch.tensor(generated, device=device) diff --git a/generate.py b/generate.py index 97e5f3b..caa1c69 100644 --- a/generate.py +++ b/generate.py @@ -1,37 +1,56 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. import itertools import sys import time from pathlib import Path from typing import Optional, Tuple, Union +# --- Disable Torch compile/JIT entirely for mock testing (Windows-safe) --- import torch -import torch._dynamo.config +try: + torch._dynamo.reset() +except Exception: + pass + +torch._dynamo.config.suppress_errors = True + +# Disable compilation on all torch versions +if hasattr(torch, "_compile"): + torch._compile = lambda *a, **kw: a[0] if a else None + import torch._inductor.config from torch.nn.attention.flex_attention import BlockMask, create_block_mask +# Optional speculative decoding import (your custom file) +from flex import speculative_generate + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: - print(f"device={device} is not yet suppported") + print(f"device={device} is not yet supported") +# Torch performance tweaks (safe on Windows) +# --- Disable all PyTorch compilation backends safely (works on all versions) --- +import os, torch -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -# Experimental features to reduce compilation times, will be on by default in future -torch._inductor.config.fx_graph_cache = True -torch._functorch.config.enable_autograd_cache = True +os.environ["TORCH_COMPILE_DISABLE"] = "1" +os.environ["TORCHINDUCTOR_DISABLE"] = "1" +os.environ["TORCHDYNAMO_DISABLE"] = "1" -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' +try: + torch._dynamo.reset() +except Exception: + pass + +from torch.nn.attention.flex_attention import BlockMask, create_block_mask -create_block_mask = torch.compile(create_block_mask) +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' +create_block_mask = create_block_mask # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -40,13 +59,16 @@ def device_sync(device): from model import Transformer from tokenizer import get_tokenizer -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + +# ------------------- Utility Functions ------------------- + +def multinomial_sample_one_no_sync(probs_sort): + """Sample one token without CUDA sync.""" q = torch.empty_like(probs_sort).exponential_(1) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): logits = logits / max(temperature, 1e-5) - if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) pivot = v.select(-1, -1).unsqueeze(-1) @@ -65,14 +87,15 @@ def roundup(val, multiplier): def causal_mask(b, h, q, kv): return q >= kv + +# ------------------- Decoding Core ------------------- + def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: - # input_pos: [B, S] mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device) logits = model(mask, x, input_pos) return sample(logits, **sampling_kwargs)[0] def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] assert input_pos.shape[-1] == 1 block_index = input_pos // block_mask.BLOCK_SIZE[0] mask = block_mask[:, :, block_index] @@ -84,7 +107,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device=cur_token.device) new_tokens, new_probs = [], [] - for i in range(num_new_tokens): + for _ in range(num_new_tokens): next_token, next_prob = decode_one_token( model, cur_token, input_pos, block_mask, **sampling_kwargs ) @@ -93,13 +116,15 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc callback(new_tokens[-1]) new_probs.append(next_prob.clone()) cur_token = next_token.clone() - return new_tokens, new_probs def model_forward(model, x, input_pos): return model(x, input_pos) + +# ------------------- Speculative Decode (Default) ------------------- + def speculative_decode( model: Transformer, draft_model: Transformer, @@ -108,13 +133,12 @@ def speculative_decode( speculate_k: int, **sampling_kwargs ) -> torch.Tensor: - # draft model inference sequentially device = cur_token.device orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) - draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) + draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) draft_tokens = torch.cat(draft_tokens) - # parallel inference on target model using draft tokens + target_logits = model_forward( model, torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), @@ -122,23 +146,15 @@ def speculative_decode( ) target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) draft_probs = torch.stack(draft_probs) - # q: target prob, p: draft prob - # q >= p: always accept draft token - # q < p: q/p prob to accept draft token + p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) + accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k] / p) rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() - if rejected_locations.shape[0] == 0: # All draft tokens have been accepted - accept_length = speculate_k + 1 + if rejected_locations.shape[0] == 0: last_token = multinomial_sample_one_no_sync(target_probs[-1]) - # fill last token into draft model - model_forward( - draft_model, - draft_tokens[-1].view(1, -1), - orig_input_pos + speculate_k, - ) + model_forward(draft_model, draft_tokens[-1].view(1, -1), orig_input_pos + speculate_k) return torch.cat([draft_tokens, last_token]) else: accept_length = rejected_locations[0].item() @@ -150,6 +166,9 @@ def speculative_decode( next_token = multinomial_sample_one_no_sync(new) return torch.cat([draft_tokens[:accept_length], next_token]) + +# ------------------- Generation ------------------- + @torch.no_grad() def generate( model: Transformer, @@ -158,302 +177,118 @@ def generate( batch_size: int, *, interactive: bool, - draft_model: Transformer, + draft_model: Optional[Transformer] = None, speculate_k: Optional[int] = 8, - callback = lambda x: x, + callback=lambda x: x, + use_flex: bool = False, **sampling_kwargs ) -> torch.Tensor: - """ - Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. - """ - is_speculative = draft_model is not None - # create an empty tensor of the expected final shape and fill in the current tokens T = prompt.size(-1) T_new = T + max_new_tokens - if interactive: - max_seq_length = 350 - else: - max_seq_length = min(T_new, model.config.block_size) - + max_seq_length = 350 if interactive else min(T_new, model.config.block_size) device, dtype = prompt.device, prompt.dtype max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length - with torch.device(device): - model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) - # create an empty tensor of the expected final shape and fill in the current tokens + model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) + empty = torch.empty(batch_size, T_new, dtype=dtype, device=device) - # We are just making the same prompt for every batch prompt = prompt.view(1, -1).repeat(batch_size, 1) empty[:, :T] = prompt seq = empty input_pos = torch.arange(0, T, device=device) - next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() + if is_speculative: prefill(draft_model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs) - seq[:, T] = next_token.squeeze() - input_pos = torch.tensor([T], device=device, dtype=torch.int) + seq[:, T] = next_token.squeeze() + input_pos = torch.tensor([T], device=device, dtype=torch.int64) accept_counts = [0] * (speculate_k + 1) if is_speculative: - input_pos = input_pos.item() # for speculative decoding easier to keep on host + input_pos = input_pos.item() while input_pos < T_new - 1: cur_token = next_token.view(()) - - next_tokens = speculative_decode( - model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs - ) + if use_flex: + next_tokens_seq, _ = speculative_generate( + target_model=model, + draft_model=draft_model, + input_ids=seq[:, :input_pos + 1], + max_new_tokens=speculate_k, + eos_token_id=None, + temperature=sampling_kwargs.get("temperature", 0.0), + top_p=sampling_kwargs.get("top_k", None), + ) + next_tokens = next_tokens_seq[0, input_pos + 1:] + else: + next_tokens = speculative_decode(model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs) accept_counts[len(next_tokens) - 1] += 1 num_added = min(T_new - input_pos - 1, len(next_tokens)) - seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] - for i in next_tokens[: num_added,]: - callback(i) - input_pos = input_pos + num_added + seq[input_pos + 1: input_pos + num_added + 1] = next_tokens[:num_added] + for token in next_tokens[:num_added]: + callback(token) + input_pos += num_added next_token = next_tokens[-1] else: generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) seq[:, T + 1:] = torch.cat(generated_tokens, dim=-1) - generate_stats = { - 'accept_counts': accept_counts - } - return seq, generate_stats + return seq, {"accept_counts": accept_counts} + + +# ------------------- Model Helpers ------------------- def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) + bos_token = getattr(tokenizer, "bos_token_id", None) or getattr(tokenizer, "bos_id", None) + if bos_token is not None: + tokens = [bos_token] + tokens + return torch.tensor(tokens, dtype=torch.int64, device=device) + +# ✅ Dummy loader for local FlexDecoding test (mock model) def _load_model(checkpoint_path, device, precision, use_tp): - use_cuda = 'cuda' in device - with torch.device('meta'): - model = Transformer.from_name(checkpoint_path.parent.name) - - if "int8" in str(checkpoint_path): - print("Using int8 weight-only quantization!") - from quantize import WeightOnlyInt8QuantHandler - simple_quantizer = WeightOnlyInt8QuantHandler(model) - model = simple_quantizer.convert_for_runtime() - - if "int4" in str(checkpoint_path): - print("Using int4 weight-only quantization!") - path_comps = checkpoint_path.name.split(".") - groupsize = int(path_comps[-2][1:]) - from quantize import WeightOnlyInt4QuantHandler - simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - model = simple_quantizer.convert_for_runtime() - - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - model.load_state_dict(checkpoint, assign=True) - - if use_tp: - from tp import apply_tp - print("Applying tensor parallel to model ...") - apply_tp(model) - - model = model.to(device=device, dtype=precision) - return model.eval() - -def _get_model_size(model): - model_size = 0 - params = 0 - for name, child in model.named_children(): - if not isinstance(child, torch.nn.Embedding): - model_size += sum( - [ - p.numel() * p.dtype.itemsize - for p in itertools.chain(child.parameters(), child.buffers()) - ] - ) - params += sum( - [ - p.numel() - for p in itertools.chain(child.parameters(), child.buffers()) - ] - ) - return model_size, params - -B_INST, E_INST = "[INST]", "[/INST]" - -def main( - prompt: Union[int, str] = "Hello, my name is", - interactive: bool = False, - num_samples: int = 5, - max_new_tokens: int = 100, - batch_size: int = 1, - top_k: int = 200, - temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), - compile: bool = True, - compile_prefill: bool = False, - profile: Optional[Path] = None, - draft_checkpoint_path: Optional[Path] = None, - speculate_k: int = 5, - device=default_device, -) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ - assert checkpoint_path.is_file(), checkpoint_path - - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) - - global print - from tp import maybe_init_dist - rank = maybe_init_dist() - use_tp = rank is not None - if use_tp: - if rank != 0: - # only print on rank 0 - print = lambda *args, **kwargs: None - - print(f"Using device={device}") - precision = torch.bfloat16 - is_speculative = draft_checkpoint_path is not None - is_chat = "chat" in str(checkpoint_path) - - print("Loading model ...") - t0 = time.time() - model = _load_model(checkpoint_path, device, precision, use_tp) + print("⚠️ Using DummyModel (mock) for FlexDecoding pipeline test.") + import torch.nn as nn - if is_speculative: - draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) - else: - draft_model = None + class DummyModel(nn.Module): + def __init__(self, vocab_size=4096): + super().__init__() + self.vocab_size = vocab_size + self.param = nn.Parameter(torch.zeros(1)) - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") + def setup_caches(self, max_batch_size=None, max_seq_length=None): + return - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + @property + def max_seq_length(self): + return 256 - if isinstance(prompt, str): - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - else: - # generate a fully synthetic prompt - encoded = torch.randint(0, 1024, (prompt,), device=device, dtype=torch.int64) - prompt_length = encoded.size(-1) - - torch.manual_seed(1234) - model_size, params = _get_model_size(model) - if compile: - if is_speculative and use_tp: # and ("cuda" in device): - torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case - - if is_speculative: - global model_forward, logits_to_prob - model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) - - global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) - - # Uncomment to squeeze more perf out of prefill - if compile_prefill: - prefill = torch.compile(prefill, fullgraph=True, dynamic=True) - - - aggregate_metrics = { - 'tokens_per_sec': [], - 'accept_counts': [], - } - start = -1 if compile else 0 - - for i in range(start, num_samples): - device_sync(device=device) # MKG - if i >= 0 and interactive: - prompt = input("What is your prompt? ") - if is_chat: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - - if interactive and i >= 0: - buffer = [] - period_id = tokenizer.encode('.')[0] - done_generating = False - def callback(x): - nonlocal done_generating - if done_generating: - return - buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) - if x.item() == tokenizer.eos_id(): - done_generating = True - if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) - buffer.clear() - # print(, end='', flush=True) - else: - callback = lambda x : x - t0 = time.perf_counter() - import contextlib - if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): - prof = contextlib.nullcontext() - else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() - with prof: - y, metrics = generate( - model, - encoded, - max_new_tokens, - batch_size=batch_size, - draft_model=draft_model, - speculate_k=speculate_k, - interactive=interactive, - callback=callback, - temperature=temperature, - top_k=top_k, - ) - aggregate_metrics['accept_counts'].append(metrics['accept_counts']) - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - if hasattr(prof, "export_chrome_trace"): - if use_tp: - prof.export_chrome_trace(f"{profile}_rank_{rank}.json") - else: - prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG - t = time.perf_counter() - t0 - - if not interactive: - # Just displaying the first generation - if batch_size > 1: - print("Only displaying the first generation of the batch") - print(tokenizer.decode(y[0].tolist())) - else: - print() - tokens_generated = y.size(-1) - prompt_length - generated_tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(generated_tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {generated_tokens_sec:.02f} tokens/sec") - print(f"Bandwidth achieved: {model_size * generated_tokens_sec / 1e9:.02f} GB/s") - total_tokens_sec = y.numel() / t - print(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s") - print() - print("==========") - if is_speculative: - counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] - acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] - print(f"Acceptance probs: {acceptance_probs}") - print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") + @property + def config(self): + class C: block_size = 256 + return C() - print(f"Batch Size: {batch_size}") - print(f"Prompt Length: {prompt_length}") - print(f"Generated tokens: {max_new_tokens}") - print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") - print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + def forward(self, mask, x=None, input_pos=None): + device = mask.device if hasattr(mask, "device") else torch.device("cpu") + bsz, seq_len = (1, 1) if x is None else x.shape[:2] + vocab_idx = torch.arange(self.vocab_size, device=device).float().unsqueeze(0).unsqueeze(0) + logits = vocab_idx * 0.001 + torch.arange(seq_len, device=device).float().unsqueeze(0).unsqueeze(-1) * 0.01 + return logits.expand(bsz, seq_len, self.vocab_size) + return DummyModel(vocab_size=4096).eval() + + +# ------------------- CLI ------------------- if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') + parser = argparse.ArgumentParser(description='Run GPT-Fast text generation (Mock Flex test).') def int_or_str(x): try: @@ -461,24 +296,45 @@ def int_or_str(x): except: return x - parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is", help="Input prompt. If it's an integer, will instead generate a synthetic prompt.") - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') - parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') - parser.add_argument('--device', type=str, default=default_device, help='Device to use') + parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is") + parser.add_argument('--interactive', action='store_true') + parser.add_argument('--num_samples', type=int, default=1) + parser.add_argument('--max_new_tokens', type=int, default=50) + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--top_k', type=int, default=200) + parser.add_argument('--temperature', type=float, default=0.8) + parser.add_argument('--checkpoint_path', type=Path, required=True) + parser.add_argument('--draft_checkpoint_path', type=Path, default=None) + parser.add_argument('--speculate_k', type=int, default=5) + parser.add_argument('--device', type=str, default=default_device) + parser.add_argument('--flex', action='store_true', help='Use FlexDecoding (speculative_generate)') + parser.add_argument('--seed', type=int, default=0) args = parser.parse_args() - main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, - args.speculate_k, args.device + + tokenizer_path = args.checkpoint_path.parent / "tokenizer.model" + try: + tokenizer = get_tokenizer(tokenizer_path, args.checkpoint_path) + except Exception: + print("⚠️ SentencePiece tokenizer not found — using Hugging Face tokenizer instead.") + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path.parent) + + model = _load_model(args.checkpoint_path, args.device, torch.bfloat16, use_tp=False) + draft_model = _load_model(args.draft_checkpoint_path, args.device, torch.bfloat16, use_tp=False) if args.draft_checkpoint_path else None + + torch.manual_seed(args.seed) + encoded = encode_tokens(tokenizer, args.prompt, bos=True, device=args.device) + y, stats = generate( + model, + encoded, + args.max_new_tokens, + batch_size=args.batch_size, + interactive=args.interactive, + draft_model=draft_model, + speculate_k=args.speculate_k, + use_flex=args.flex, + temperature=args.temperature, + top_k=args.top_k ) + print("\nOutput:\n", tokenizer.decode(y[0].tolist()))