1111
1212# Credit - Bahar
1313class MapReader :
14- def __init__ (self , filename ):
14+ def __init__ (self , filename , mapName ):
15+ self .mapName = mapName
1516 f = open (filename , 'r' )
1617 data = f .read ().split ()
1718 self .size = (int (data [0 ]), int (data [1 ]))
@@ -38,7 +39,8 @@ def get_char(self, layer):
3839# Modifications - Ishaan
3940
4041class MapsDataset (Dataset ):
41- def __init__ (self , window_size , step_size , sample_group_size , converter ):
42+ def __init__ (self , window_size , step_size , sample_group_size , converter , outputDir = "../data/output" ):
43+ self .outputDir = outputDir
4244 self .char_size = converter .char_size
4345 self .converter = converter
4446 self .window_size = window_size
@@ -47,6 +49,9 @@ def __init__(self, window_size, step_size, sample_group_size, converter):
4749 self .samples = []
4850 self .block_size = self .window_size [0 ] * self .window_size [1 ] - 1
4951
52+ os .makedirs (self .outputDir , exist_ok = True )
53+
54+
5055 def __len__ (self ):
5156 return len (self .samples )
5257
@@ -67,6 +72,14 @@ def add(self, mapReader):
6772
6873 #Generate image patches and write to data/output directory
6974 def generate_patches (self , mapReader , image_groups = 3 ):
75+ """_summary_
76+
77+ Args:
78+ mapReader (MapReader): reader for a single big map!
79+ image_groups (int, optional): _description_. Defaults to 3.
80+ """
81+ outDirectory = os .path .join (self .outputDir , mapReader .mapName , str (self .window_size ))
82+ os .makedirs (outDirectory , exist_ok = True )
7083 img_group_number = 0
7184 for i in range (0 , mapReader .size [0 ] - self .window_size [0 ] + 1 , self .step_size ):
7285 for j in range (0 , mapReader .size [1 ] - self .window_size [1 ] + 1 , self .step_size ):
@@ -76,7 +89,8 @@ def generate_patches(self, mapReader, image_groups=3):
7689 for x in range (self .window_size [0 ])])
7790
7891 if len (self .samples )== self .sample_group_size :
79- with open ("data/output/SF_group" + str (img_group_number )+ ".dill" , 'wb+' ) as f :
92+ path = os .path .join (outDirectory , str (img_group_number ) + ".dill" )
93+ with open (path , 'wb+' ) as f :
8094 dill .dump (self .samples , f )
8195 f .close ()
8296 print ("Image group {} saved in data/output/ !" .format (img_group_number ))
0 commit comments