Skip to content

Commit 37687c0

Browse files
committed
Move dataset loading and transforming to new class DataBunch
1 parent 768ce00 commit 37687c0

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

sota_extractor2/models/structure/__init__.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import re
12
import numpy as np
3+
import pandas as pd
24
from ...helpers.training import set_seed
3-
5+
from ... import config
46

57
def split_by_cell_content(df, seed=42, split_column="cell_content"):
68
set_seed(seed, "val_split")
@@ -12,3 +14,52 @@ def split_by_cell_content(df, seed=42, split_column="cell_content"):
1214
train_df = df[~split]
1315
len(train_df), len(valid_df)
1416
return train_df, valid_df
17+
18+
19+
label_map_4 = {
20+
"model-paper": 1,
21+
"model-best": 1,
22+
"model-competing": 2,
23+
"dataset": 3,
24+
"dataset-sub": 3,
25+
"dataset-task": 3,
26+
}
27+
28+
29+
label_map_3 = {
30+
"model-paper": 1,
31+
"model-best": 1,
32+
"model-competing": 2,
33+
}
34+
35+
label_map_2 = {
36+
"model-paper": 1,
37+
"model-best": 1,
38+
"model-competing": 1,
39+
}
40+
41+
42+
class DataBunch:
43+
def __init__(self, train_name, test_name, label_map):
44+
self.label_map = label_map
45+
self.train_df = pd.read_csv(config.datasets_structure/train_name)
46+
self.test_df = pd.read_csv(config.datasets_structure/test_name)
47+
self.transform(self.normalize)
48+
self.transform(self.label)
49+
50+
def transform(self, fun):
51+
self.train_df = fun(self.train_df)
52+
self.test_df = fun(self.test_df)
53+
54+
def normalize(self, df):
55+
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
56+
df = df.replace(re.compile(r"(xxref|xxanchor)-[\w\d-]*"), "\\1 ")
57+
df = df.replace(re.compile(r"(^|[ ])\d+\.\d+\b"), " xxnum ")
58+
df = df.replace(re.compile(r"(^|[ ])\d\b"), " xxnum ")
59+
df = df.replace(re.compile(r"\bdata set\b"), " dataset ")
60+
return df
61+
62+
def label(self, df):
63+
df["label"] = df["cell_type"].apply(lambda x: self.label_map.get(x, 0))
64+
df["label"] = pd.Categorical(df["label"])
65+
return df

0 commit comments

Comments
 (0)