Skip to content

Commit 8490d90

Browse files
authored
Merge pull request #26 from kenblu24/accessor-props
Added methods to NeuronListView to add and delete spikes.
2 parents 3d67f49 + 2e0880c commit 8490d90

File tree

4 files changed

+238
-3
lines changed

4 files changed

+238
-3
lines changed

src/superneuromat/accessor_classes.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,16 +1474,158 @@ def __iter__(self):
14741474
def ispikes(self):
14751475
return self.m.ispikes[:, self.indices]
14761476

1477+
def add_spike(self, time: int, idx: int, value=1.0, exist='error'):
1478+
"""Adds an external spike in the SNN
1479+
1480+
Parameters
1481+
----------
1482+
time : int
1483+
The time step at which the external spike is added
1484+
idx : int
1485+
The neuron for which the external spike is added
1486+
value : float
1487+
The value of the external spike (default: 1.0)
1488+
exist : str
1489+
action for existing spikes on a neuron at a given time step.
1490+
Should be one of ['error', 'overwrite', 'add', 'dontadd']. (default: 'error')
1491+
1492+
if exist='add', the existing spike value is added to the new value.
1493+
1494+
Raises
1495+
------
1496+
TypeError
1497+
if:
1498+
1499+
* time cannot be precisely cast to int
1500+
* neuron_id is not an int
1501+
* value is not an int or float
1502+
1503+
ValueError
1504+
if spike already exists at that neuron and timestep and exist='error',
1505+
or if exist is an invalid setting.
1506+
1507+
See Also
1508+
--------
1509+
SNN.add_spike
1510+
Neuron.add_spike
1511+
"""
1512+
self.m.add_spike(time, self.indices[idx], value, exist)
1513+
1514+
def add_spikes(
1515+
self,
1516+
spikes: float | Sequence[float] | Sequence[Sequence[float]] | np.ndarray[(int,), dtype] | np.ndarray[(int, int), dtype],
1517+
time_offset: int = 0,
1518+
exist: str = 'error',
1519+
):
1520+
"""Add a time-series of spikes to this neuron.
1521+
1522+
Parameters
1523+
----------
1524+
spikes : numpy.typing.ArrayLike
1525+
time_offset : int, default=0
1526+
The number of time steps to offset the spikes by.
1527+
exist : str, default='error'
1528+
Action if a queued spike already exists at the given time step.
1529+
Should be one of ['error', 'overwrite', 'add', 'dontadd'].
1530+
1531+
Note: ``0.0``-valued spikes are not added unless ``exist='overwrite'``.
1532+
1533+
1534+
If the input is a scalar, a single spike is sent to each neuron in this NeuronListView
1535+
at time ``time_offset``.
1536+
1537+
If the input is a 1-dimensional array, it is assumed to be the values of spikes to send to each neuron
1538+
in this NeuronListView at time ``time_offset``.
1539+
1540+
If the input is a 2-dimensional array, it is assumed to be, for each time step, a list of values to send
1541+
to each neuron in this NeuronListView, starting at time ``time_offset``. That is, the first row of the array
1542+
corresponds to the first time step, the second row to the second time step, and so on.
1543+
"""
1544+
arr = np.asarray(spikes, dtype=self.m.default_dtype)
1545+
if arr.ndim == 0:
1546+
arr = np.broadcast_to(arr, (1, len(self.indices)))
1547+
elif arr.ndim == 1:
1548+
if arr.shape != (len(self.indices),):
1549+
msg = ("add_spikes() received a 1-dimensional array, which is assumed to be "
1550+
"the values of spikes to send to each neuron in this NeuronListView. "
1551+
f"Expected {len(self.indices)} values, but received {arr.shape[0]}.")
1552+
raise ValueError(msg)
1553+
arr = arr.reshape((1, -1))
1554+
elif arr.ndim == 2:
1555+
if arr.shape[1] != len(self.indices):
1556+
msg = ("add_spikes() received a 2-dimensional array, which is assumed to be, "
1557+
"for each time step, a list of values to send to each neuron in this NeuronListView. "
1558+
f"Expected arr.shape[1] == {len(self.indices)}, but received values "
1559+
f"for {arr.shape[1]} neurons.")
1560+
raise ValueError(msg)
1561+
for time, vec in enumerate(arr):
1562+
for idx, value in zip(self.indices, vec):
1563+
if value == 0.0 and exist != 'overwrite':
1564+
continue
1565+
self.m.add_spike(time + time_offset, idx, value, exist)
1566+
14771567
def pretty_spike_train(self, max_steps=None, max_neurons=None, use_unicode=True, indices=None):
1568+
"""Returns a list[str] showing the spike train for each neuron in this NeuronListView.
1569+
1570+
See Also
1571+
--------
1572+
SNN.pretty_spike_train
1573+
"""
14781574
if indices is None:
14791575
indices = self.indices
14801576
return util.pretty_spike_train(self.ispikes, max_steps, max_neurons, use_unicode, indices)
14811577

