This repository was archived by the owner on Nov 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathstitch_images.py
More file actions
181 lines (158 loc) · 5.72 KB
/
stitch_images.py
File metadata and controls
181 lines (158 loc) · 5.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# -*- coding: utf-8 -*-
# Description: stitches together item captures into one image
# Example usage:
# python stitch_images.py ../data/ ../img/items/ ../img/ 100 10 10 default 50 20 3 30000
# python stitch_images.py ../data/ ../img/items/ ../img/ 100 10 10 centuries 50 20 3 30000
# python stitch_images.py ../data/ ../img/items/ ../img/ 100 10 10 collections 50 20 3 30000
# python stitch_images.py ../data/ ../img/items/ ../img/ 100 10 10 colors 50 20 3 30000
# python stitch_images.py ../data/ ../img/items/ ../img/ 100 10 10 genres 50 20 3 30000
# python stitch_images.py ../data/ ../img/items/ ../img/ 100 80 80 colors 50 20 3 8000
from PIL import Image
import json
import math
import os
import sys
# input
if len(sys.argv) < 11:
print "Usage: %s <inputdir of data> <inputdir of images> <outputdir for image> <images per row> <image cell width> <image cell height> <data group> <group item threshold> <group threshold> <min group rows> <max image height>" % sys.argv[0]
sys.exit(1)
INPUT_DATA_DIR = sys.argv[1]
INPUT_IMAGE_DIR = sys.argv[2]
OUTPUT_IMAGE_DIR = sys.argv[3]
ITEMS_PER_ROW = int(sys.argv[4])
ITEM_W = int(sys.argv[5])
ITEM_H = int(sys.argv[6])
DATA_GROUP = sys.argv[7]
GROUP_ITEM_THRESHOLD = int(sys.argv[8])
GROUP_THRESHOLD = int(sys.argv[9])
MIN_GROUP_ROWS = int(sys.argv[10])
MAX_IMAGE_HEIGHT = int(sys.argv[11])
# config
imageExt = "jpg"
# init captures
captures = []
with open(INPUT_DATA_DIR + "captures.json") as data_file:
captures = json.load(data_file)
itemCount = len(captures)
print "Loaded " + str(itemCount) + " captures..."
def getItemsIds(the_group, the_items):
ids = []
if isinstance(the_items[0], list):
items = [{'id': item_i, 'score': group_value[1]} for item_i, group_value in enumerate(the_items) if group_value[0] == the_group['index']]
items = sorted(items, key=lambda k: k['score'], reverse=True)
ids = [i['id'] for i in items]
else:
ids = [item_i for item_i, group_i in enumerate(the_items) if group_i == the_group['index']]
return ids
# init groups
groups = []
item_groups = []
group_filename = INPUT_DATA_DIR + DATA_GROUP + ".json"
items_group_filename = INPUT_DATA_DIR + "item_" + DATA_GROUP + ".json"
if os.path.isfile(group_filename) and os.path.isfile(items_group_filename):
_groups = []
with open(group_filename) as data_file:
_groups = json.load(data_file)
with open(items_group_filename) as data_file:
item_groups = json.load(data_file)
# Take out unknown group
unknown = next(iter([g for g in _groups if not g['value']]), False)
other = {
'count': 0,
'items': []
}
# Add items to appropriate groups
for i,g in enumerate(_groups):
if g['value']:
item_ids = getItemsIds(g, item_groups)
# this group is too small; add to "other" group
if g['count'] < GROUP_ITEM_THRESHOLD and len(_groups) > GROUP_THRESHOLD:
other['items'].extend(item_ids)
other['count'] += g['count']
else:
g['items'] = item_ids
groups.append(g)
# Add "other" group
if other['count']:
groups.append(other)
# Add "uknown" group
if unknown:
unknown['items'] = getItemsIds(unknown, item_groups)
groups.append(unknown)
else:
# Put everything in one big group
groups.append({
'items': range(itemCount),
'count': itemCount
})
# init
x = 0
y = 0
failCount = 0
skipCount = 0
count = 0
# calculate height
rows = int(math.ceil(1.0 * itemCount / ITEMS_PER_ROW))
imageW = ITEM_W * ITEMS_PER_ROW
imageH = rows * ITEM_H
if len(groups) > 1:
rows = 0
for g in groups:
group_rows = int(math.ceil(1.0 * g['count'] / ITEMS_PER_ROW))
if group_rows < MIN_GROUP_ROWS:
group_rows = MIN_GROUP_ROWS
rows += group_rows
imageH = rows * ITEM_H
# Ensure under max height
imageH = min(imageH, MAX_IMAGE_HEIGHT)
# Create blank image
print "Creating blank image at (" + str(imageW) + " x " + str(imageH) + ")"
imageBase = Image.new("RGB", (imageW, imageH), "black")
for g in groups:
items = g['items']
extra_rows = max(MIN_GROUP_ROWS - int(math.ceil(1.0 * g['count'] / ITEMS_PER_ROW)), 0)
for itemId in items:
captureId = captures[itemId]
# Determine x/y
if x >= imageW:
x = 0
y += ITEM_H
if y >= imageH:
break
# Try to paste image
if captureId:
fileName = INPUT_IMAGE_DIR + captureId + "." + imageExt
try:
im = Image.open(fileName)
im.thumbnail((ITEM_W, ITEM_H), Image.ANTIALIAS) # Image.NEAREST
imageBase.paste(im, (x, y))
# print "Pasted " + fileName
sys.stdout.write('\r')
sys.stdout.write(str(round(1.0*count/itemCount*100,3))+'%')
sys.stdout.flush()
except IOError:
# print "Cannot read file: " + fileName
failCount += 1
except:
# print "Unexpected error:", sys.exc_info()[0]
failCount += 1
raise
else:
skipCount += 1
x += ITEM_W
count += 1
# Add extra rows for small groups
for extra_row in range(extra_rows):
y += ITEM_H
# Go to the next line for the next group
x = 0
y += ITEM_H
if y >= imageH:
break
# Save image
print "Saving stiched image..."
outputfile = OUTPUT_IMAGE_DIR + DATA_GROUP + "_" + str(ITEMS_PER_ROW) + "_" + str(ITEM_W) + "_" + str(ITEM_H) + "." + imageExt
imageBase.save(outputfile)
print "Saved image: " + outputfile
print "Failed to add " + str(failCount) + " images."
print "Skipped " + str(skipCount) + " images."