Skip to content

Commit 68e548d

Browse files
committed
A refactor to prepare for a merge with the stocker lab
1 parent c931879 commit 68e548d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+926
-693
lines changed

.idea/Illusions.iml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

Lines changed: 435 additions & 278 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

from_lines/__init__.py renamed to Local_orientation_of_cross_lines/__init__.py

File renamed without changes.

from_lines/angle_decoder/__init__.py renamed to Local_orientation_of_cross_lines/angle_decoder/__init__.py

File renamed without changes.

from_lines/angle_decoder/angle_decoder_linear.py renamed to Local_orientation_of_cross_lines/angle_decoder/angle_decoder_linear.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,32 @@ def forward(self, x):
2727
class AngleDecoder(torch.nn.Module):
2828
"""This class takes the inputs of the pretrained VGG function and gives out the angle
2929
"""
30-
def __init__(self, layer, noise = 0, nonlinear = False):
31-
super(OrientationDecoder, self).__init__()
30+
def __init__(self, layer, noise = 0, nonlinear = False,n_filts=6):
31+
super(AngleDecoder, self).__init__()
3232
self.layer = layer
3333
self.noise = noise
34+
self.n_filts = n_filts
3435
maxpool_indices = [ 4, 9, 16, 23, 30]
3536
assert layer in maxpool_indices
3637

3738
# load the pretrained network
3839
self.vgg_chopped = VGG_chopped(layer)
3940

