-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
38 lines (31 loc) · 1.16 KB
/
inference.py
File metadata and controls
38 lines (31 loc) · 1.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import typing
import numpy as np
from model.model import Encoder, Decoder, Seq2Seq
from config import MODEL_CONF
from utils import embedding
class ModelInference(object):
def __init__(self, weights_file: str, device: str = "cpu") -> None:
self.weights_file = weights_file
self.device = device
self.encoder = Encoder(
MODEL_CONF["encoder"]["input_dim"],
MODEL_CONF["encoder"]["hidden_dim"],
MODEL_CONF["encoder"]["n_layer"],
MODEL_CONF["encoder"]["dropout"],
)
self.decoder = Decoder(
MODEL_CONF["decoder"]["input_dim"],
MODEL_CONF["decoder"]["output_dim"],
MODEL_CONF["decoder"]["hidden_dim"],
MODEL_CONF["decoder"]["n_layer"],
MODEL_CONF["decoder"]["dropout"],
)
self.net = Seq2Seq(self.encoder, self.decoder, self.device)
# load weight
self.net.load_state_dict(
torch.load(self.weights_file, map_location=self.device)
)
print("weights loaded.")
def tck2emb(self, streamlines: typing.List) -> np.array:
return embedding(self.net, streamlines)