Skip to content

Commit 2434937

Browse files
authored
Merge pull request #1 from KordingLab/feature/qc-single-patch-orientation
Thank you!
2 parents 73a9cba + 99dd6e6 commit 2434937

File tree

4 files changed

+825
-0
lines changed

4 files changed

+825
-0
lines changed

single_patch_orientation/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Single patch orientation using pre-trained AlexNet
2+
3+
Test pre-trained AlexNet with various orientation stimuli used in psychophysics to see if the model shows similar behaviors as in human observers.
4+
5+
## Usage
6+
The main Python file for this project is `alexnet_to_orientation.py`.
7+
An example usage of this script is:
8+
9+
```
10+
$ python alexnet_to_orientation.py --epochs 10 \
11+
--save-model --model-name 'alexNet_broadband_multiorinoise_naturaloriprior' \
12+
--if-more-noise-levels
13+
```
14+
15+
For all options and command-line arguments, please use:
16+
17+
```
18+
$ python alexnet_to_orientation.py -h
19+
```
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import numpy as np
2+
3+
import matplotlib.pyplot as plt
4+
import argparse
5+
6+
import torch
7+
import torch.optim as optim
8+
9+
import orientation_stim
10+
import toy_model_train_orientation
11+
12+
13+
"""
14+
example of use:
15+
python alexnet_to_orientation.py --epochs 10 \
16+
--save-model --model-name 'alexNet_broadband_multiorinoise_naturaloriprior' \
17+
--if-more-noise-levels
18+
"""
19+
20+
def main():
21+
parser = argparse.ArgumentParser(description='Orientation with pre-trained alexnet')
22+
parser.add_argument('--input-img-size', type=int, default=224, metavar='N',
23+
help='input image size (default: 224)')
24+
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
25+
help='input batch size for training (default: 128)')
26+
parser.add_argument('--epoch-size', type=int, default=100, metavar='N',
27+
help='epoch size (# of batches) for generating images online.')
28+
parser.add_argument('--test-epoch-size', type=int, default=20, metavar='N',
29+
help='epoch size (# of batches) for generating images online.')
30+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
31+
help='number of epochs to train (default: 10)')
32+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
33+
help='how many batches to wait before logging training status')
34+
parser.add_argument('--no-cuda', action='store_true', default=False,
35+
help='disables CUDA training')
36+
parser.add_argument('--seed', type=int, default=1, metavar='S',
37+
help='random seed (default: 1)')
38+
parser.add_argument('--save-model', action='store_true', default=False,
39+
help='For Saving the current Model')
40+
parser.add_argument('--model-name', type=str, default='orient_cnn',
41+
help='Name of the current Model for saving')
42+
parser.add_argument('--if-unif', action='store_true', default=False, # use natural prior by default
43+
help='if the training distribution uniform')
44+
parser.add_argument('--if-more-noise-levels', action='store_true', default=False,
45+
help='if multiple orientation noise levels used in training')
46+
47+
args = parser.parse_args()
48+
args.vis_name = args.model_name
49+
50+
use_cuda = not args.no_cuda and torch.cuda.is_available()
51+
52+
torch.manual_seed(args.seed)
53+
54+
device = torch.device("cuda" if use_cuda else "cpu")
55+
56+
model = toy_model_train_orientation.SlimAlexNet(max_pool_layer_index=1,
57+
last_layer_num_params=2).to(device)
58+
59+
# the parameters would be finetuned - should only be conv_pool.5
60+
params_to_update = model.parameters()
61+
print("Params to learn:")
62+
for name, param in model.named_parameters():
63+
if param.requires_grad == True:
64+
print("\t", name)
65+
66+
optimizer_ft = optim.Adam(params_to_update, lr=1e-4)
67+
68+
test_loss_history = []
69+
report_epoch = [1, 2, 4, 8, 16]
70+
for epoch in range(1, args.epochs + 1):
71+
toy_model_train_orientation.train(args, model, device,
72+
optimizer_ft, epoch, if_alexNet=True)
73+
test_loss = toy_model_train_orientation.test(args, model, device,
74+
if_alexNet=True)
75+
test_loss_history.append(test_loss)
76+
# save intermediate result figures
77+
if epoch in report_epoch:
78+
if (args.save_model):
79+
torch.save(model.state_dict(),
80+
args.model_name + '_epoch' + str(epoch) + '.pt')
81+
# save a figure with bias
82+
model_file = args.model_name + '_epoch' + str(epoch) + '.pt'
83+
ave_ori, all_ave_bias = toy_model_train_orientation.compare_bias(model_file,
84+
img_size=224, if_alexNet=True)
85+
plt.figure(figsize=(5, 3))
86+
plt.plot(ave_ori, all_ave_bias[:, 1] - all_ave_bias[:, 0])
87+
plt.plot(ave_ori, all_ave_bias[:, 2] - all_ave_bias[:, 0])
88+
plt.plot([np.pi / 4, np.pi / 4], [-5 / 180 * np.pi, 5 / 180 * np.pi], 'k--')
89+
plt.plot([np.pi / 2, np.pi / 2], [-5 / 180 * np.pi, 5 / 180 * np.pi], 'k--')
90+
plt.plot([3 * np.pi / 4, 3 * np.pi / 4], [-5 / 180 * np.pi, 5 / 180 * np.pi], 'k--')
91+
plt.plot([0, np.pi], [0, 0], 'k--')
92+
plt.xlim([0, np.pi])
93+
plt.ylim(-5 / 180 * np.pi, 5 / 180 * np.pi)
94+
plt.savefig(args.vis_name + '_epoch' + str(epoch) + '.pdf')
95+
96+
# save final report if not has been saved
97+
if args.epochs not in report_epoch:
98+
if (args.save_model):
99+
torch.save(model.state_dict(),
100+
args.model_name + '_epoch' + str(epoch) + '.pt')
101+
102+
# save a final figure with cosine similarity and bias
103+
feature_cosSim = toy_model_train_orientation.feature_similarity(args.input_img_size,
104+
model, feature_layer_ind=3, if_alexNet=True)
105+
plt.figure(figsize=(8, 3))
106+
plt.subplot(1, 2, 1)
107+
plt.plot(np.arange(179), feature_cosSim)
108+
plt.xlim([0, 180])
109+
# bias differences
110+
model_file = args.model_name + '_epoch' + str(epoch) + '.pt'
111+
ave_ori, all_ave_bias = toy_model_train_orientation.compare_bias(model_file,
112+
img_size=224, if_alexNet=True)
113+
plt.subplot(1, 2, 2)
114+
plt.plot(ave_ori, all_ave_bias[:, 1] - all_ave_bias[:, 0])
115+
plt.plot(ave_ori, all_ave_bias[:, 2] - all_ave_bias[:, 0])
116+
plt.plot([0, np.pi], [0, 0], 'k--')
117+
plt.plot([np.pi / 4, np.pi / 4], [-5 / 180 * np.pi, 5 / 180 * np.pi], 'k--')
118+
plt.plot([np.pi / 2, np.pi / 2], [-5 / 180 * np.pi, 5 / 180 * np.pi], 'k--')
119+
plt.plot([3 * np.pi / 4, 3 * np.pi / 4], [-5 / 180 * np.pi, 5 / 180 * np.pi], 'k--')
120+
plt.xlim([0, np.pi])
121+
plt.ylim(-5 / 180 * np.pi, 5 / 180 * np.pi)
122+
plt.savefig(args.vis_name + '_final_summary' + '.pdf')
123+
124+
# figure showing the test loss progress
125+
plt.figure()
126+
plt.plot(np.arange(1, args.epochs + 1), test_loss_history)
127+
plt.savefig(args.vis_name + '_loss_history' + '.pdf')
128+
129+
130+
if __name__ == '__main__':
131+
main()
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import numpy as np
2+
import scipy.ndimage
3+
import matplotlib
4+
5+
# matplotlib.use("TkAgg")
6+
import matplotlib.pyplot as plt
7+
8+
# torch package
9+
import torch
10+
from torch.distributions import normal
11+
12+
13+
def grating(size=500, pixelsPerDegree=200, spatial_freq=3, spatial_phase=0,
14+
orientation=np.pi/4, contrast=1):
15+
'''
16+
The output range is -1 to 1 if contrast is 1
17+
18+
size: number of pixel of the image patch, assuming square
19+
spatial_freq: cycle per visual angle
20+
spatial_phase: in radians
21+
pixelsPerDegree: number of patch pixels in one degree visual angle
22+
orientation: in radians
23+
'''
24+
x, y = np.meshgrid(np.arange(size), np.arange(size))
25+
x = (x - size / 2.0) / pixelsPerDegree
26+
y = (y - size / 2.0) / pixelsPerDegree
27+
28+
return contrast * np.cos(spatial_phase +
29+
2 * np.pi * spatial_freq *
30+
(x * np.sin(orientation) + y * np.cos(orientation)))
31+
32+
33+
def gabor(size=500, pixelsPerDegree=100, spatial_freq=3, spatial_phase=0,
34+
orientation=np.pi/4, contrast=1, sigma=.5, spatial_aspect_ratio=1):
35+
'''
36+
adds an exponential modulation (Gaussian envelope) on top of "grating"
37+
38+
spatial_freq: cycle/visual angle
39+
spatial_phase: in radians
40+
orientation: in radians
41+
contrast: if 1, values range from -1 to 1
42+
sigma: standard deviation of the Gaussian envelope (in visual angle)
43+
spatial_aspect_ratio (gamma): if not 1, distorted along orientation.
44+
specifies the ellipticity of the support of the Gabor function.
45+
46+
'''
47+
48+
x, y = np.meshgrid(np.arange(size), np.arange(size))
49+
x = (x - size / 2.0) / pixelsPerDegree #in visual angle
50+
y = (y - size / 2.0) / pixelsPerDegree
51+
52+
# rotation
53+
x_theta = x * np.cos(orientation) + y * np.sin(orientation)
54+
y_theta = -x * np.sin(orientation) + y * np.cos(orientation)
55+
56+
# Gaussian envelope
57+
sigma_x = sigma
58+
sigma_y = float(sigma) / spatial_aspect_ratio
59+
gaussian_envelope = np.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 +
60+
y_theta ** 2 / sigma_y ** 2))
61+
62+
# sinusoidal grating
63+
grating = np.cos(2 * np.pi * spatial_freq * x_theta + spatial_phase)
64+
gabor = gaussian_envelope * grating
65+
66+
# normalize for contrast
67+
gabor_min = np.min(gabor[:])
68+
gabor_max = np.max(gabor[:])
69+
gabor = (gabor - gabor_min) * contrast * 2 / (gabor_max - gabor_min) - 1
70+
71+
return gabor
72+
73+
74+
def matlab_style_gauss2D(shape=(3, 3), sigma=0.5):
75+
"""
76+
2D gaussian mask - should give the same result as MATLAB's
77+
fspecial('gaussian',[shape],[sigma])
78+
"""
79+
m,n = [(ss-1.)/2. for ss in shape]
80+
y,x = np.ogrid[-m:m+1,-n:n+1]
81+
h = np.exp( -(x*x + y*y) / (2.*sigma*sigma))
82+
h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
83+
sumh = h.sum()
84+
if sumh != 0:
85+
h /= sumh
86+
return h
87+
88+
89+
def circular_mask(size=500, pixelsPerDegree=200, radius=1, polarity_in=1, polarity_out=0,
90+
if_filtered=False, filter_size = (15, 15), filter_width = 2):
91+
'''
92+
:param size: size of the image patch
93+
:param radius: in visual angle
94+
:param polarity_in: 1 or 0 inside the circle
95+
:param filter_size and filter_width are in pixel units
96+
:return: the mask
97+
'''
98+
x, y = np.meshgrid(np.arange(size), np.arange(size))
99+
x = (x - size / 2.0) / pixelsPerDegree
100+
y = (y - size / 2.0) / pixelsPerDegree
101+
102+
mask = np.ones([size, size]) * polarity_out
103+
mask[np.sqrt(np.power(x, 2) + np.power(y, 2)) < radius] = polarity_in
104+
105+
# Gaussian filtering the mask
106+
if if_filtered:
107+
H = matlab_style_gauss2D(filter_size, filter_width) # lowpass filter
108+
mask = scipy.ndimage.convolve(mask, H, mode='nearest') # filter
109+
110+
return mask
111+
112+
# --- example grating ---
113+
# img = grating(500, 200, 3, np.pi/4, 1, np.pi/5)
114+
# mask = circular_mask(500, 200, 1,
115+
# if_filtered=True, filter_size=(50, 50), filter_width=10)
116+
# img = np.multiply(img, mask)
117+
# plt.imshow(img, cmap=plt.gray()), plt.show()
118+
119+
120+
def broadband_noise(size=64, contrast=1, if_low_pass=True, center_sf=0, sf_sigma=10,
121+
if_band_pass=False, low_sf=1.67, high_sf=10.67,
122+
orientation=20/180*np.pi, orient_sigma=10/180*np.pi):
123+
"""
124+
broadband noise with either low pass [--low-pass] or
125+
band pass [--band-pass] spatial frequency;
126+
if set one as true, need to make sure the other is false
127+
128+
Args:
129+
size: if size is not power of 2, crop to be the intended size
130+
contrast:
131+
center_sf: if 0, low pass
132+
sf_sigma: in pixel; if Inf, all sf included (but exclude the corner)
133+
low_sf: cycle/image
134+
high_sf: cycle/image
135+
orientation: in radians, note that the available range of orientation is 0 - pi
136+
orient_sigma: in radians; if Inf, all orientations are included
137+
138+
Return:
139+
"""
140+
141+
size_tmp = np.power(2, np.ceil(np.log2(size))).astype(int) # then crop to intended size
142+
input_img = np.random.uniform(0, 1, [size_tmp, size_tmp])
143+
max_sf = size_tmp / 2
144+
img_center = np.matlib.repmat(np.floor(size_tmp / 2), 1, 2)
145+
146+
# Fourier transform and separate magnitude and phase
147+
f = np.fft.fftshift(np.fft.fft2(input_img))
148+
# mag_f = np.abs(f)
149+
# phase_f = np.angle(f)
150+
151+
# make generic matrices, where r represents frequency,
152+
# theta for orientation
153+
x, y = np.meshgrid(np.arange(-size_tmp/2, size_tmp/2),
154+
np.arange(-size_tmp/2, size_tmp/2))
155+
r = np.sqrt(np.power(x, 2) + np.power(y, 2))
156+
y[y == 0] = .01
157+
theta = np.arctan(x / y)
158+
theta[f.shape[1] // 2:, :] = theta[f.shape[1] // 2 :, :] - np.pi
159+
theta += 3 * np.pi / 2
160+
# theta = np.arctan2(y, x) + np.pi #in radians
161+
162+
# build the filter, spatial freq filters
163+
if if_low_pass:
164+
if np.isinf(sf_sigma):
165+
sf_band = np.zeros(r.shape)
166+
sf_band[r <= max_sf] = 1
167+
else:
168+
sf_band = np.exp(-((r - center_sf) ** 2 / 2 / sf_sigma ** 2))
169+
elif if_band_pass:
170+
sf_band = np.zeros(r.shape)
171+
sf_band[(r >= low_sf) & (r<= high_sf)] = 1
172+
173+
174+
# orientation filters
175+
if np.isinf(orient_sigma):
176+
pass_band = sf_band
177+
else: # need to make sure it works properly for 0-180 deg
178+
orient_band = np.exp(-((theta - (orientation + np.pi / 2)) ** 2 / 2 / orient_sigma ** 2)) + \
179+
np.exp(-((np.angle(np.exp(1j * (theta))) + np.pi - (orientation + np.pi / 2)) ** 2
180+
/ 2 / orient_sigma ** 2))
181+
pass_band = np.multiply(orient_band, sf_band)
182+
# only pass the needed components, and reconstruct the image back
183+
band = np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(np.multiply(f, pass_band))))
184+
band = np.real(band)
185+
186+
# normalize for contrast
187+
band_min = np.min(band[:])
188+
band_max = np.max(band[:])
189+
band = (band - band_min) * contrast * 2 / (band_max - band_min) - 1
190+
191+
# crop if needed
192+
if size_tmp > size:
193+
band = band[0:size, 0:size]
194+
195+
return band
196+

0 commit comments

Comments
 (0)