-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
68 lines (64 loc) · 2.7 KB
/
train.py
File metadata and controls
68 lines (64 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/python
import os
from argparse import ArgumentParser
from sklearn.externals import joblib
from tictacs import from_recipe
from pan import ProfilingDataset, createDocProfiles, create_target_prof_trainset
# import dill
#import cPickle as pickle
# from sklearn.neighbors import KNeighborsClassifier
# from sklearn.metrics import accuracy_score, confusion_matrix
if __name__ == '__main__':
parser = ArgumentParser(description='Train pan model on pan dataset')
parser.add_argument('-i', '--input', type=str,
required=True, dest='infolder',
help='path to folder with pan dataset for a language')
parser.add_argument('-o', '--output', type=str,
required=True, dest='outfolder',
help='path to folder where model should be written')
args = parser.parse_args()
infolder = args.infolder
outfolder = args.outfolder
print('Loading dataset->Grouping User texts.\n')
dataset = ProfilingDataset(infolder)
print('Loaded {} users...\n'.format(len(dataset.entries)))
# get config
config = dataset.config
tasks = config.tasks
print('\n--------------- Thy time of Running ---------------')
all_models = {}
for task in tasks:
print('Learning to judge %s..' % task)
# load data
#X, y = dataset.get_data(task)
docs = createDocProfiles(dataset)
X, y = create_target_prof_trainset(docs, task)
tictac = from_recipe(config.recipes[task])
outline = ""
for step in tictac.steps:
if step[0] == "features":
# print type(step[1])
for tf in step[1].transformer_list:
# print type(tf[1])
# print type(tf[1].get_params())
outline += tf[0] + " with Params:[" + str(tf[1].get_params()) + "]+"
else:
# if hasattr(step[1], 'get_params'):
# outline += step[0] + " with Params:[" + str(step[1].get_params()) + "]+"
# else:
# outline += step[0]+ "+"
outline += step[0] + "+"
outline = outline[:-1] + "\n"
print('Task:{}, Pipeline:{}'.format(task, outline))
all_models[task] = tictac.fit(X, y)
modelfile = os.path.join(outfolder, '%s_fin.bin' % dataset.lang)
print('Writing model to {}'.format(modelfile))
#fo = open(modelfile, 'wb')
#import pprint
#print type(all_models)
#print modelfile
#dill.dump(all_models, fo, protocol=pickle.HIGHEST_PROTOCOL)
#fo.close()
#pickle.dump(all_models, modelfile)
# dill.dump(all_models, modelfile)
joblib.dump(all_models, modelfile, compress=3)