Skip to content

Commit cdb3191

Browse files
committed
Make number of layers adjustable
1 parent 3c816ed commit cdb3191

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

sota_extractor2/models/structure/ulmfit_experiment.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from .experiment import Experiment
1+
from .experiment import Experiment, label_map_ext
22
from .nbsvm import preds_for_cell_content, preds_for_cell_content_max, preds_for_cell_content_multi
33
import dataclasses
44
from dataclasses import dataclass
55
from typing import Tuple
66
from sota_extractor2.helpers.training import set_seed
77
from fastai.text import *
8+
from fastai.text.learner import _model_meta
89
import numpy as np
910
from pathlib import Path
1011
import json
@@ -24,6 +25,9 @@ class ULMFiTExperiment(Experiment):
2425
dataset: str = None
2526
train_on_easy: bool = True
2627
BS: int = 64
28+
valid_split: str = 'speech_rec'
29+
test_split: str = 'img_class'
30+
n_layers: int = 3
2731

2832
has_predictions: bool = False # similar to has_model, but to avoid storing pretrained models we only keep predictions
2933
# that can be later used by CRF
@@ -64,7 +68,10 @@ def _add_phase(self, state):
6468
# todo: make it compatible with Experiment
6569
def train_model(self, data_clas):
6670
set_seed(self.seed, "clas")
67-
clas = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=self.drop_mult)
71+
cfg = _model_meta[AWD_LSTM]['config_clas'].copy()
72+
cfg['n_layers'] = self.n_layers
73+
74+
clas = text_classifier_learner(data_clas, AWD_LSTM, config=cfg, drop_mult=self.drop_mult)
6875
clas.load_encoder(self.pretrained_lm)
6976
if self.fp16:
7077
clas = clas.to_fp16()
@@ -124,5 +131,6 @@ def evaluate(self, model, train_df, valid_df, test_df):
124131
true_y = vote_results["true"]
125132
else:
126133
true_y = tdf["label"]
127-
self._set_results(prefix, preds, true_y)
134+
true_y_ext = tdf["cell_type"].apply(lambda x: label_map_ext.get(x, 0))
135+
self._set_results(prefix, preds, true_y, true_y_ext)
128136
self._preds.append(probs)

0 commit comments

Comments
 (0)