Skip to content

Commit efc2498

Browse files
authored
Merge pull request #27 from kenblu24/delete
Ability to delete Neuron(s) or Synapse(s) from the model sanely and help you deal with the fallout of the huge changes in model index
2 parents 8490d90 + 533b665 commit efc2498

File tree

4 files changed

+442
-0
lines changed

4 files changed

+442
-0
lines changed

docs/source/api/SNN.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ Methods
143143
~SNN.stdp_setup
144144
~SNN.set_stdp_enabled_from_mat
145145
~SNN.set_weights_from_mat
146+
~SNN.delete_neuron
147+
~SNN.delete_neurons
148+
~SNN.delete_synapse
149+
~SNN.delete_synapses
146150

147151

148152
.. _inspecting-the-snn:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ commands = [
9494
["python", "test_display.py",],
9595
["python", "test_reset.py",],
9696
["python", "test_dtype.py",],
97+
["python", "test_deletion.py",],
9798
["python", "test_json.py",],
9899
]
99100

src/superneuromat/neuromorphicmodel.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,268 @@ def shorten_spike_train(self, time_steps: int | None = None):
16711671
time_steps = max(bool(self.spike_train), self.stdp_time_steps)
16721672
self.spike_train = self.spike_train[-time_steps:]
16731673

1674+
def delete_neuron(self, neuron_id: int | Neuron, reindex: bool = True, _delete_synapses: bool = True):
1675+
"""Deletes a neuron from the network.
1676+
1677+
Because neurons and synapses are stored in the SNN as lists of parameters, deleting a neuron may
1678+
cause shifts in the indices of other neurons and synapses. If you are manually modifying
1679+
the lists of neuron or synapse parameters, you may find it hard to keep track of what's what.
1680+
1681+
However, if you use :py:class:`Neuron`\\ s and :py:class:`Synapse`\\ s, or :py:class:`NeuronListView`\\ s
1682+
and :py:class:`SynapseListView`\\ s, then the shift in indices will be automatically handled, and those
1683+
objects will reflect the new indices while still referring to the same neurons and synapses that you'd expect.
1684+
1685+
Parameters
1686+
----------
1687+
neuron_id : int or Neuron
1688+
The ID of the neuron to delete.
1689+
1690+
Returns
1691+
-------
1692+
tuple[dict, dict]
1693+
Returns ``(neuron_mapping, synapse_mapping)``, where ``neuron_mapping`` is a mapping
1694+
of neuron IDs from ``{before: after}`` the neuron was deleted, and ``synapse_mapping`` is a mapping
1695+
of synaptic IDs from ``{before: after}`` the neuron was deleted.
1696+
"""
1697+
if isinstance(neuron_id, Neuron):
1698+
neuron_id = neuron_id.idx
1699+
if not is_intlike_catch(neuron_id):
1700+
raise TypeError("neuron_id must be int or Neuron.")
1701+
1702+
# TODO: what about delay chains?
1703+
1704+
# Delete synapses
1705+
synaptic_ids = [
1706+
idx for idx, (pre, post)
1707+
in enumerate(zip(self.pre_synaptic_neuron_ids, self.post_synaptic_neuron_ids))
1708+
if pre == neuron_id or post == neuron_id
1709+
]
1710+
1711+
smap = {}
1712+
if _delete_synapses:
1713+
smap = self.delete_synapses(synaptic_ids, reindex=reindex)
1714+
1715+
if neuron_id in self._neuron_cache:
1716+
self._neuron_cache[neuron_id].idx = None
1717+
del self._neuron_cache[neuron_id]
1718+
1719+
if reindex:
1720+
mapping = {i: i for i in range(neuron_id)}
1721+
mapping |= {i: i - 1 for i in range(neuron_id + 1, self.num_neurons)}
1722+
1723+
# fix broken indices in synapses
1724+
self.pre_synaptic_neuron_ids = [mapping[i] for i in self.pre_synaptic_neuron_ids if i not in synaptic_ids]
1725+
self.post_synaptic_neuron_ids = [mapping[i] for i in self.post_synaptic_neuron_ids if i not in synaptic_ids]
1726+
1727+
# replace affected indices in neuron lists
1728+
for nlist in self._neuronlist_cache:
1729+
indices = set(nlist.indices)
1730+
overlap = indices & mapping.keys()
1731+
if overlap:
1732+
nlist.indices = [mapping[i] for i in nlist.indices if i != neuron_id and i in mapping]
1733+
self.rebuild_connection_ids()
1734+
1735+
# remap neuron IDs in cache
1736+
for idx in range(neuron_id + 1, self.num_neurons):
1737+
if idx in self._neuron_cache:
1738+
self._neuron_cache[idx].idx = idx - 1
1739+
self._neuron_cache[idx - 1] = self._neuron_cache[idx]
1740+
del self._neuron_cache[idx] # delete item in cache with id self.num_neurons - 1 due to left shift in indices
1741+
1742+
# self.neurons.remove(self.neurons[neuron_id])
1743+
del self.neuron_refractory_periods[neuron_id]
1744+
del self.neuron_refractory_periods_state[neuron_id]
1745+
del self.neuron_states[neuron_id]
1746+
del self.neuron_thresholds[neuron_id]
1747+
del self.neuron_leaks[neuron_id]
1748+
del self.neuron_reset_states[neuron_id]
1749+
1750+
if reindex:
1751+
return mapping, smap
1752+
return {}, {}
1753+
1754+
def delete_neurons(self, neuron_ids: list[int] | list[Neuron], reindex: bool = True):
1755+
"""Deletes neurons from the network.
1756+
1757+
Because neurons and synapses are stored in the SNN as lists of parameters, deleting neurons may
1758+
cause shifts in the indices of other neurons and synapses. If you are manually modifying
1759+
the lists of neuron or synapse parameters, you may find it hard to keep track of what's what.
1760+
1761+
However, if you use :py:class:`Neuron`\\ s and :py:class:`Synapse`\\ s, or :py:class:`NeuronListView`\\ s
1762+
and :py:class:`SynapseListView`\\ s, then the shift in indices will be automatically handled, and those
1763+
objects will reflect the new indices while still referring to the same neurons and synapses that you'd expect.
1764+
1765+
Parameters
1766+
----------
1767+
neuron_ids : list[int] | list[Neuron]
1768+
The IDs of the neurons to delete.
1769+
1770+
Returns
1771+
-------
1772+
tuple[dict, dict]
1773+
Returns ``(neuron_mapping, synapse_mapping)``, where ``neuron_mapping`` is a mapping
1774+
of neuron IDs from ``{before: after}`` the neurons were deleted, and ``synapse_mapping`` is a mapping
1775+
of synaptic IDs from ``{before: after}`` the neurons were deleted.
1776+
"""
1777+
indices = set(int(neuron_id) for neuron_id in neuron_ids if isinstance(neuron_id, (int, Neuron)))
1778+
indices = list(indices)
1779+
indices.sort(reverse=True)
1780+
num_neurons = self.num_neurons
1781+
1782+
synaptic_ids = [
1783+
idx for idx, (pre, post)
1784+
in enumerate(zip(self.pre_synaptic_neuron_ids, self.post_synaptic_neuron_ids))
1785+
if pre in indices or post in indices
1786+
]
1787+
1788+
for neuron_id in indices:
1789+
self.delete_neuron(neuron_id, reindex=False, _delete_synapses=False)
1790+
1791+
smap = self.delete_synapses(synaptic_ids, reindex=reindex)
1792+
1793+
for idx in indices:
1794+
if idx in self._neuron_cache:
1795+
self._neuron_cache[idx].idx = None
1796+
del self._neuron_cache[idx]
1797+
1798+
if not reindex:
1799+
return {}, {}
1800+
1801+
remaining_idxs = (idx for idx in range(num_neurons) if idx not in indices) # sorted
1802+
mapping = {old: new for new, old in enumerate(remaining_idxs)}
1803+
1804+
for neuron in self._neuron_cache.values():
1805+
neuron.idx = mapping[neuron.idx]
1806+
self._neuron_cache = {neuron.idx: neuron for neuron in self._neuron_cache.values()}
1807+
1808+
# fix broken indices in synapses
1809+
self.pre_synaptic_neuron_ids = [mapping[i] for i in self.pre_synaptic_neuron_ids]
1810+
self.post_synaptic_neuron_ids = [mapping[i] for i in self.post_synaptic_neuron_ids]
1811+
1812+
# replace affected indices in neuron lists
1813+
for nlist in self._neuronlist_cache:
1814+
indices = set(nlist.indices)
1815+
overlap = indices & mapping.keys()
1816+
if overlap:
1817+
nlist.indices = [mapping[i] for i in nlist.indices if i not in neuron_ids and i in mapping]
1818+
self.rebuild_connection_ids()
1819+
return mapping, smap
1820+
1821+
def delete_synapse(self, synapse_id: int | Synapse, reindex: bool = True, _rebuild_connection_ids: bool = True):
1822+
"""Deletes a synapse from the network.
1823+
1824+
Because synapses are stored in the SNN as a list, deleting a synapse may
1825+
cause a shift in the indices of other synapses. If you are manually modifying
1826+
the lists of synapse parameters, you may find it hard to keep track of what's what.
1827+
1828+
However, if you use :py:class:`Synapse`\\ s or a :py:class:`SynapseListView`,
1829+
then the shift in indices will be automatically handled, and those objects will
1830+
reflect the new indices while still referring to the same synapses that you'd expect.
1831+
1832+
.. warning::
1833+
1834+
Deleting synapses may result in unexpected behavior, as it can cause
1835+
large shifts in the indices of synapses. Use with caution.
1836+
1837+
Parameters
1838+
----------
1839+
synapse_id : int or Synapse
1840+
The ID of the synapse to delete.
1841+
1842+
Returns
1843+
-------
1844+
dict
1845+
A mapping of synaptic IDs from ``{before: after}`` the synapse was deleted. May be empty.
1846+
"""
1847+
if isinstance(synapse_id, Synapse):
1848+
synapse_id = synapse_id.idx
1849+
if not is_intlike_catch(synapse_id):
1850+
raise TypeError("synapse_id must be int or Synapse.")
1851+
1852+
if synapse_id in self._synapse_cache:
1853+
self._synapse_cache[synapse_id].idx = None
1854+
del self._synapse_cache[synapse_id]
1855+
1856+
if reindex:
1857+
mapping = {i: i for i in range(synapse_id)}
1858+
for syn_id in range(synapse_id, self.num_synapses):
1859+
mapping[syn_id] = syn_id - 1
1860+
print(syn_id)
1861+
if syn_id in self._synapse_cache:
1862+
print(syn_id, ' in cache')
1863+
self._synapse_cache[syn_id].idx = syn_id - 1
1864+
self._synapse_cache[syn_id - 1] = self._synapse_cache[syn_id]
1865+
del mapping[synapse_id]
1866+
del self._synapse_cache[syn_id] # delete the last cached synapse after the shift
1867+
1868+
pair = (self.pre_synaptic_neuron_ids[synapse_id], self.post_synaptic_neuron_ids[synapse_id])
1869+
if pair in self.connection_ids:
1870+
del self.connection_ids[pair]
1871+
del self.pre_synaptic_neuron_ids[synapse_id]
1872+
del self.post_synaptic_neuron_ids[synapse_id]
1873+
del self.synaptic_weights[synapse_id]
1874+
del self.synaptic_delays[synapse_id]
1875+
del self.enable_stdp[synapse_id]
1876+
if reindex:
1877+
for slist in self._synapselist_cache:
1878+
indices = set(slist.indices)
1879+
overlap = indices & mapping.keys()
1880+
if overlap:
1881+
slist.indices = [mapping[i] for i in slist.indices if i != synapse_id and i in mapping]
1882+
if _rebuild_connection_ids:
1883+
self.rebuild_connection_ids()
1884+
return mapping
1885+
return {}
1886+
1887+
def delete_synapses(self, synapse_ids: Sequence[int] | Sequence[Synapse], reindex: bool = True, _rebuild_connection_ids: bool = True):
1888+
"""Deletes a list of synapses from the network.
1889+
1890+
Because synapses are stored in the SNN as a list, deleting synapses may
1891+
cause a shift in the indices of other synapses. If you are manually modifying
1892+
the lists of synapse parameters, you may find it hard to keep track of what's what.
1893+
1894+
However, if you use :py:class:`Synapse`\\ s or a :py:class:`SynapseListView`,
1895+
then the shift in indices will be automatically handled, and those objects will
1896+
reflect the new indices while still referring to the same synapses that you'd expect.
1897+
1898+
.. warning::
1899+
1900+
Deleting synapses may result in unexpected behavior, as it can cause
1901+
large shifts in the indices of synapses. Use with caution.
1902+
1903+
Parameters
1904+
----------
1905+
synapse_ids : list[int] | list[Synapse]
1906+
The IDs of the synapses to delete.
1907+
"""
1908+
indices = set(int(synapse_id) for synapse_id in synapse_ids if isinstance(synapse_id, (int, Synapse)))
1909+
indices = list(indices)
1910+
indices.sort(reverse=True)
1911+
num_synapses = self.num_synapses
1912+
for synapse_id in indices:
1913+
self.delete_synapse(synapse_id, reindex=False)
1914+
if not reindex:
1915+
return {}
1916+
1917+
remaining_idxs = (idx for idx in range(num_synapses) if idx not in indices) # sorted
1918+
mapping = {old: new for new, old in enumerate(remaining_idxs)}
1919+
1920+
for synapse in self._synapse_cache.values():
1921+
synapse.idx = mapping[synapse.idx]
1922+
self._synapse_cache = {synapse.idx: synapse for synapse in self._synapse_cache.values()}
1923+
1924+
if _rebuild_connection_ids:
1925+
self.rebuild_connection_ids()
1926+
1927+
for slist in self._synapselist_cache:
1928+
indices = set(slist.indices)
1929+
overlap = indices & mapping.keys()
1930+
if not overlap:
1931+
continue
1932+
slist.indices = [mapping[i] for i in slist.indices if i in mapping]
1933+
1934+
return mapping
1935+
16741936
_internal_vars = [
16751937
"_neuron_thresholds", "_neuron_leaks", "_neuron_reset_states", "_internal_states",
16761938
"_neuron_refractory_periods", "_neuron_refractory_periods_original", "_weights",

0 commit comments

Comments
 (0)