From 5a2ccfa15e0df0ac40b774965c71b89acee2a395 Mon Sep 17 00:00:00 2001 From: nima nikoonazar Date: Thu, 21 Aug 2025 12:48:53 +0330 Subject: [PATCH] feat(checkpoint-loading): ensure models are moved to GPU before loading state_dict Moved `prefill_model` and `decode_model` to the target device before calling `load_state_dict` to avoid redundant tensor transfers by PyTorch. --- gpu/generate.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/gpu/generate.py b/gpu/generate.py index 638ed7b3..f6b78dc7 100755 --- a/gpu/generate.py +++ b/gpu/generate.py @@ -55,21 +55,38 @@ def build( model_args_prefill = fast.ModelArgs(use_kernel=False) model_args_decode = fast.ModelArgs(use_kernel=True) - tokenizer = Tokenizer("./tokenizer.model") + # Load tokenizer (either provided path or default local tokenizer.model) + tokenizer = Tokenizer(tokenizer_path or "./tokenizer.model") + + # Set default device and dtype globally for PyTorch ops torch.set_default_device(device) torch.set_default_dtype(torch.bfloat16) - prefill_model = fast.Transformer(model_args_prefill) - decode_model = fast.Transformer(model_args_decode) - - fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") - fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu") - int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") - int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu") + # Initialize models directly on the target device (avoids extra CPU → GPU transfers later) + prefill_model = fast.Transformer(model_args_prefill).to(device) + decode_model = fast.Transformer(model_args_decode).to(device) + + # Checkpoint paths + fp16_ckpt_path = Path(ckpt_dir) / "model_state_fp16.pt" + int2_ckpt_path = Path(ckpt_dir) / "model_state_int2.pt" + + # Load checkpoints directly on the target device. + # Prefer weights_only=True (PyTorch >=2.0) to avoid unnecessary metadata load, + # fallback to normal torch.load for older versions. + try: + fp16_checkpoint = torch.load(fp16_ckpt_path, map_location=device, weights_only=True) + int2_checkpoint = torch.load(int2_ckpt_path, map_location=device, weights_only=True) + except TypeError: + fp16_checkpoint = torch.load(fp16_ckpt_path, map_location=device) + int2_checkpoint = torch.load(int2_ckpt_path, map_location=device) + + # Load state dicts into models. Since models are already on GPU, + # this avoids extra device transfers during loading. prefill_model.load_state_dict(fp16_checkpoint, strict=True) decode_model.load_state_dict(int2_checkpoint, strict=True) + # Synchronize to ensure all GPU ops are complete before timing torch.cuda.synchronize() print(f"loaded model in {time.time() - start_time:.2f} seconds") start_time = time.time()