Skip to content

Commit c931879

Browse files
committed
Initial commit of the evolving_fisher component. Why do we see the Fisher we do?
1 parent 1a36334 commit c931879

File tree

5 files changed

+22917
-0
lines changed

5 files changed

+22917
-0
lines changed

evolving_fisher/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .fisher_calculators import *

evolving_fisher/deep-linear-simulation.ipynb

Lines changed: 21488 additions & 0 deletions
Large diffs are not rendered by default.

evolving_fisher/examine_alxenet_fishers.ipynb

Lines changed: 704 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
2+
import numpy as np
3+
4+
import matplotlib.pyplot as plt
5+
6+
import torch
7+
import torch.nn as nn
8+
from torchvision import models, transforms
9+
10+
import argparse
11+
import pickle
12+
13+
14+
############################ Generators from Linqi's code ##########################################
15+
16+
def gen_sinusoid(sz, A, omega, rho):
17+
radius = int(sz / 2.0)
18+
[x, y] = torch.meshgrid([torch.tensor(range(-radius, radius)),
19+
torch.tensor(range(-radius, radius))])
20+
x = x.float()
21+
y = y.float()
22+
stimuli = A * torch.cos(0.35 * omega[0] * x + 0.35 * omega[1] * y + rho)
23+
return stimuli
24+
25+
26+
def gen_sinusoid_aperture(ratio, sz, A, omega, rho, polarity):
27+
sin_stimuli = gen_sinusoid(sz, A, omega, rho)
28+
radius = int(sz / 2.0)
29+
[x, y] = torch.meshgrid([torch.tensor(range(-radius, radius)),
30+
torch.tensor(range(-radius, radius))])
31+
aperture = torch.empty(sin_stimuli.size(), dtype=torch.float)
32+
33+
aperture_radius = float(radius) * ratio
34+
aperture[x ** 2 + y ** 2 >= aperture_radius ** 2] = 1 - polarity
35+
aperture[x ** 2 + y ** 2 < aperture_radius ** 2] = polarity
36+
37+
return sin_stimuli * aperture
38+
39+
40+
def center_surround(ratio, sz, theta_center, theta_surround, A, rho):
41+
center = gen_sinusoid_aperture(ratio, sz, A, [torch.cos(theta_center), torch.sin(theta_center)], rho, 1)
42+
surround = gen_sinusoid_aperture(ratio, sz, A, [torch.cos(theta_surround), torch.sin(theta_surround)], rho, 0)
43+
return center + surround
44+
45+
46+
def sinsoid_noise(ratio, sz, A, omega, rho):
47+
radius = int(sz / 2.0)
48+
sin_aperture = gen_sinusoid_aperture(ratio, sz, A, omega, rho, 1)
49+
50+
nrm_dist = normal.Normal(0.0, 0.12)
51+
noise_patch = nrm_dist.sample(sin_aperture.size())
52+
53+
[x, y] = torch.meshgrid([torch.tensor(range(-radius, radius)),
54+
torch.tensor(range(-radius, radius))])
55+
aperture = torch.empty(sin_aperture.size(), dtype=torch.float)
56+
57+
aperture_radius = float(radius) * ratio
58+
aperture[x ** 2 + y ** 2 >= aperture_radius ** 2] = 1
59+
aperture[x ** 2 + y ** 2 < aperture_radius ** 2] = 0
60+
61+
return noise_patch * aperture + sin_aperture
62+
63+
64+
def rgb_sinusoid(theta):
65+
output = torch.zeros(1, 3, 224, 224)
66+
sin_stim = gen_sinusoid(224, A=1, omega=[torch.cos(theta), torch.sin(theta)], rho=0)
67+
for idx in range(3):
68+
output[0, idx, :, :] = sin_stim
69+
70+
return output
71+
72+
73+
def rgb_sine_aperture(theta):
74+
output = torch.zeros(1, 3, 224, 224)
75+
sin_stim = gen_sinusoid_aperture(0.85, 224, A=1, omega=[torch.cos(theta), torch.sin(theta)], rho=0, polarity=1)
76+
# show_stimulus(sin_stim)
77+
for idx in range(3):
78+
output[0, idx, :, :] = sin_stim
79+
80+
return output
81+
82+
83+
def rgb_sine_noise(theta):
84+
output = torch.zeros(1, 3, 224, 224)
85+
sin_stim = sinsoid_noise(0.75, 224, A=1, omega=[torch.cos(theta), torch.sin(theta)], rho=0)
86+
for idx in range(3):
87+
output[0, idx, :, :] = sin_stim
88+
89+
return output
90+
91+
92+
def rgb_center_surround(theta_center, theta_surround):
93+
output = torch.zeros(1, 3, 224, 224)
94+
stimulus = center_surround(0.75, 224, theta_center, theta_surround, A=1, rho=0)
95+
for idx in range(3):
96+
output[0, idx, :, :] = stimulus
97+
return output
98+
99+
100+
def show_stimulus(I):
101+
plt.figure()
102+
plt.axis('off')
103+
plt.imshow(I.detach().numpy(), cmap=plt.gray())
104+
plt.show()
105+
106+
##################################################################################################
107+
108+
109+
def get_fisher_orientations(model, layer, n_angles=120, n_phases=1, delta = 1e-2):
110+
""" Takes a full model (unchopped) along with a layer specification, and returns the fisher information
111+
with respect to orientation of that layer (averaged over phase of sine grating).
112+
113+
Also allows choosing the finite-difference delta
114+
115+
:param n_images: number of phases to average over. Sampled randomly"""
116+
117+
118+
phases = np.linspace(0, np.pi, n_phases)
119+
angles = np.linspace(0, np.pi, n_angles)
120+
121+
122+
# create negative mask. This is a circle centered in the middle of radius 100 pixels
123+
unit_circle = np.zeros((224, 224)).astype(np.bool)
124+
for i in range(224):
125+
for j in range(224):
126+
if (i - 112) ** 2 + (j - 112) ** 2 >= 50 ** 2:
127+
unit_circle[i, j] = True
128+
129+
130+
fishers_at_angle = []
131+
for angle in angles:
132+
# print("\n angle",angle)
133+
"""I'll put all phases in one giant tensor for faster torching"""
134+
all_phases_plus = torch.zeros(n_phases ,3 ,224 ,224).cuda()
135+
all_phases_minus = torch.zeros(n_phases ,3 ,224 ,224).cuda()
136+
137+
for i ,phase in enumerate(phases):
138+
139+
140+
all_phases_plus[i] = rgb_sine_aperture(torch.tensor( angle +delta))
141+
all_phases_minus[i] = rgb_sine_aperture(torch.tensor( angle -delta))
142+
143+
# get the response
144+
plus_resp = get_response(all_phases_plus, model, layer)
145+
146+
size = plus_resp.size()
147+
148+
minus_resp = get_response(all_phases_minus, model, layer)
149+
150+
# get the derivative. Now working in pytorch
151+
df_dtheta = get_derivative(delta, plus_resp, minus_resp)
152+
# reshape to be in terms of examples
153+
df_dtheta = df_dtheta.view(n_phases ,-1)
154+
155+
# average down
156+
fisher = get_fisher(df_dtheta)
157+
fishers_at_angle.append(fisher)
158+
# print("fisher",fisher)
159+
print("response size", size)
160+
return fishers_at_angle
161+
162+
def get_derivative(delta, plus_resp, minus_resp):
163+
"""
164+
Calculates the finite-difference derivative of the network activation w/r/t the relative angle between two lines
165+
:param delta: The perturbation
166+
:param minus_resp: The activations at the negative perturbation
167+
:param plus_resp: The activations at the positive perturbation
168+
169+
:return: The derivative of the activations of the network at specified layer. FLATTENED
170+
171+
>>> f = lambda x: x**2
172+
>>> d = get_derivative(1e-3, f(2+1e-3), f(2-1e-3))
173+
>>> assert np.isclose(d,4)
174+
>>> d = get_derivative(1e-3, f(5+1e-3), f(5-1e-3))
175+
>>> assert np.isclose(d,10)
176+
"""
177+
178+
deriv = (plus_resp - minus_resp )/ (2 * delta)
179+
return deriv
180+
181+
182+
def get_fisher(df_dtheta):
183+
"""
184+
Compute fisher information under Gaussian noise assumption.
185+
186+
Averages over the 0th dimension. (Assumes they're specific examples)
187+
188+
:param df_dtheta: Derivative of a function f w/r/t some parameter theta
189+
:return: The Fisher information (1-dimensional)
190+
"""
191+
192+
fishers = 0
193+
for d in df_dtheta:
194+
fishers += torch.dot(d ,d)
195+
fisher = fishers/ len(df_dtheta)
196+
197+
return fisher
198+
199+
200+
def numpy_to_torch(rgb_image, cuda=True):
201+
"""
202+
Prepares an image for passing through a pytorch network.
203+
:param rgb_image: Numpy tensor, shape (x,y,3)
204+
:return: Pytorch tensor, shape (3,x,y), with channels switched to BGR.
205+
206+
>>> numpy_to_torch(np.ones((224,224,3))).size()
207+
torch.Size([3, 224, 224])
208+
"""
209+
210+
# rgb to bgr
211+
tens = torch.from_numpy(rgb_image[:, :, [2, 1, 0]])
212+
if cuda:
213+
r = tens.permute(2, 0, 1).float().cuda()
214+
else:
215+
r = tens.permute(2, 0, 1).float()
216+
return r
217+
218+
219+
def get_response(torch_image, model, layer):
220+
"""
221+
Gets the response of the network of the VGG at a certain layer to a single image
222+
NOTE: we only return the activations corresponding to the center of the image.
223+
224+
Assumes the network is on the GPU already.
225+
226+
227+
228+
:param image: An torch image
229+
:param layer: which layer? 4,9,16,23,30
230+
:param hooked_model: A model with a
231+
:return: Pytorch tensor of the activations
232+
"""
233+
234+
# preprocess image
235+
mean = torch.Tensor([[[0.485]], [[0.456]], [[0.406]]]).cuda()
236+
std = torch.Tensor([[[0.229]], [[0.224]], [[0.225]]]).cuda()
237+
238+
torch_image = (torch_image - mean) / std
239+
240+
# indexing by None adds a new axis in the beginning
241+
if len(torch_image.size()) == 3:
242+
torch_image = torch_image[None]
243+
244+
# register the hook
245+
outputs_at_layer = []
246+
247+
def hook(module, input, output):
248+
outputs_at_layer.append(output.detach())
249+
250+
# vgg/alexnet or resnet?
251+
if len(list(model.children())) > 4:
252+
handle = list(model.children())[layer].register_forward_hook(hook)
253+
else:
254+
handle = list(model.children())[0][layer].register_forward_hook(hook)
255+
256+
_ = model(torch_image)
257+
258+
# clean up
259+
handle.remove()
260+
r = outputs_at_layer[0]
261+
del outputs_at_layer
262+
return r

0 commit comments

Comments
 (0)