Skip to content

Commit fd3cbd0

Browse files
Clip small negative values in van_rossum_distance before np.sqrt (NeuralEnsemble#680)
* Clipped small negative values in van_rossum_distance to prevent Nan from np.sqrt; added test case for same * refined van rossum distance funtion warning and regression test per review * Grammar improved in warning message * Adjusted spacing in warning message
1 parent db5a5f0 commit fd3cbd0

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

elephant/spike_train_dissimilarity.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"""
2222

2323
from __future__ import division, print_function, unicode_literals
24-
24+
import warnings
2525
import numpy as np
2626
import quantities as pq
2727
from neo.core import SpikeTrain
@@ -363,6 +363,18 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True):
363363
for i, j in np.ndindex(k_dist.shape):
364364
vr_dist[i, j] = (
365365
k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i])
366+
367+
# Clip small negative values
368+
if np.any(vr_dist < 0):
369+
warnings.warn(
370+
"van_rossum_distance: very small negative values encountered; "
371+
"setting them to zero. Potentially due to floating point error, "
372+
"which can occur if spike times are represented as small floating "
373+
"point values (e.g., in seconds). A possible way to prevent this "
374+
"warning is to use a time unit with better numerical precision, "
375+
"e.g., from seconds to milliseconds.", RuntimeWarning)
376+
vr_dist = np.maximum(vr_dist, 0.0)
377+
366378
return np.sqrt(vr_dist)
367379

368380

elephant/test/test_spike_train_dissimilarity.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import elephant.kernels as kernels
1515
from elephant.spike_train_generation import StationaryPoissonProcess
1616
import elephant.spike_train_dissimilarity as stds
17-
17+
import warnings
1818
from elephant.datasets import download_datasets
1919

2020

@@ -73,7 +73,6 @@ def setUp(self):
7373
self.tau7 = 0.01 * s
7474
self.q7 = 1.0 / self.tau7
7575
self.t = np.linspace(0, 200, 20000001) * ms
76-
7776
def test_wrong_input(self):
7877
self.assertRaises(TypeError, stds.victor_purpura_distance,
7978
[self.array1, self.array2], self.q3)
@@ -600,6 +599,27 @@ def test_van_rossum_distance(self):
600599
[self.st21], self.tau3)[0, 0], 0)
601600
self.assertEqual(len(stds.van_rossum_distance([], self.tau3)), 0)
602601

602+
def test_van_rossum_distance_regression_small_negative_values(self):
603+
"""
604+
Regression test for issue #679
605+
Very small negative value in van_rossum_distance function.
606+
Occurs due to floating point precision when
607+
spike times are represented as small values
608+
These values should be clipped to zero to avoid nans.
609+
"""
610+
611+
st24 = SpikeTrain([0.1782, 0.2286, 0.2804, 0.4972, 0.5504],
612+
units='s', t_stop=4.0)
613+
tau8 = 0.1 * s
614+
# Check small negative values edge case
615+
with warnings.catch_warnings(record=True) as w:
616+
warnings.simplefilter("always")
617+
result = stds.van_rossum_distance([st24, st24], tau8)
618+
self.assertTrue(any("very small negative values encountered"
619+
in str(warn.message) for warn in w))
620+
self.assertEqual(result[0, 1], 0.0)
621+
self.assertFalse(np.any(np.isnan(result)))
622+
603623

604624
if __name__ == '__main__':
605625
unittest.main()

0 commit comments

Comments
 (0)