66import torch
77import numpy as np
88import random
9+ import dill
10+
911
1012# Credit - Bahar
1113class MapReader :
@@ -36,11 +38,12 @@ def get_char(self, layer):
3638# Modifications - Ishaan
3739
3840class MapsDataset (Dataset ):
39- def __init__ (self , window_size , step_size , converter ):
41+ def __init__ (self , window_size , step_size , sample_group_size , converter ):
4042 self .char_size = converter .char_size
4143 self .converter = converter
4244 self .window_size = window_size
4345 self .step_size = step_size
46+ self .sample_group_size = sample_group_size
4447 self .samples = []
4548 self .block_size = self .window_size [0 ] * self .window_size [1 ] - 1
4649
@@ -62,6 +65,25 @@ def add(self, mapReader):
6265 for y in range (self .window_size [1 ])]
6366 for x in range (self .window_size [0 ])])
6467
68+ #Generate image patches and write to data/output directory
69+ def generate_patches (self , mapReader , image_groups = 3 ):
70+ img_group_number = 0
71+ for i in range (0 , mapReader .size [0 ] - self .window_size [0 ] + 1 , self .step_size ):
72+ for j in range (0 , mapReader .size [1 ] - self .window_size [1 ] + 1 , self .step_size ):
73+ self .samples .append ([[
74+ (self .converter .get_char (mapReader .data [i + x ][j + y ]) / (len (self .converter .char_groups ) - 1 )) * - 2 + 1
75+ for y in range (self .window_size [1 ])]
76+ for x in range (self .window_size [0 ])])
77+
78+ if len (self .samples )== self .sample_group_size :
79+ with open ("data/output/SF_group" + str (img_group_number )+ ".dill" , 'wb+' ) as f :
80+ dill .dump (self .samples , f )
81+ f .close ()
82+ print ("Image group {} saved in data/output/ !" .format (img_group_number ))
83+ self .samples .clear ()
84+ img_group_number += 1
85+
86+
6587 def shuffle (self ):
6688 random .shuffle (self .samples )
6789
0 commit comments