1
- from .experiment import Experiment
1
+ from .experiment import Experiment , label_map_ext
2
2
from .nbsvm import preds_for_cell_content , preds_for_cell_content_max , preds_for_cell_content_multi
3
3
import dataclasses
4
4
from dataclasses import dataclass
5
5
from typing import Tuple
6
6
from sota_extractor2 .helpers .training import set_seed
7
7
from fastai .text import *
8
+ from fastai .text .learner import _model_meta
8
9
import numpy as np
9
10
from pathlib import Path
10
11
import json
@@ -24,6 +25,9 @@ class ULMFiTExperiment(Experiment):
24
25
dataset : str = None
25
26
train_on_easy : bool = True
26
27
BS : int = 64
28
+ valid_split : str = 'speech_rec'
29
+ test_split : str = 'img_class'
30
+ n_layers : int = 3
27
31
28
32
has_predictions : bool = False # similar to has_model, but to avoid storing pretrained models we only keep predictions
29
33
# that can be later used by CRF
@@ -64,7 +68,10 @@ def _add_phase(self, state):
64
68
# todo: make it compatible with Experiment
65
69
def train_model (self , data_clas ):
66
70
set_seed (self .seed , "clas" )
67
- clas = text_classifier_learner (data_clas , AWD_LSTM , drop_mult = self .drop_mult )
71
+ cfg = _model_meta [AWD_LSTM ]['config_clas' ].copy ()
72
+ cfg ['n_layers' ] = self .n_layers
73
+
74
+ clas = text_classifier_learner (data_clas , AWD_LSTM , config = cfg , drop_mult = self .drop_mult )
68
75
clas .load_encoder (self .pretrained_lm )
69
76
if self .fp16 :
70
77
clas = clas .to_fp16 ()
@@ -124,5 +131,6 @@ def evaluate(self, model, train_df, valid_df, test_df):
124
131
true_y = vote_results ["true" ]
125
132
else :
126
133
true_y = tdf ["label" ]
127
- self ._set_results (prefix , preds , true_y )
134
+ true_y_ext = tdf ["cell_type" ].apply (lambda x : label_map_ext .get (x , 0 ))
135
+ self ._set_results (prefix , preds , true_y , true_y_ext )
128
136
self ._preds .append (probs )
0 commit comments