Skip to content

Commit 2e2dd5e

Browse files
authored
Refactor patch utility functions and add doc strings (#136)
1 parent b0a4193 commit 2e2dd5e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ngclearn/utils/patch_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def create_patches(self, add_frame=False, center=True):
121121

122122

123123

124-
125-
def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234, vis_mode=False): ## scikit
124+
def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234, vis_mode=False, step_size=(1, 1)): ## scikit
126125
"""
127126
Generates a set of patches from an array/list of image arrays (via
128127
random sampling with replacement). This uses scikit-learn's patch creation
@@ -149,7 +148,7 @@ def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True,
149148
for s in range(_x_batch.shape[0]):
150149
xs = _x_batch[s, :]
151150
xs = xs.reshape(px, py)
152-
patches = extract_patches_2d(xs, patch_size, max_patches=max_patches, random_state=seed)#, random_state=69)
151+
patches = extract_patches_2d(xs, patch_size, max_patches=max_patches, random_state=seed, extraction_step=step_size)#, random_state=69)
153152
patches = np.reshape(patches, (len(patches), -1)) # flatten each patch in set
154153
if s > 0:
155154
p_batch = np.concatenate((p_batch,patches),axis=0)
@@ -201,3 +200,4 @@ def generate_pacthify_patch_set(x_batch_, patch_size=(5, 5), center=True): ## pa
201200
mu = np.mean(patchBatch, axis=1,keepdims=True)
202201
patchBatch = patchBatch - mu
203202
return patchBatch
203+

0 commit comments

Comments
 (0)