-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathexpand_to_square.py
More file actions
82 lines (65 loc) · 2.78 KB
/
expand_to_square.py
File metadata and controls
82 lines (65 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import json
import os
import re
from PIL import Image
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', required=True, type=str, help='The directory of unzipped data.zip')
parser.add_argument('--image_dir', required=True, type=str, help='The directory of images')
parser.add_argument('--output_dir', required=True, type=str, help='The output directory')
def expand_to_square(box, w, h):
if w == h:
return box
if w > h:
x1, y1, x2, y2 = box
y1 += (w - h) // 2
y2 += (w - h) // 2
box = x1, y1, x2, y2
return box
x1, y1, x2, y2 = box
x1 += (h - w) // 2
x2 += (h - w) // 2
box = x1, y1, x2, y2
return box
def normalize_bbox(box, w, h):
if w > h:
return [round(i / w, 3) for i in box]
return [round(i / h, 3) for i in box]
if __name__ == "__main__":
args = parser.parse_args()
input_files = ['ref3rec.json', 'ref3reg.json', 'shikra.json', 'svit.json']
bbox_regex = r"\[[0-9.]+, [0-9.]+, [0-9.]+, [0-9.]+\]"
images = dict()
for input_file in input_files:
input_file_path = os.path.join(args.data_dir, input_file)
print(input_file_path)
with open(input_file_path, 'r') as fin:
data = json.load(fin)
pbar = tqdm(data)
for item in pbar:
image_file_name = item['image']
if image_file_name in images:
width, height = images[image_file_name]
else:
image = Image.open(os.path.join(args.image_dir, image_file_name))
width, height = image.size
images[image_file_name] = width, height
for qa in item['conversations']:
regex = re.compile(r'\[[0-9.]+, [0-9.]+, [0-9.]+, [0-9.]+\]')
matches = re.findall(regex, qa['value'])
new_value = qa['value']
for match in matches:
try:
x1, y1, x2, y2 = [float(part.strip()) for part in match.strip('[]').split(',')]
original_bbox = round(x1 * width), round(y1 * height), round(x2 * width), round(y2 * height)
expanded_bbox = expand_to_square(original_bbox, width, height)
new_x1, new_y1, new_x2, new_y2 = normalize_bbox(expanded_bbox, width, height)
new_bbox_string = f'[{new_x1}, {new_y1}, {new_x2}, {new_y2}]'
new_value = new_value.replace(match, new_bbox_string)
except:
pass
qa['value'] = new_value
output_file = os.path.join(args.output_dir, input_file)
with open(output_file, "w+") as fout:
fout.write(json.dumps(data, indent=4))