14821578
def print_spike_train(self, max_steps=None, max_neurons=None, use_unicode=True, indices=None):
1579+
"""Prints the spike train for each neuron in this NeuronListView.
1580+
1581+
See Also
1582+
--------
1583+
SNN.print_spike_train
1584+
"""
14831585
if indices is None:
14841586
indices = self.indices
14851587
util.print_spike_train(self.ispikes, max_steps, max_neurons, use_unicode, indices)
14861588

1589+
def clear_input_spikes(self, t: int | slice | list | np.ndarray | None = None,
1590+
destination: int | slice | list | np.ndarray | None = None,
1591+
remove_empty: bool = True):
1592+
"""Delete input spikes from the SNN.
1593+
1594+
Parameters
1595+
----------
1596+
t : int | slice | list | np.ndarray | None, default=None
1597+
The time step(s) from which to delete input spikes.
1598+
If ``None``, delete input spikes across all time steps.
1599+
destination : int | Neuron | slice | list | np.ndarray | None, default=None
1600+
The neuron(s) from which to delete input spikes.
1601+
If ``None``, delete all input spikes from only neurons in this NeuronListView.
1602+
remove_empty : bool, default=True
1603+
If ``True``, remove empty time steps from the input spike train.
1604+
1605+
See Also
1606+
--------
1607+
SNN.clear_input_spikes
1608+
"""
1609+
if destination is None:
1610+
destination = set(self.indices)
1611+
else:
1612+
if isinstance(destination, (int, np.integer)):
1613+
destination = [destination]
1614+
elif isinstance(destination, slice):
1615+
destination = slice_indices(destination, len(self))
1616+
elif isinstance(destination, np.ndarray):
1617+
destination = list(set(destination.tolist()))
1618+
else:
1619+
try:
1620+
destination = [int(idx) for idx in destination]
1621+
except (TypeError, ValueError) as err:
1622+
msg = f"clear_input_spikes() expected int, slice, list, or None, but received {type(destination)}"
1623+
raise TypeError(msg) from err
1624+
destination = list(set(destination))
1625+
destination = [self.indices[idx] for idx in destination]
1626+
1627+
self.m.clear_input_spikes(t, destination, remove_empty)
1628+
14871629

14881630
class NeuronIterator(ModelListIterator):
14891631
accessor_type = Neuron

src/superneuromat/neuromorphicmodel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,7 @@ def clear_spike_train(self):
14961496
self.spike_train = []
14971497

