1+
2+ import numpy as np
3+
4+ import matplotlib .pyplot as plt
5+
6+ import torch
7+ import torch .nn as nn
8+ from torchvision import models , transforms
9+
10+ import argparse
11+ import pickle
12+
13+
14+ ############################ Generators from Linqi's code ##########################################
15+
16+ def gen_sinusoid (sz , A , omega , rho ):
17+ radius = int (sz / 2.0 )
18+ [x , y ] = torch .meshgrid ([torch .tensor (range (- radius , radius )),
19+ torch .tensor (range (- radius , radius ))])
20+ x = x .float ()
21+ y = y .float ()
22+ stimuli = A * torch .cos (0.35 * omega [0 ] * x + 0.35 * omega [1 ] * y + rho )
23+ return stimuli
24+
25+
26+ def gen_sinusoid_aperture (ratio , sz , A , omega , rho , polarity ):
27+ sin_stimuli = gen_sinusoid (sz , A , omega , rho )
28+ radius = int (sz / 2.0 )
29+ [x , y ] = torch .meshgrid ([torch .tensor (range (- radius , radius )),
30+ torch .tensor (range (- radius , radius ))])
31+ aperture = torch .empty (sin_stimuli .size (), dtype = torch .float )
32+
33+ aperture_radius = float (radius ) * ratio
34+ aperture [x ** 2 + y ** 2 >= aperture_radius ** 2 ] = 1 - polarity
35+ aperture [x ** 2 + y ** 2 < aperture_radius ** 2 ] = polarity
36+
37+ return sin_stimuli * aperture
38+
39+
40+ def center_surround (ratio , sz , theta_center , theta_surround , A , rho ):
41+ center = gen_sinusoid_aperture (ratio , sz , A , [torch .cos (theta_center ), torch .sin (theta_center )], rho , 1 )
42+ surround = gen_sinusoid_aperture (ratio , sz , A , [torch .cos (theta_surround ), torch .sin (theta_surround )], rho , 0 )
43+ return center + surround
44+
45+
46+ def sinsoid_noise (ratio , sz , A , omega , rho ):
47+ radius = int (sz / 2.0 )
48+ sin_aperture = gen_sinusoid_aperture (ratio , sz , A , omega , rho , 1 )
49+
50+ nrm_dist = normal .Normal (0.0 , 0.12 )
51+ noise_patch = nrm_dist .sample (sin_aperture .size ())
52+
53+ [x , y ] = torch .meshgrid ([torch .tensor (range (- radius , radius )),
54+ torch .tensor (range (- radius , radius ))])
55+ aperture = torch .empty (sin_aperture .size (), dtype = torch .float )
56+
57+ aperture_radius = float (radius ) * ratio
58+ aperture [x ** 2 + y ** 2 >= aperture_radius ** 2 ] = 1
59+ aperture [x ** 2 + y ** 2 < aperture_radius ** 2 ] = 0
60+
61+ return noise_patch * aperture + sin_aperture
62+
63+
64+ def rgb_sinusoid (theta ):
65+ output = torch .zeros (1 , 3 , 224 , 224 )
66+ sin_stim = gen_sinusoid (224 , A = 1 , omega = [torch .cos (theta ), torch .sin (theta )], rho = 0 )
67+ for idx in range (3 ):
68+ output [0 , idx , :, :] = sin_stim
69+
70+ return output
71+
72+
73+ def rgb_sine_aperture (theta ):
74+ output = torch .zeros (1 , 3 , 224 , 224 )
75+ sin_stim = gen_sinusoid_aperture (0.85 , 224 , A = 1 , omega = [torch .cos (theta ), torch .sin (theta )], rho = 0 , polarity = 1 )
76+ # show_stimulus(sin_stim)
77+ for idx in range (3 ):
78+ output [0 , idx , :, :] = sin_stim
79+
80+ return output
81+
82+
83+ def rgb_sine_noise (theta ):
84+ output = torch .zeros (1 , 3 , 224 , 224 )
85+ sin_stim = sinsoid_noise (0.75 , 224 , A = 1 , omega = [torch .cos (theta ), torch .sin (theta )], rho = 0 )
86+ for idx in range (3 ):
87+ output [0 , idx , :, :] = sin_stim
88+
89+ return output
90+
91+
92+ def rgb_center_surround (theta_center , theta_surround ):
93+ output = torch .zeros (1 , 3 , 224 , 224 )
94+ stimulus = center_surround (0.75 , 224 , theta_center , theta_surround , A = 1 , rho = 0 )
95+ for idx in range (3 ):
96+ output [0 , idx , :, :] = stimulus
97+ return output
98+
99+
100+ def show_stimulus (I ):
101+ plt .figure ()
102+ plt .axis ('off' )
103+ plt .imshow (I .detach ().numpy (), cmap = plt .gray ())
104+ plt .show ()
105+
106+ ##################################################################################################
107+
108+
109+ def get_fisher_orientations (model , layer , n_angles = 120 , n_phases = 1 , delta = 1e-2 ):
110+ """ Takes a full model (unchopped) along with a layer specification, and returns the fisher information
111+ with respect to orientation of that layer (averaged over phase of sine grating).
112+
113+ Also allows choosing the finite-difference delta
114+
115+ :param n_images: number of phases to average over. Sampled randomly"""
116+
117+
118+ phases = np .linspace (0 , np .pi , n_phases )
119+ angles = np .linspace (0 , np .pi , n_angles )
120+
121+
122+ # create negative mask. This is a circle centered in the middle of radius 100 pixels
123+ unit_circle = np .zeros ((224 , 224 )).astype (np .bool )
124+ for i in range (224 ):
125+ for j in range (224 ):
126+ if (i - 112 ) ** 2 + (j - 112 ) ** 2 >= 50 ** 2 :
127+ unit_circle [i , j ] = True
128+
129+
130+ fishers_at_angle = []
131+ for angle in angles :
132+ # print("\n angle",angle)
133+ """I'll put all phases in one giant tensor for faster torching"""
134+ all_phases_plus = torch .zeros (n_phases ,3 ,224 ,224 ).cuda ()
135+ all_phases_minus = torch .zeros (n_phases ,3 ,224 ,224 ).cuda ()
136+
137+ for i ,phase in enumerate (phases ):
138+
139+
140+ all_phases_plus [i ] = rgb_sine_aperture (torch .tensor ( angle + delta ))
141+ all_phases_minus [i ] = rgb_sine_aperture (torch .tensor ( angle - delta ))
142+
143+ # get the response
144+ plus_resp = get_response (all_phases_plus , model , layer )
145+
146+ size = plus_resp .size ()
147+
148+ minus_resp = get_response (all_phases_minus , model , layer )
149+
150+ # get the derivative. Now working in pytorch
151+ df_dtheta = get_derivative (delta , plus_resp , minus_resp )
152+ # reshape to be in terms of examples
153+ df_dtheta = df_dtheta .view (n_phases ,- 1 )
154+
155+ # average down
156+ fisher = get_fisher (df_dtheta )
157+ fishers_at_angle .append (fisher )
158+ # print("fisher",fisher)
159+ print ("response size" , size )
160+ return fishers_at_angle
161+
162+ def get_derivative (delta , plus_resp , minus_resp ):
163+ """
164+ Calculates the finite-difference derivative of the network activation w/r/t the relative angle between two lines
165+ :param delta: The perturbation
166+ :param minus_resp: The activations at the negative perturbation
167+ :param plus_resp: The activations at the positive perturbation
168+
169+ :return: The derivative of the activations of the network at specified layer. FLATTENED
170+
171+ >>> f = lambda x: x**2
172+ >>> d = get_derivative(1e-3, f(2+1e-3), f(2-1e-3))
173+ >>> assert np.isclose(d,4)
174+ >>> d = get_derivative(1e-3, f(5+1e-3), f(5-1e-3))
175+ >>> assert np.isclose(d,10)
176+ """
177+
178+ deriv = (plus_resp - minus_resp )/ (2 * delta )
179+ return deriv
180+
181+
182+ def get_fisher (df_dtheta ):
183+ """
184+ Compute fisher information under Gaussian noise assumption.
185+
186+ Averages over the 0th dimension. (Assumes they're specific examples)
187+
188+ :param df_dtheta: Derivative of a function f w/r/t some parameter theta
189+ :return: The Fisher information (1-dimensional)
190+ """
191+
192+ fishers = 0
193+ for d in df_dtheta :
194+ fishers += torch .dot (d ,d )
195+ fisher = fishers / len (df_dtheta )
196+
197+ return fisher
198+
199+
200+ def numpy_to_torch (rgb_image , cuda = True ):
201+ """
202+ Prepares an image for passing through a pytorch network.
203+ :param rgb_image: Numpy tensor, shape (x,y,3)
204+ :return: Pytorch tensor, shape (3,x,y), with channels switched to BGR.
205+
206+ >>> numpy_to_torch(np.ones((224,224,3))).size()
207+ torch.Size([3, 224, 224])
208+ """
209+
210+ # rgb to bgr
211+ tens = torch .from_numpy (rgb_image [:, :, [2 , 1 , 0 ]])
212+ if cuda :
213+ r = tens .permute (2 , 0 , 1 ).float ().cuda ()
214+ else :
215+ r = tens .permute (2 , 0 , 1 ).float ()
216+ return r
217+
218+
219+ def get_response (torch_image , model , layer ):
220+ """
221+ Gets the response of the network of the VGG at a certain layer to a single image
222+ NOTE: we only return the activations corresponding to the center of the image.
223+
224+ Assumes the network is on the GPU already.
225+
226+
227+
228+ :param image: An torch image
229+ :param layer: which layer? 4,9,16,23,30
230+ :param hooked_model: A model with a
231+ :return: Pytorch tensor of the activations
232+ """
233+
234+ # preprocess image
235+ mean = torch .Tensor ([[[0.485 ]], [[0.456 ]], [[0.406 ]]]).cuda ()
236+ std = torch .Tensor ([[[0.229 ]], [[0.224 ]], [[0.225 ]]]).cuda ()
237+
238+ torch_image = (torch_image - mean ) / std
239+
240+ # indexing by None adds a new axis in the beginning
241+ if len (torch_image .size ()) == 3 :
242+ torch_image = torch_image [None ]
243+
244+ # register the hook
245+ outputs_at_layer = []
246+
247+ def hook (module , input , output ):
248+ outputs_at_layer .append (output .detach ())
249+
250+ # vgg/alexnet or resnet?
251+ if len (list (model .children ())) > 4 :
252+ handle = list (model .children ())[layer ].register_forward_hook (hook )
253+ else :
254+ handle = list (model .children ())[0 ][layer ].register_forward_hook (hook )
255+
256+ _ = model (torch_image )
257+
258+ # clean up
259+ handle .remove ()
260+ r = outputs_at_layer [0 ]
261+ del outputs_at_layer
262+ return r
0 commit comments