forked from Tencent/NeuralNLP-NeuralClassifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredictor.py
More file actions
96 lines (86 loc) · 4.16 KB
/
predictor.py
File metadata and controls
96 lines (86 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import codecs
import math
import numpy as np
import os
import sys
import json
import torch
from torch.utils.data import DataLoader
from config import Config
from dataset.classification_dataset import ClassificationDataset
from dataset.collator import ClassificationCollator
from dataset.collator import ClassificationType
from dataset.collator import FastTextCollator
from model.classification.drnn import DRNN
from model.classification.fasttext import FastText
from model.classification.textcnn import TextCNN
from model.classification.textvdcnn import TextVDCNN
from model.classification.textrnn import TextRNN
from model.classification.textrcnn import TextRCNN
from model.classification.transformer import Transformer
from model.classification.dpcnn import DPCNN
from model.classification.attentive_convolution import AttentiveConvNet
from model.classification.region_embedding import RegionEmbedding
from model.model_util import get_optimizer, get_hierar_relations
ClassificationDataset, ClassificationCollator, FastTextCollator,FastText, TextCNN, TextRNN, TextRCNN, DRNN, TextVDCNN, Transformer, DPCNN, AttentiveConvNet, RegionEmbedding
class Predictor(object):
def __init__(self, profile):
self.config = Config(config_file=profile)
self.model_name = self.config.model_name
self.use_cuda = self.config.device.startswith("cuda")
self.dataset_name = "ClassificationDataset"
self.collate_name = "FastTextCollator" if self.model_name == "FastText" \
else "ClassificationCollator"
self.dataset = globals()[self.dataset_name](self.config, [], mode="infer")
self.collate_fn = globals()[self.collate_name](self.config, len(self.dataset.label_map))
self.model = Predictor._get_classification_model(self.model_name, self.dataset, self.config)
Predictor._load_checkpoint(self.config.eval.model_dir, self.model, self.use_cuda)
self.model.eval()
self.batch_size = self.config.eval.batch_size
@staticmethod
def _get_classification_model(model_name, dataset, conf):
model = globals()[model_name](dataset, conf)
model = model.cuda(conf.device) if conf.device.startswith("cuda") else model
return model
@staticmethod
def _load_checkpoint(file_name, model, use_cuda):
if use_cuda:
checkpoint = torch.load(file_name)
else:
checkpoint = torch.load(file_name, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint["state_dict"])
def predict(self, texts):
"""
input texts should be json objects
"""
with torch.no_grad():
batch_texts = [self.dataset._get_vocab_id_list(json.loads(text)) for text in texts]
batch_texts = self.collate_fn(batch_texts)
logits = self.model(batch_texts)
if self.config.task_info.label_type != ClassificationType.MULTI_LABEL:
probs = torch.softmax(logits, dim=1)
else:
probs = torch.sigmoid(logits)
probs = probs.cpu().tolist()
return np.array(probs)
def predict_batch(self, input_texts):
is_multi = self.config.task_info.label_type == ClassificationType.MULTI_LABEL
predict_probs = []
epoches = math.ceil(len(input_texts)/self.batch_size)
for i in range(epoches):
batch_texts = input_texts[i*self.batch_size:(i+1)*self.batch_size]
predict_prob = self.predict(batch_texts)
for j in predict_prob:
predict_probs.append(j)
predict_label_namez=[]
for predict_prob in predict_probs:
if not is_multi:
predict_label_ids = [predict_prob.argmax()]
else:
predict_label_ids = []
predict_label_idx = np.argsort(-predict_prob)
for j in range(0, self.config.eval.top_k):
if predict_prob[predict_label_idx[j]] > self.config.eval.threshold:
predict_label_ids.append(predict_label_idx[j])
predict_label_namez += [self.dataset.id_to_label_map[predict_label_id] for predict_label_id in predict_label_ids]
return predict_label_namez