14981498
def clear_input_spikes(self, t: int | slice | list | np.ndarray | None = None,
1499-
destination: int | Neuron | slice | list | np.ndarray | None = None,
1499+
destination: int | Neuron | slice | list | np.ndarray | set | None = None,
15001500
remove_empty: bool = True):
15011501
"""Delete input spikes from the SNN.
15021502
@@ -1524,7 +1524,7 @@ def clear_input_spikes(self, t: int | slice | list | np.ndarray | None = None,
15241524
# normalize times to delete
15251525
if isinstance(t, slice):
15261526
times_to_delete = set(self.input_spikes.keys()) & set(slice_indices(t, max(self.input_spikes)))
1527-
elif isinstance(t, int):
1527+
elif isinstance(t, (int, np.integer)):
15281528
times_to_delete = [t] if t in self.input_spikes else []
15291529
elif t is None:
15301530
times_to_delete = list(self.input_spikes.keys())
@@ -1539,7 +1539,7 @@ def clear_input_spikes(self, t: int | slice | list | np.ndarray | None = None,
15391539
raise TypeError(msg) from err
15401540

15411541
# normalize destinations to delete
1542-
if isinstance(destination, (int, Neuron)):
1542+
if isinstance(destination, (int, np.integer, Neuron)):
15431543
destination = [int(destination)]
15441544
elif destination is None:
15451545
pass

src/superneuromat/util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def pretty_spike_train(
188188
use_unicode : bool, default=True
189189
If ``True``, use unicode characters to represent spikes.
190190
Otherwise fallback to ascii characters.
191+
indices : list[int] | None, default=None
192+
If provided, show these indices in the header of the output.
193+
Otherwise, enumerate them from 0 to the number of neurons in the spike train.
191194
"""
192195
lines = []
193196
steps = len(spike_train)
@@ -258,5 +261,8 @@ def print_spike_train(spike_train, max_steps=None, max_neurons=None, use_unicode
258261
use_unicode : bool, default=True
259262
If ``True``, use unicode characters to represent spikes.
260263
Otherwise fallback to ascii characters.
264+
indices : list[int] | None, default=None
265+
If provided, show these indices in the header of the output.
266+
Otherwise, enumerate them from 0 to the number of neurons in the spike train.
261267
"""
262268
print('\n'.join(pretty_spike_train(spike_train, max_steps, max_neurons, use_unicode, indices)))

tests/test_spikes.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,5 +207,92 @@ def remove_empty(d: dict):
207207
9: {"nids": [0], "values": [1.0]},
208208
}
209209

210+
def test_neuronlist_clear_input_spikes(self):
211+
"""Test the clear_input_spikes function"""
212+
print("begin test_neuronlist_clear_input_spikes")
213+
snn = SNN()
214+
for _i in range(4):
215+
snn.create_neuron()
216+
217+
inputs = snn.neurons[2:]
218+
a, b = inputs
219+
a.add_spikes([
220+
(0, 1.0),
221+
(1, 1.0),
222+
(2, 1.0),
223+
(3, 1.0),
224+
])
225+
b.add_spikes([1, 1])
226+
snn.neurons[0].add_spike(4, 9.0)
227+
print(snn.input_spikes_info())
228+
assert snn.input_spikes == {
229+
0: {"nids": [2, 3], "values": [1.0, 1.0]},
230+
1: {"nids": [2, 3], "values": [1.0, 1.0]},
231+
2: {"nids": [2], "values": [1.0]},
232+
3: {"nids": [2], "values": [1.0]},
233+
4: {"nids": [0], "values": [9.0]},
234+
}
235+
inputs.clear_input_spikes(destination=1)
236+
inputs.clear_input_spikes(t=slice(1, 3))
237+
assert snn.input_spikes == {
238+
0: {"nids": [2], "values": [1.0]},
239+
3: {"nids": [2], "values": [1.0]},
240+
4: {"nids": [0], "values": [9.0]},
241+
}
242+
inputs.clear_input_spikes(destination=[0])
243+
assert snn.input_spikes == {
244+
4: {"nids": [0], "values": [9.0]},
245+
}
246+
snn.input_spikes = {
247+
0: {"nids": [1, 2, 3], "values": [2.0, 5.0, 1.0]},
248+
1: {"nids": [2, 3], "values": [1.0, 1.0]},
249+
2: {"nids": [2], "values": [1.0]},
250+
3: {"nids": [2], "values": [1.0]},
251+
4: {"nids": [0], "values": [9.0]},
252+
}
253+
inputs.clear_input_spikes(destination=slice(0, 1))
254+
assert snn.input_spikes == {
255+
0: {"nids": [1, 3], "values": [2.0, 1.0]},
256+
1: {"nids": [3], "values": [1.0]},
257+
4: {"nids": [0], "values": [9.0]},
258+
}
259+
260+
def test_neuronlist_add_spike(self):
261+
"""Test the add_spike function"""
262+
print("begin test_neuronlist_add_spike")
263+
snn = SNN()
264+
for _i in range(4):
265+
snn.create_neuron()
266+
267+
inputs = snn.neurons[2:]
268+
inputs.add_spike(0, 1, 2.0)
269+
assert snn.input_spikes == {0: {"nids": [3], "values": [2.0]}}
270+
271+
def test_neuronlist_add_input_spikes(self):
272+
"""Test the add_input_spikes function"""
273+
print("begin test_neuronlist_add_input_spikes")
274+
snn = SNN()
275+
for _i in range(4):
276+
snn.create_neuron()
277+
278+
inputs = snn.neurons[2:]
279+
inputs.add_spikes([1, 2])
280+
assert snn.input_spikes == {0: {"nids": [2, 3], "values": [1.0, 2.0]}}
281+
inputs.add_spikes([[0, 2], [3, 4]], exist='overwrite')
282+
inputs.add_spikes([[0, 0], [0, 0]], exist='ignore')
283+
assert snn.input_spikes == {
284+
0: {"nids": [2, 3], "values": [0.0, 2.0]},
285+
1: {"nids": [2, 3], "values": [3.0, 4.0]}
286+
}
287+
snn.clear_input_spikes()
288+
assert snn.input_spikes == {}
289+
inputs.add_spikes(2, time_offset=5)
290+
assert snn.input_spikes == {5: {"nids": [2, 3], "values": [2.0, 2.0]}}
291+
with self.assertRaises(ValueError):
292+
inputs.add_spikes([0, 1, 2])
293+
with self.assertRaises(ValueError):
294+
inputs.add_spikes([[0, 1, 2], [3, 4, 5]], time_offset=5)
295+
296+
210297
if __name__ == "__main__":
211298
unittest.main()

0 commit comments

Comments
 (0)