diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index dfaaa6f..fb0e519 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -4,11 +4,19 @@ # torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/ import argparse +import os +from pathlib import Path from gpt_oss.tokenizer import get_tokenizer -def main(args): +def main(args: argparse.Namespace) -> None: + # Validate checkpoint path exists for backends that need local files + if args.backend in ["torch", "triton"]: + checkpoint_path = Path(args.checkpoint) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint path does not exist: {args.checkpoint}") + match args.backend: case "torch": from gpt_oss.torch.utils import init_distributed diff --git a/gpt_oss/tokenizer.py b/gpt_oss/tokenizer.py index 866077f..05c1080 100644 --- a/gpt_oss/tokenizer.py +++ b/gpt_oss/tokenizer.py @@ -1,6 +1,7 @@ import tiktoken -def get_tokenizer(): + +def get_tokenizer() -> tiktoken.Encoding: o200k_base = tiktoken.get_encoding("o200k_base") tokenizer = tiktoken.Encoding( name="o200k_harmony", diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py index ce87a85..680f4af 100644 --- a/gpt_oss/torch/utils.py +++ b/gpt_oss/torch/utils.py @@ -3,7 +3,7 @@ import torch.distributed as dist -def suppress_output(rank): +def suppress_output(rank: int) -> None: """Suppress printing on the current device. Force printing with `force=True`.""" import builtins as __builtin__ builtin_print = __builtin__.print