Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions gpu/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,38 @@ def build(

model_args_prefill = fast.ModelArgs(use_kernel=False)
model_args_decode = fast.ModelArgs(use_kernel=True)
tokenizer = Tokenizer("./tokenizer.model")

# Load tokenizer (either provided path or default local tokenizer.model)
tokenizer = Tokenizer(tokenizer_path or "./tokenizer.model")

# Set default device and dtype globally for PyTorch ops
torch.set_default_device(device)
torch.set_default_dtype(torch.bfloat16)

prefill_model = fast.Transformer(model_args_prefill)
decode_model = fast.Transformer(model_args_decode)

fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
# Initialize models directly on the target device (avoids extra CPU → GPU transfers later)
prefill_model = fast.Transformer(model_args_prefill).to(device)
decode_model = fast.Transformer(model_args_decode).to(device)

# Checkpoint paths
fp16_ckpt_path = Path(ckpt_dir) / "model_state_fp16.pt"
int2_ckpt_path = Path(ckpt_dir) / "model_state_int2.pt"

# Load checkpoints directly on the target device.
# Prefer weights_only=True (PyTorch >=2.0) to avoid unnecessary metadata load,
# fallback to normal torch.load for older versions.
try:
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location=device, weights_only=True)
int2_checkpoint = torch.load(int2_ckpt_path, map_location=device, weights_only=True)
except TypeError:
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location=device)
int2_checkpoint = torch.load(int2_ckpt_path, map_location=device)

# Load state dicts into models. Since models are already on GPU,
# this avoids extra device transfers during loading.
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
decode_model.load_state_dict(int2_checkpoint, strict=True)

# Synchronize to ensure all GPU ops are complete before timing
torch.cuda.synchronize()
print(f"loaded model in {time.time() - start_time:.2f} seconds")
start_time = time.time()
Expand Down