|
| 1 | +from __future__ import division |
| 2 | +import numpy as np |
| 3 | +import scipy.optimize as optimize |
| 4 | + |
| 5 | +def initialize_centers_widths_weights(activations,locations,num_factors): |
| 6 | + temp_data = [] |
| 7 | + temp_R = [] |
| 8 | + for i in range(len(activations)): |
| 9 | + temp_data.append(activations[i].numpy()) |
| 10 | + temp_R.append(locations[i].numpy()) |
| 11 | + activations = temp_data |
| 12 | + locations = temp_R |
| 13 | + del temp_data,temp_R |
| 14 | + mean_activations,mean_locations = mean_image(activations,locations) |
| 15 | + template_center_mean,template_width_mean = hotspot_initialization(mean_activations, |
| 16 | + mean_locations, |
| 17 | + num_factors) |
| 18 | + F = radial_basis(mean_locations,template_center_mean,template_width_mean).T |
| 19 | + trans_F = F.T.copy() |
| 20 | + template_weights_mean = np.linalg.solve(trans_F.dot(F), trans_F.dot(mean_activations)) |
| 21 | + subject_weights_mean = [] |
| 22 | + subject_center_mean = [] |
| 23 | + subject_width_mean = [] |
| 24 | + for i in range(len(activations)): |
| 25 | + F = radial_basis(locations[i], template_center_mean, template_width_mean).T |
| 26 | + trans_F = F.T.copy() |
| 27 | + subject_weights_mean.append(np.linalg.solve(trans_F.dot(F), trans_F.dot(mean_activations))) |
| 28 | + subject_center_mean.append(template_center_mean) |
| 29 | + subject_width_mean.append(template_width_mean) |
| 30 | + |
| 31 | + return template_center_mean,template_width_mean,template_weights_mean,\ |
| 32 | + np.array(subject_center_mean),np.array(subject_width_mean),np.array(subject_weights_mean) |
| 33 | + |
| 34 | + |
| 35 | +def hotspot_initialization(activations, locations , num_factors): |
| 36 | + mean_activations = abs(activations - np.nanmean(activations)) |
| 37 | + centers = np.zeros(shape=(num_factors,locations.shape[1])) |
| 38 | + widths = np.zeros(shape=(num_factors,)) |
| 39 | + |
| 40 | + for k in range(num_factors): |
| 41 | + ind = np.nanargmax(mean_activations) |
| 42 | + centers[k,:] = locations[ind,:] |
| 43 | + widths[k] = init_width(activations,locations,activations[ind],centers[k,:]) |
| 44 | + mean_activations = mean_activations - radial_basis(locations,centers[k,:],widths[k]) |
| 45 | + return centers,widths |
| 46 | + |
| 47 | +def mean_image(activations, locations): |
| 48 | + |
| 49 | + mean_locations = locations[0] |
| 50 | + |
| 51 | + for i in range(1, len(locations)): |
| 52 | + mean_locations = np.vstack([mean_locations, locations[i]]) ## fix this to accomodate differing lengths |
| 53 | + |
| 54 | + mean_locations = np.unique(mean_locations,axis=0) |
| 55 | + mean_activations = np.zeros(shape=(mean_locations.shape[0],)) |
| 56 | + n = np.zeros(shape=(mean_activations.shape)) |
| 57 | + |
| 58 | + for i in range(len(activations)): |
| 59 | + C = intersect(mean_locations,locations[i]) |
| 60 | + mean_locations_ind = get_common_indices(mean_locations,C) |
| 61 | + subject_locations_ind = get_common_indices(locations[i],C) |
| 62 | + mean_locations_ind = np.sort(mean_locations_ind) |
| 63 | + subject_locations_ind = np.sort(subject_locations_ind) |
| 64 | + mean_activations[mean_locations_ind] = mean_activations[mean_locations_ind] + \ |
| 65 | + np.mean(activations[i][:,subject_locations_ind],axis=0) |
| 66 | + n[mean_locations_ind] = n[mean_locations_ind] + 1 |
| 67 | + mean_activations = mean_activations/n |
| 68 | + |
| 69 | + return mean_activations,mean_locations |
| 70 | + |
| 71 | + |
| 72 | +def init_width(activations,locations,weight,c): |
| 73 | + |
| 74 | + start_width = 0 |
| 75 | + objective = lambda w: np.sum(np.abs(activations - weight*radial_basis(locations,c,w))) |
| 76 | + result = optimize.minimize(objective,x0=start_width) |
| 77 | + |
| 78 | + return result.x |
| 79 | + |
| 80 | +def radial_basis(locations, centers, log_widths): |
| 81 | + """The radial basis function used as the shape for the factors""" |
| 82 | + # V x 3 -> 1 x V x 3 |
| 83 | + locations = np.expand_dims(locations,0) |
| 84 | + if len(centers.shape) > 3: |
| 85 | + # 1 x V x 3 -> 1 x 1 x V x 3 |
| 86 | + locations = np.expand_dims(locations,0) |
| 87 | + # S x K x 3 -> S x K x 1 x 3 |
| 88 | + centers = np.expand_dims(centers,len(centers.shape)-1) |
| 89 | + # S x K x V x 3 |
| 90 | + delta2s = (locations - centers)**2 |
| 91 | + # S x K -> S x K x 1 |
| 92 | + log_widths = np.expand_dims(log_widths,len(log_widths.shape)) |
| 93 | + return np.exp(-delta2s.sum(len(delta2s.shape) - 1) / np.exp(log_widths)) |
| 94 | + |
| 95 | +def intersect(A,B): |
| 96 | + return np.array([x for x in set(tuple(x) for x in A) & set(tuple(x) for x in B)]) |
| 97 | + |
| 98 | +def get_common_indices(X,C): |
| 99 | + return np.where((X==C[:,None]).all(-1))[1] |
0 commit comments