-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
37 lines (30 loc) · 1.27 KB
/
inference.py
File metadata and controls
37 lines (30 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from transformers import AutoTokenizer
import torch, json
from train_slm import SmallLanguageModel
from config_slm import SLMConfig
# Load tokenizer (the one saved after training)
tokenizer = AutoTokenizer.from_pretrained("./trained_slm")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🧠 Using device: {device}")
# Load model config + weights
with open("./trained_slm/config.json") as f:
model_cfg = json.load(f)
cfg = SLMConfig()
cfg.model.vocab_size = model_cfg["vocab_size"]
cfg.model.d_model = model_cfg["d_model"]
cfg.model.n_layers = model_cfg["n_layers"]
cfg.model.n_heads_per_layer = model_cfg["n_heads_per_layer"]
cfg.model.block_size = model_cfg["block_size"]
cfg.model.dropout = model_cfg["dropout"]
model = SmallLanguageModel(cfg)
state_dict = torch.load("./trained_slm/pytorch_model.bin", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
# Generate
text = "I am going home and I found"
inputs = tokenizer(text, return_tensors="pt")
for _ in range(50):
logits, _ = model(inputs["input_ids"])
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token], dim=1)
print(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True))