40-
n_feats = {4: 64 * 112 * 112,
41+
# n_feats = {4: 64 * 112 * 112,
42+
# 9: 128 * 56 * 56,
43+
# 16: 256 * 28 * 28,
44+
# 23: 512 * 14 * 14,
45+
# 30: 512 * 7 * 7}
46+
47+
n_feats = {4: 64 * (n_filts*2)**2,
4148
9: 128 * 56 * 56,
4249
16: 256 * 28 * 28,
4350
23: 512 * 14 * 14,
4451
30: 512 * 7 * 7}
4552

4653
if nonlinear:
4754
self.decoder = torch.nn.Sequential(
48-
torch.nn.Linear(n_feats[layer], 50),
49-
torch.nn.Dropout(.5),
50-
torch.nn.Linear(50, 1)
55+
torch.nn.Linear(n_feats[layer], 500),
5156
)
5257
else:
5358
self.decoder = torch.nn.Sequential(
@@ -60,11 +65,16 @@ def __init__(self, layer, noise = 0, nonlinear = False):
6065

6166
def forward(self, x):
6267
x = self.vgg_chopped(x)
68+
# take just the center bit. (HARDCODED LAYER=4)
69+
x = x[:,:,(56-self.n_filts):(56+self.n_filts),(56-self.n_filts):(56+self.n_filts)]
6370
# flatten
64-
x = x.view(-1,self.n_feat)
71+
x = x.contiguous().view(-1,self.n_feat)
72+
73+
74+
6575
# add noise
66-
x += self.noise * torch.randn(*x.size())
76+
x += self.noise * torch.randn_like(x)
6777
# get angle
68-
x = self.deconv(x)
78+
x = self.decoder(x)
6979

7080
return x
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pandas as pd
2+
import torch
3+
4+
def loader_generator(h5_path, batch_size, as_tensor=True, n_samples = 10000):
5+
"""Given an h5 path to a file that holds the arrays, returns a generator
6+
that can get certain data at a time."""
7+
8+
stop = n_samples
9+
curr_index = 0
10+
start = 0
11+
while 1:
12+
13+
dataframe = pd.read_hdf(h5_path, start=curr_index,
14+
stop=min([curr_index + batch_size, stop]))
15+
curr_index += batch_size
16+
17+
if (dataframe.shape[0]==0) or (curr_index >= stop):
18+
curr_index = start
19+
continue
20+
21+
if as_tensor:
22+
if dataframe.shape[1]>1:
23+
out = torch.Tensor(dataframe.values).view(batch_size, -1, 224, 224)
24+
else:
25+
out = torch.Tensor(dataframe.values)
26+
else:
27+
if dataframe.shape[1] > 1:
28+
out = dataframe.values.reshape((batch_size, -1, 224, 224))
29+
else:
30+
out = dataframe.values
31+
yield out
32+
33+
def data_iterator(h5_path, batch_size, as_tensor=True):
34+
return iter(loader_generator(h5_path, batch_size, as_tensor))

from_lines/angle_decoder/gen_and_save_lines_angle_decoder.py renamed to Local_orientation_of_cross_lines/angle_decoder/gen_and_save_lines_angle_decoder.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,41 @@
88

99
import argparse
1010
import os
11+
from tqdm import tqdm as tqdm
1112

12-
13-
def gen_many_lines_on_white(noise=True, mask=None):
13+
def gen_many_lines_on_white(noise=True, mask=None,N=1000, test = False):
1414
""" Two intersecting lines of random central angle and relative angle and position.
1515
16+
if test = True, then we sample from angles with final digits in the range 0-1
17+
else the rest.
18+
19+
1620
"""
21+
22+
23+
24+
1725
all_inputs = list()
1826
angles = list()
19-
n_samples = 10000
20-
for n in range(n_samples):
21-
# usually near the center
22-
centerloc = np.random.randint(80, 224 - 80, 2)
23-
fixed_angle = np.random.uniform(0, np.pi)
24-
relative_angle = np.random.uniform(0, np.pi / 2)
27+
n_samples = N
28+
29+
30+
if test:
31+
def gen_angle():
32+
tens = np.random.randint(0,18)*10
33+
rest = np.random.uniform(0,1)
34+
return (tens+rest)/180*np.pi-np.pi
35+
else:
36+
def gen_angle():
37+
tens = np.random.randint(0,18)*10
38+
rest = np.random.uniform(1,10)
39+
return (tens+rest)/180*np.pi-np.pi
40+
41+
for n in tqdm(range(n_samples)):
42+
#at center
43+
centerloc = (112,112)#np.random.randint(80, 224 - 80, 2)
44+
fixed_angle = np.pi/2
45+
relative_angle = gen_angle()
2546

2647
mask = np.zeros((224, 224)).astype(np.bool) if mask is None else mask
2748

@@ -47,11 +68,14 @@ def gen_many_lines_on_white(noise=True, mask=None):
4768
if __name__ == '__main__':
4869
parser = argparse.ArgumentParser()
4970

50-
parser.add_argument("--noise", help="The variance of gaussian noise added to the raw images",
71+
parser.add_argument("--noise", help="Whether to add gaussian noise added to the raw images",
72+
action = 'store_true')
73+
parser.add_argument("--n_images", help="How many to build",
74+
type=int, default = 10000)
75+
parser.add_argument("--test", help="Test set or train set?",
5176
action = 'store_true')
52-
5377
parser.add_argument("--image_directory", type = str,
54-
default='/home/abenjamin/DNN_illusions/fast_data/features/straight_lines/',
78+
default='/home/abenjamin/DNN_illusions/fast_data/features/rel_angle/',
5579
help="""Path to the folder in which we store the `lines.h5` and `lines_targets.h5` files.
5680
If lines_targets.h5 does not exist, we just plot the input and model output.""")
5781

@@ -63,8 +87,7 @@ def gen_many_lines_on_white(noise=True, mask=None):
6387
if (i - 112) ** 2 + (j - 112) ** 2 >= 100 ** 2:
6488
unit_circle[i, j] = True
6589

66-
67-
all_inputs, all_targets = gen_many_lines_on_white(args.noise, unit_circle)
90+
all_inputs, all_targets = gen_many_lines_on_white(args.noise, unit_circle, N=args.n_images, test = args.test)
6891

6992
all_inputs = pd.DataFrame(np.stack(all_inputs))
7093
all_inputs.to_hdf(args.image_directory+'lines.h5', key="l", mode='w')
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import torch
2+
from torch.autograd import Variable
3+
4+
import argparse
5+
import os
6+
from data_loader_utils import data_iterator
7+
from angle_decoder_linear import AngleDecoder
8+
9+
#visualize
10+
from matplotlib.colors import ListedColormap
11+
import matplotlib as mpl
12+
import matplotlib
13+
from colorspacious import cspace_convert
14+
import numpy as np
15+
from matplotlib import pyplot as plt
16+
17+
def gen_test_lines(mask=None):
18+
""" Two intersecting lines of certain central angle and relative angle and position.
19+
20+
"""
21+
all_inputs = list()
22+
angles = list()
23+
n_samples = 180
24+
for n in range(n_samples):
25+
centerloc = (112,112)
26+
fixed_angle = 0
27+
relative_angle = n*np.pi/360 # up to 90 deg
28+
29+
mask = np.zeros((224, 224)).astype(np.bool) if mask is None else mask
30+
31+
numpy_im = generate_intersecting_rgb(centerloc, fixed_angle, relative_angle,
32+
negative_mask=mask,
33+
linewidth=1)
34+
35+
36+
37+
all_inputs.append(numpy_to_torch(numpy_im).view(-1))
38+
angles.append(relative_angle)
39+
40+
return all_inputs, angles
41+
42+
def pass_test_images(model, samples, gpu = True, batch_size = 10):
43+
"""This script loads a model, as specified by the path, tests some images, and displays the decoded images.
44+
45+
Inputs:
46+
samples: a list of images
47+
48+
Returns a list of tuples of (input_image (3x224x224), orientation_image (2x224x224),
49+
target_orientation_image (2x224x224))
50+
"""
51+
52+
samples = torch.stack(samples)
53+
54+
if gpu:
55+
feats = samples.cuda()
56+
data = Variable(samples)
57+
58+
output = model(data).detach()
59+
if gpu:
60+
output = output.cpu()
61+
del data
62+
63+
return output.numpy()
64+
65+
def load_model(args):
66+
"""Loads the VGG+decoder network trained and saved at path."""
67+
model = AngleDecoder(args.layer, 0, args.nonlinear)
68+
69+
model.load_state_dict(torch.load(args.model_path))
70+
model.eval()
71+
return model
72+
73+
74+
def save_and_visualize(outputs, targets):
75+
"""Here we take decoded relative angles [range(180)] and plot them, along with what the angles actually were
76+
77+
"""
78+
79+
80+
81+
#
82+
#
83+
# for i,(input,output,target) in enumerate(images):
84+
#
85+
# check_sizes(input, output)
86+
#
87+
# plt.figure(figsize=(15, 5))
88+
# plt.subplot(131)
89+
# plt.imshow(np.moveaxis(input,0,2))
90+
# ax = plt.gca()
91+
# ax.set_axis_off()
92+
# ax.set_title("Input")
93+
#
94+
# plt.subplot(132)
95+
# ax2 = show_orientation_image(output)
96+
# ax2.set_title("Decoded orientation")
97+
#
98+
# try:#if target is not None:
99+
# plt.subplot(133)
100+
# ax3 = show_orientation_image(target)
101+
# ax3.set_title("Target orientation")
102+
# except:
103+
# print("Recomputing orientation image with kernel size {}".format(kernel_size))
104+
# filts = get_quadratures(kernel_size)
105+
#
106+
# # note that we invert to get the map
107+
# target = get_orientation_map(1-np.mean(check_on_float_scale(input),axis=0), filts)
108+
# plt.subplot(133)
109+
# ax3 = show_orientation_image(target)
110+
# ax3.set_title("Target orientation")
111+
#
112+
#
113+
# plt.tight_layout()
114+
# if save:
115+
# plt.savefig("Decoded_test_image_{}.png".format(i))
116+
# #save just decoded png
117+
# print(type(output), output.shape)
118+
# output_img = convert_to_orientation_image(output)
119+
# target_img = convert_to_orientation_image(target)
120+
# # matplotlib.image.imsave('only_original_img_{}'.format(i), input)
121+
# matplotlib.image.imsave('only_decoded_img_{}.png'.format(i), output_img)
122+
# matplotlib.image.imsave('only_target_img_{}.png'.format(i), target_img)
123+
# plt.show()
124+
#
125+
#
126+
127+
if __name__ == '__main__':
128+
parser = argparse.ArgumentParser()
129+
parser.add_argument("model_path", help="relative path to the saved model",
130+
type=str, default='/home/abenjamin/DNN_illusions/data/models/Angle_decoder_4.pt')
131+
parser.add_argument("layer", help="which layer of VGG the model was trained to decode from",
132+
type=int)
133+
134+
parser.add_argument('--no-cuda', action='store_true',
135+
help='Disable CUDA')
136+
parser.add_argument("--card", help="which card to use",
137+
type=int, default =0 )
138+
parser.add_argument('--nonlinear', action='store_true',
139+
help='Use the decoder with nonlinear 2 layer network')
140+
args = parser.parse_args()
141+
args.gpu = not args.no_cuda
142+
143+
os.environ["CUDA_VISIBLE_DEVICES"]=str(args.card)
144+
145+
# note that right now the model is on the cpu
146+
model = load_model(args)
147+
if args.gpu:
148+
model = model.cuda()
149+
150+
# get data
151+
unit_circle = np.zeros((224, 224)).astype(np.bool)
152+
for i in range(224):
153+
for j in range(224):
154+
if (i - 112) ** 2 + (j - 112) ** 2 >= 100 ** 2:
155+
unit_circle[i, j] = True
156+
samples, targets = gen_test_lines(unit_circle)
157+
158+
outputs = pass_test_images(model, samples,args.gpu)
159+
160+
save_and_visualize(outputs, targets)

0 commit comments

Comments
 (0)