@@ -34,9 +34,21 @@ class Labels(Enum):
34
34
# put here to avoid recompiling, used only in _limit_context
35
35
elastic_tag_split_re = re .compile ("(<b>.*?</b>)" )
36
36
37
+ # e = Experiment(remove_num=False, drop_duplicates=False, vectorizer='count',
38
+ # this_paper=True, merge_fragments=True, merge_type='concat',
39
+ # evidence_source='text_highlited', split_btags=True, fixed_tokenizer=True,
40
+ # fixed_this_paper=True, mask=False, evidence_limit=None, context_tokens=None,
41
+ # analyzer='word', lowercase=True, class_weight='balanced', multinomial_type='multinomial',
42
+ # solver='lbfgs', C=0.1, dual=False, penalty='l2', ngram_range=[1, 3],
43
+ # min_df=10, max_df=0.9, max_iter=1000, results={}, has_model=False)
44
+
45
+ # ULMFiT related parameters
46
+ # remove_num, drop_duplicates, this_paper, merge_fragments, merge_type, evidence_source, split_btags
47
+ # fixed_tokenizer?, fixed_this_paper (remove), mask, evidence_limit, context_tokens, lowercase
48
+ # class_weight? (consider adding support),
49
+
37
50
@dataclass
38
51
class Experiment :
39
- vectorizer : str = "tfidf"
40
52
this_paper : bool = False
41
53
merge_fragments : bool = False
42
54
merge_type : str = "concat" # "concat", "vote_maj", "vote_avg", "vote_max"
@@ -47,23 +59,11 @@ class Experiment:
47
59
mask : bool = False # if True and evidence_source = "text_highlited", replace <b>...</b> with xxmask
48
60
evidence_limit : int = None # maximum number of evidences per cell (grouped by (ext_id, this_paper))
49
61
context_tokens : int = None # max. number of words before <b> and after </b>
50
- analyzer : str = "word" # "char", "word" or "char_wb"
51
62
lowercase : bool = True
52
63
remove_num : bool = True
53
64
drop_duplicates : bool = True
54
65
mark_this_paper : bool = False
55
66
56
- class_weight : str = None
57
- multinomial_type : str = "manual" # "manual", "ovr", "multinomial"
58
- solver : str = "liblinear" # 'lbfgs' - large, liblinear for small datasets
59
- C : float = 4.0
60
- dual : bool = True
61
- penalty : str = "l2"
62
- ngram_range : tuple = (1 , 2 )
63
- min_df : int = 3
64
- max_df : float = 0.9
65
- max_iter : int = 1000
66
-
67
67
results : dict = dataclasses .field (default_factory = dict )
68
68
69
69
has_model : bool = False # either there's already pretrained model or it's a saved experiment and there's a saved model as well
@@ -78,29 +78,39 @@ def _get_next_exp_name(self, dir_path):
78
78
return dir_path / name
79
79
raise Exception ("You have too many files in this dir, really!" )
80
80
81
- def _save_model (self , path ):
81
+ @staticmethod
82
+ def _dump_pickle (obj , path ):
82
83
with open (path , 'wb' ) as f :
83
- pickle .dump (self . _model , f )
84
+ pickle .dump (obj , f )
84
85
85
- def _load_model (self , path ):
86
+ @staticmethod
87
+ def _load_pickle (path ):
86
88
with open (path , 'rb' ) as f :
87
- self ._model = pickle .load (f )
88
- return self ._model
89
+ return pickle .load (f )
90
+
91
+ def _save_model (self , path ):
92
+ self ._dump_pickle (self ._model , path )
93
+
94
+ def _load_model (self , path ):
95
+ self ._model = self ._load_pickle (path )
96
+ return self ._model
89
97
90
98
def load_model (self ):
91
99
path = self ._path .parent / f"{ self ._path .stem } .model"
92
100
return self ._load_model (path )
93
101
102
+ def save_model (self , path ):
103
+ if hasattr (self , "_model" ):
104
+ self ._save_model (path )
105
+
94
106
def save (self , dir_path ):
95
107
dir_path = Path (dir_path )
96
108
dir_path .mkdir (exist_ok = True , parents = True )
97
109
filename = self ._get_next_exp_name (dir_path )
98
110
j = dataclasses .asdict (self )
99
111
with open (filename , "wt" ) as f :
100
112
json .dump (j , f )
101
- if hasattr (self , "_model" ):
102
- fn = filename .stem
103
- self ._save_model (dir_path / f"{ fn } .model" )
113
+ self .save_model (dir_path / f"{ filename .stem } .model" )
104
114
return filename .name
105
115
106
116
def to_df (self ):
@@ -119,12 +129,13 @@ def new_experiment(self, **kwargs):
119
129
def update_results (self , ** kwargs ):
120
130
self .results .update (** kwargs )
121
131
122
- def get_trained_model (self , train_df ):
123
- nbsvm = NBSVM (experiment = self )
124
- nbsvm .fit (train_df ["text" ], train_df ["label" ])
125
- self ._model = nbsvm
132
+ def train_model (self , train_df , valid_df ):
133
+ raise NotImplementedError ("train_model should be implemented in subclass" )
134
+
135
+ def get_trained_model (self , train_df , valid_df ):
136
+ self ._model = self .train_model (train_df , valid_df )
126
137
self .has_model = True
127
- return nbsvm
138
+ return self . _model
128
139
129
140
def _limit_context (self , text ):
130
141
parts = elastic_tag_split_re .split (text )
@@ -301,3 +312,23 @@ def experiments_to_df(cls, exps):
301
312
dfs = [e .to_df () for e in exps ]
302
313
df = pd .concat (dfs )
303
314
return df
315
+
316
+ @dataclass
317
+ class NBSVMExperiment (Experiment ):
318
+ vectorizer : str = "tfidf"
319
+ analyzer : str = "word" # "char", "word" or "char_wb"
320
+ class_weight : str = None
321
+ multinomial_type : str = "manual" # "manual", "ovr", "multinomial"
322
+ solver : str = "liblinear" # 'lbfgs' - large, liblinear for small datasets
323
+ C : float = 4.0
324
+ dual : bool = True
325
+ penalty : str = "l2"
326
+ ngram_range : tuple = (1 , 2 )
327
+ min_df : int = 3
328
+ max_df : float = 0.9
329
+ max_iter : int = 1000
330
+
331
+ def train_model (self , train_df , valid_df = None ):
332
+ nbsvm = NBSVM (experiment = self )
333
+ nbsvm .fit (train_df ["text" ], train_df ["label" ])
334
+ return nbsvm
0 commit comments