-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathT3.inference.py
More file actions
39 lines (31 loc) · 898 Bytes
/
T3.inference.py
File metadata and controls
39 lines (31 loc) · 898 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from T1_model import Model
# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_SEED = 1337
torch.manual_seed(TORCH_SEED)
torch.cuda.manual_seed(TORCH_SEED)
encoding = tiktoken.get_encoding("cl100k_base")
# Initiate from trained model
model = Model()
model.load_state_dict(torch.load('model/model-scifi.pt'))
model.eval()
model.to(device)
# start = 'Write a short story about Sam Altman.'
start = 'Sam Altman was born in'
start_ids = encoding.encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
# run generation
with torch.no_grad():
y = model.generate(x, max_new_tokens=500)
print('---------------')
print(encoding.decode(y[0].tolist()))
print('---------------')