-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
executable file
·77 lines (57 loc) · 2.82 KB
/
inference.py
File metadata and controls
executable file
·77 lines (57 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import argparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import Pix2StructProcessor
from model import Simplot
from trainer import inference
from dataset import prepare_test_dataset, SimplotDataset, test_collator
def main(args):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
processor = Pix2StructProcessor.from_pretrained('google/deplot')
processor.image_processor.is_vqa = True
np.random.seed(args.seed)
result_path = args.result_path
os.makedirs(result_path, exist_ok = True)
gpu = args.device
# device = f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
model = Simplot(args)
# print(model)
dataset = prepare_test_dataset(args)
print("Dataset contents:", dataset)
if args.inference_type == 'QA':
test_dataset = SimplotDataset(dataset, processor, args.phase)
else:
test_dataset = SimplotDataset(dataset, processor, 4)
# print(test_dataset)
checkpoint = torch.load(args.state_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=1, collate_fn=test_collator)
inference(args, model, dataset, test_dataloader, processor, device)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument('--img_path', type=str, default='./data/test/png')
# parser.add_argument('--table_path', type=str, default='./data/test/tables')
# parser.add_argument('--row_path', type=str, default='./data/test/gpt_indexes')
# parser.add_argument('--col_path', type=str, default='./data/test/gpt_columns')
# parser.add_argument('--json_path', type=str, default='./data/test/annotations')
parser.add_argument('--img_path', type=str, default='./data/test1/png')
parser.add_argument('--table_path', type=str, default='./data/test1/tables')
parser.add_argument('--row_path', type=str, default='./data/test1/indexes')
parser.add_argument('--col_path', type=str, default='./data/test1/columns')
parser.add_argument('--json_path', type=str, default='./data/test1/annotations')
parser.add_argument('--device', type=str, default='0')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--phase', type=int, default=3)
parser.add_argument('--state_path', type=str, default='./state/phase_2_best_model.pth')
parser.add_argument('--result_path', type=str, default='./result/')
parser.add_argument('--tau', type=float, default=1)
parser.add_argument('--theta', type=float, default=0.5)
parser.add_argument('--inference_type', type=str, default='QA')
args = parser.parse_args()
main(args)