@@ -121,8 +121,7 @@ def create_patches(self, add_frame=False, center=True):
121121
122122
123123
124-
125- def generate_patch_set (x_batch , patch_size = (8 , 8 ), max_patches = 50 , center = True , seed = 1234 , vis_mode = False ): ## scikit
124+ def generate_patch_set (x_batch , patch_size = (8 , 8 ), max_patches = 50 , center = True , seed = 1234 , vis_mode = False , step_size = (1 , 1 )): ## scikit
126125 """
127126 Generates a set of patches from an array/list of image arrays (via
128127 random sampling with replacement). This uses scikit-learn's patch creation
@@ -149,7 +148,7 @@ def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True,
149148 for s in range (_x_batch .shape [0 ]):
150149 xs = _x_batch [s , :]
151150 xs = xs .reshape (px , py )
152- patches = extract_patches_2d (xs , patch_size , max_patches = max_patches , random_state = seed )#, random_state=69)
151+ patches = extract_patches_2d (xs , patch_size , max_patches = max_patches , random_state = seed , extraction_step = step_size )#, random_state=69)
153152 patches = np .reshape (patches , (len (patches ), - 1 )) # flatten each patch in set
154153 if s > 0 :
155154 p_batch = np .concatenate ((p_batch ,patches ),axis = 0 )
@@ -201,3 +200,4 @@ def generate_pacthify_patch_set(x_batch_, patch_size=(5, 5), center=True): ## pa
201200 mu = np .mean (patchBatch , axis = 1 ,keepdims = True )
202201 patchBatch = patchBatch - mu
203202 return patchBatch
203+
0 commit comments