Skip to content

Commit 6cb1c7f

Browse files
committed
optimizing generate_patches
1 parent 5a583cc commit 6cb1c7f

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

src/maps.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import random
99
import dill
10+
import logging
1011

1112

1213
# Credit - Bahar
@@ -71,32 +72,46 @@ def add(self, mapReader):
7172
for x in range(self.window_size[0])])
7273

7374
#Generate image patches and write to data/output directory
74-
def generate_patches(self, mapReader, image_groups=3):
75+
def generate_patches(self, mapReader, image_groups=3, outDirectory=None):
7576
"""_summary_
7677
7778
Args:
7879
mapReader (MapReader): reader for a single big map!
7980
image_groups (int, optional): _description_. Defaults to 3.
8081
"""
81-
outDirectory = os.path.join(self.outputDir, mapReader.mapName, str(self.window_size))
82+
if outDirectory is None:
83+
outDirectory = os.path.join(self.outputDir, mapReader.mapName, str(self.window_size))
8284
os.makedirs(outDirectory, exist_ok=True)
85+
8386
img_group_number = 0
8487
for i in range(0, mapReader.size[0] - self.window_size[0] + 1, self.step_size):
8588
for j in range(0, mapReader.size[1] - self.window_size[1] + 1, self.step_size):
86-
self.samples.append([[
87-
(self.converter.get_char(mapReader.data[i + x][j + y]) / (len(self.converter.char_groups) - 1)) * -2 + 1
88-
for y in range(self.window_size[1])]
89-
for x in range(self.window_size[0])])
9089

91-
if len(self.samples)==self.sample_group_size:
90+
sample = self.extractSample(mapReader, topLeft=(i, j))
91+
92+
self.samples.append(sample)
93+
94+
if len(self.samples) == self.sample_group_size:
9295
path = os.path.join(outDirectory, str(img_group_number) + ".dill")
9396
with open(path, 'wb+') as f:
9497
dill.dump(self.samples, f)
9598
f.close()
96-
print("Image group {} saved in data/output/ !".format(img_group_number))
99+
logging.info(f"Image group {img_group_number} saved in {path}")
97100
self.samples.clear()
98101
img_group_number+=1
99-
102+
103+
def extractSample(self, mapReader, topLeft):
104+
i = topLeft[0]
105+
j = topLeft[1]
106+
sample = [
107+
[
108+
(self.converter.get_char(mapReader.data[i + x][j + y]) / (len(self.converter.char_groups) - 1)) * -2 + 1 # TODO this conversion should be done once in the original data instead of patches.
109+
for y in range(self.window_size[1])
110+
]
111+
for x in range(self.window_size[0])
112+
]
113+
return sample
114+
100115

101116
def shuffle(self):
102117
random.shuffle(self.samples)

src/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Run this script to generate data in /output directory.
22
import maps
33

4-
sfMap = maps.MapReader('../data/input/sf_layered.txt', "SF_Layerd")
4+
sfMap = maps.MapReader('../data/input/sf_layered.txt', "SF_Layered")
55
mapsDataset = maps.MapsDataset((32, 32), 2, 1280, maps.single_layer_converter) #Third parameter is the group size
66
mapsDataset.generate_patches(sfMap) #This will generate dill files which contain the saved sample lists.

0 commit comments

Comments
 (0)