-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
Hiya - I am working on some code to estimate grid scales and was borrowing code from this library to plot idealised gridcells with a rectified cosine.
I noticed, however, that there was deviation in the gridscale input and the actual distance between peaks - I haven't boiled down exactly why this happens. I speculate it is related to width of cosine waves and interference patterns.
A grid cell which I intended to give a scale of 20 ended up with a scale of 23, et.c. as below:
The hotfix is to divide the gridscales with a magic number 1.15 (not ideal) before tiling the environment and summing cosine waves:
This could be implemented somewhere here.
Below is some code to reproduce this behaviour:
# Using Tom George's code for rectified cosine grid cell model
import numpy as np
import matplotlib.pyplot as plt
def rotate(vector, theta):
"""Rotates a vector anticlockwise by angle theta."""
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
vector_new = np.dot(R, vector)
return vector_new
def get_vectors_between(pos1, pos2):
"""Calculates vectors between two sets of positions."""
pos1_ = pos1.reshape(-1, 1, pos1.shape[-1])
pos2_ = pos2.reshape(1, -1, pos2.shape[-1])
pos1 = np.repeat(pos1_, pos2_.shape[1], axis=1)
pos2 = np.repeat(pos2_, pos1_.shape[0], axis=0)
vectors = pos1 - pos2
return vectors
def get_flattened_coords(N):
"""Generates flattened coordinates for a square meshgrid."""
x = np.arange(N)
y = np.arange(N)
xv, yv = np.meshgrid(x, y)
return np.stack((xv.flatten(), yv.flatten()), axis=1)
# --- Customizable parameters ---
N = 100 # Size of the grid
gridscales = np.array([10, 20, 30, 40]) # Grid scales for each neuron
phis = [0, 0, 0,0] # Orientations (in radians) for each neuron
# --- End of customizable parameters ---
n_neurons = len(gridscales)
phase_offsets = np.ones(shape=(n_neurons, 2)) * N/2 # Centered phase offsets
width_ratio = 4 / (3 * np.sqrt(3))
w = []
for i in range(n_neurons):
w1 = np.array([1.0, 0.0])
w1 = rotate(w1, np.pi/6+phis[i]) # Apply orientation here, such that baseline has a peak due east
w2 = rotate(w1, np.pi / 3)
w3 = rotate(w1, 2 * np.pi / 3)
w.append(np.array([w1, w2, w3]))
w = np.array(w)
pos = get_flattened_coords(N)
origin = np.ones([n_neurons, 2]) * N/2
vecs = get_vectors_between(origin, pos)
# Tile parameters for efficient calculation
w1 = np.tile(np.expand_dims(w[:, 0, :], axis=1), reps=(1, pos.shape[0], 1))
w2 = np.tile(np.expand_dims(w[:, 1, :], axis=1), reps=(1, pos.shape[0], 1))
w3 = np.tile(np.expand_dims(w[:, 2, :], axis=1), reps=(1, pos.shape[0], 1))
adjusted_gridscales = gridscales/(1.15) # THIS IS AN APPROXIMATE FIX FOR GRID SCALE CHANGING.
tiled_gridscales = np.tile(np.expand_dims(adjusted_gridscales, axis=1), reps=(1, pos.shape[0]))
phi_1 = ((2 * np.pi) / tiled_gridscales) * (vecs * w1).sum(axis=-1)
phi_2 = ((2 * np.pi) / tiled_gridscales) * (vecs * w2).sum(axis=-1)
phi_3 = ((2 * np.pi) / tiled_gridscales) * (vecs * w3).sum(axis=-1)
firingrate = (1 / 3) * (np.cos(phi_1) + np.cos(phi_2) + np.cos(phi_3))
# ... (rest of your code for firing rate calculation and plotting) ...
#calculate the firing rate at the width fraction then shift, scale and rectify at the level
a, b, c = np.array([1,0])@np.array([1,0]), np.array([np.cos(np.pi/3),np.sin(np.pi/3)])@np.array([1,0]), np.array([np.cos(np.pi/3),-np.sin(np.pi/3)])@np.array([1,0])
firing_rate_at_full_width = (1 / 3) * (np.cos(np.pi*width_ratio*a) +
np.cos(np.pi*width_ratio*b) +
np.cos(np.pi*width_ratio*c))
firing_rate_at_full_width = (1 / 3) * (2*np.cos(np.sqrt(3)*np.pi*width_ratio/2) + 1)
firingrate -= firing_rate_at_full_width
firingrate /= (1 - firing_rate_at_full_width)
firingrate[firingrate < 0] = 0
# Plotting
fig, ax = plt.subplots(1, len(gridscales), figsize=(12,4))
for i, each_cell in enumerate(gridscales):
ax[i].imshow(firingrate[i].reshape(N, N), cmap='jet', extent=[0, N, 0, N])
ax[i].set_title(f"Grid scale: {each_cell} \n Orientation: {phis[i]:.2f} rad")
ax[i].axis('off')
plt.tight_layout()
plt.show()
## Computing and reporting difernences in scale
from skimage.feature import peak_local_max
for i in range(firingrate.shape[0]):
peaks = peak_local_max(firingrate[i].reshape(N,N))
peaks = peaks -[N/2,N/2]
sorted_sizes = np.sort(np.linalg.norm(peaks,axis=1))
difference = np.mean(sorted_sizes[1:7]-gridscales[i])
print(f'Mean difference ~{round(difference,3)} from intended scale {gridscales[i]}')
print(f'Distance from centre for inner six peaks:\n {sorted_sizes[1:7]} \n ----')Metadata
Metadata
Assignees
Labels
No labels

