Skip to content

Gridcells gridscales are larger than inputs. #127

@charlesdgburns

Description

@charlesdgburns

Hiya - I am working on some code to estimate grid scales and was borrowing code from this library to plot idealised gridcells with a rectified cosine.
I noticed, however, that there was deviation in the gridscale input and the actual distance between peaks - I haven't boiled down exactly why this happens. I speculate it is related to width of cosine waves and interference patterns.

A grid cell which I intended to give a scale of 20 ended up with a scale of 23, et.c. as below:

Image

The hotfix is to divide the gridscales with a magic number 1.15 (not ideal) before tiling the environment and summing cosine waves:
This could be implemented somewhere here.

Image

Below is some code to reproduce this behaviour:

# Using Tom George's code for rectified cosine grid cell model

import numpy as np
import matplotlib.pyplot as plt

def rotate(vector, theta):
    """Rotates a vector anticlockwise by angle theta."""
    R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    vector_new = np.dot(R, vector)
    return vector_new

def get_vectors_between(pos1, pos2):
    """Calculates vectors between two sets of positions."""
    pos1_ = pos1.reshape(-1, 1, pos1.shape[-1])
    pos2_ = pos2.reshape(1, -1, pos2.shape[-1])
    pos1 = np.repeat(pos1_, pos2_.shape[1], axis=1)
    pos2 = np.repeat(pos2_, pos1_.shape[0], axis=0)
    vectors = pos1 - pos2
    return vectors

def get_flattened_coords(N):
    """Generates flattened coordinates for a square meshgrid."""
    x = np.arange(N)
    y = np.arange(N)
    xv, yv = np.meshgrid(x, y)
    return np.stack((xv.flatten(), yv.flatten()), axis=1)

# --- Customizable parameters ---
N = 100  # Size of the grid
gridscales = np.array([10, 20, 30, 40])  # Grid scales for each neuron
phis = [0, 0, 0,0] # Orientations (in radians) for each neuron
# --- End of customizable parameters ---


n_neurons = len(gridscales)
phase_offsets = np.ones(shape=(n_neurons, 2)) * N/2  # Centered phase offsets

width_ratio = 4 / (3 * np.sqrt(3))
w = []
for i in range(n_neurons):
    w1 = np.array([1.0, 0.0])
    w1 = rotate(w1, np.pi/6+phis[i]) # Apply orientation here, such that baseline has a peak due east
    w2 = rotate(w1, np.pi / 3)
    w3 = rotate(w1, 2 * np.pi / 3)
    w.append(np.array([w1, w2, w3]))
w = np.array(w)

pos = get_flattened_coords(N)
origin = np.ones([n_neurons, 2]) * N/2
vecs = get_vectors_between(origin, pos)

# Tile parameters for efficient calculation
w1 = np.tile(np.expand_dims(w[:, 0, :], axis=1), reps=(1, pos.shape[0], 1))
w2 = np.tile(np.expand_dims(w[:, 1, :], axis=1), reps=(1, pos.shape[0], 1))
w3 = np.tile(np.expand_dims(w[:, 2, :], axis=1), reps=(1, pos.shape[0], 1))

adjusted_gridscales = gridscales/(1.15) # THIS IS AN APPROXIMATE FIX FOR GRID SCALE CHANGING.

tiled_gridscales = np.tile(np.expand_dims(adjusted_gridscales, axis=1), reps=(1, pos.shape[0]))


phi_1 = ((2 * np.pi) / tiled_gridscales) * (vecs * w1).sum(axis=-1)
phi_2 = ((2 * np.pi) / tiled_gridscales) * (vecs * w2).sum(axis=-1)
phi_3 = ((2 * np.pi) / tiled_gridscales) * (vecs * w3).sum(axis=-1)

firingrate = (1 / 3) * (np.cos(phi_1) + np.cos(phi_2) + np.cos(phi_3))

# ... (rest of your code for firing rate calculation and plotting) ...

#calculate the firing rate at the width fraction then shift, scale and rectify at the level
a, b, c = np.array([1,0])@np.array([1,0]), np.array([np.cos(np.pi/3),np.sin(np.pi/3)])@np.array([1,0]), np.array([np.cos(np.pi/3),-np.sin(np.pi/3)])@np.array([1,0])
firing_rate_at_full_width = (1 / 3) * (np.cos(np.pi*width_ratio*a) +
                              np.cos(np.pi*width_ratio*b) +
                              np.cos(np.pi*width_ratio*c))
firing_rate_at_full_width = (1 / 3) * (2*np.cos(np.sqrt(3)*np.pi*width_ratio/2) + 1)
firingrate -= firing_rate_at_full_width
firingrate /= (1 - firing_rate_at_full_width)
firingrate[firingrate < 0] = 0

# Plotting
fig, ax = plt.subplots(1, len(gridscales), figsize=(12,4))
for i, each_cell in enumerate(gridscales):
    ax[i].imshow(firingrate[i].reshape(N, N), cmap='jet', extent=[0, N, 0, N])
    ax[i].set_title(f"Grid scale: {each_cell} \n Orientation: {phis[i]:.2f} rad")
    ax[i].axis('off')

plt.tight_layout()
plt.show()

## Computing and reporting difernences in scale
from skimage.feature import peak_local_max
for i in range(firingrate.shape[0]):
  peaks = peak_local_max(firingrate[i].reshape(N,N))
  peaks = peaks -[N/2,N/2]
  sorted_sizes = np.sort(np.linalg.norm(peaks,axis=1))
  difference = np.mean(sorted_sizes[1:7]-gridscales[i])
  print(f'Mean difference ~{round(difference,3)} from intended scale {gridscales[i]}')
  print(f'Distance from centre for inner six peaks:\n {sorted_sizes[1:7]} \n ----')

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions