forked from aldolipani/TABME
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
executable file
·86 lines (75 loc) · 2.05 KB
/
predict.py
File metadata and controls
executable file
·86 lines (75 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python
import argparse
from pathlib import Path
import pandas as pd
from utils.config import config
from utils.dataset import TABME
from utils.evaluation import predict
from utils.train import train
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Get prediction from the model'
)
parser.add_argument(
'-d',
'--data',
type=str,
help='path to output folder, should be ./data/test'
)
parser.add_argument(
'-m',
'--model',
help='path to model folder, should be inside ./output/',
type=str,
)
parser.add_argument(
'--csv',
type=str,
help='path to csv specifying folders, should be inside ./predictions',
)
parser.add_argument(
'-c',
'--cache',
type=str,
required=False,
default='./cache',
help="path to where to save preprocessed cache, default='./cache'",
)
parser.add_argument(
'-a',
'--ablation',
type=str,
required=False,
help="optional, specify ablation choose from 'resnet' or 'layoutlm',"
" default: no ablation",
)
parser.add_argument(
'-n',
'--num_hidden_layers',
type=int,
required=False,
help='optional, specify num_hidden_layers in integer',
)
args = parser.parse_args()
path_data = args.data
path_model_folder = args.model
path_csv = args.csv
ablation = args.ablation
path_cache_folder = args.cache
num_hidden_layers = args.num_hidden_layers
if num_hidden_layers:
spacing = int((1280 - 2) / (num_hidden_layers + 1))
num_hidden_features = [
2 + spacing * (n + 1) for n in reversed(range(num_hidden_layers))
]
else:
num_hidden_features = None
df = predict(
path_data,
path_model_folder,
path_csv,
batch_size=64,
path_cache_folder=path_cache_folder,
num_hidden_features=num_hidden_features,
ablation=ablation,
)