Skip to content

Commit 53f6568

Browse files
committed
Add support for sending spike event & waveform from C++ to Python
1 parent 5839e52 commit 53f6568

File tree

4 files changed

+102
-59
lines changed

4 files changed

+102
-59
lines changed

Modules/examples/bandpass_filter.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,15 @@ def stop_acquisition(self):
7070
pass
7171

7272
def handle_ttl_event(self, source_node, channel, sample_number, line, state):
73-
"""
74-
Handle each incoming ttl event.
75-
76-
Parameters:
77-
source_node (int): id of the processor this event was generated from
78-
channel (str): name of the event channel
79-
sample_number (int): sample number of the event
80-
line (int): the line on which event was generated (0-255)
81-
state (bool): event state true (ON) or false (OFF)
82-
"""
73+
""" Handle each incoming ttl event """
8374
pass
8475

85-
def start_recording(self, recording_dir):
86-
"""
87-
Called when recording starts
76+
def handle_spike(self, source_node, electrode_name, num_channels, num_samples, sample_number, sorted_id, spike_data):
77+
""" Handle each incoming spike """
78+
pass
8879

89-
Parameters:
90-
recording_dir - recording directory to be used by future record nodes.
91-
"""
80+
def start_recording(self, recording_dir):
81+
""" Called when recording starts """
9282
pass
9383

9484
def stop_recording(self):

Modules/template/processor_template.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def process(self, data):
2121
Process each incoming data buffer.
2222
2323
Parameters:
24-
data - numpy array.
24+
data - N x M numpy array, where N = num_channles, M = num of samples in the buffer.
2525
"""
2626
try:
2727
pass
@@ -48,6 +48,22 @@ def handle_ttl_event(self, source_node, channel, sample_number, line, state):
4848
state (bool): event state True (ON) or False (OFF)
4949
"""
5050
pass
51+
52+
def handle_spike(self, source_node, electrode_name, num_channels, num_samples, sample_number, sorted_id, spike_data):
53+
"""
54+
Handle each incoming spike.
55+
56+
Parameters:
57+
source_node (int): id of the processor this spike was generated from
58+
electrode_name (str): name of the electrode
59+
num_channels (int): number of channels associated with the electrode type
60+
num_samples (int): total number of samples in the spike waveform
61+
sample_number (int): sample number of the spike
62+
sorted_id (int): the sorted ID for this spike
63+
spike_data (numpy array): N x M numpy array, where N = num_channels & M = num_samples (read-only).
64+
"""
65+
# print("SPIKE RECEIVED! Source Node:", source_node, ", Electrode name:", electrode_name, ", Num channels:", num_channels, ", Sample num:", sample_number, ", Sorted ID:", sorted_id)
66+
pass
5167

5268
def start_recording(self, recording_dir):
5369
"""

Source/PythonProcessor.cpp

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
2525
#include <filesystem>
2626

2727
#include "PythonProcessor.h"
28-
#include "PythonProcessorEditor.h"
2928

3029
namespace py = pybind11;
3130

@@ -38,8 +37,8 @@ PYBIND11_EMBEDDED_MODULE(oe_pyprocessor, module){
3837
PythonProcessor::PythonProcessor()
3938
: GenericProcessor("Python Processor")
4039
{
41-
pyModule = NULL;
42-
pyObject = NULL;
40+
pyModule = nullptr;
41+
pyObject = nullptr;
4342
moduleReady = false;
4443
scriptPath = "";
4544
moduleName = "";
@@ -58,8 +57,6 @@ PythonProcessor::~PythonProcessor()
5857
if(Py_IsInitialized() > 0)
5958
{
6059
{
61-
delete pyModule;
62-
delete pyObject;
6360
py::gil_scoped_release release;
6461
}
6562
py::finalize_interpreter();
@@ -203,16 +200,41 @@ void PythonProcessor::handleTTLEvent(TTLEventPtr event)
203200
}
204201

205202

206-
// void PythonProcessor::handleSpike(SpikePtr event)
207-
// {
208-
// // py::gil_scoped_acquire acquire;
209-
// try {
210-
// pyObject->attr("handle_spike_event")();
211-
// }
212-
// catch (py::error_already_set& e) {
213-
// handlePythonException(e);
214-
// }
215-
// }
203+
void PythonProcessor::handleSpike(SpikePtr spike)
204+
{
205+
if (spike->getStreamId() == currentStream)
206+
{
207+
auto spikeChanInfo = spike->getChannelInfo();
208+
209+
const int sourceNodeId = spikeChanInfo->getSourceNodeId();
210+
const String electrodeName = spikeChanInfo->getName();
211+
const int numChans = spikeChanInfo->getNumChannels();
212+
const int64 sampleNum = spike->getSampleNumber();
213+
const uint16 sortedId = spike->getSortedId();
214+
const int numSamples = spikeChanInfo->getTotalSamples();
215+
216+
py::array_t<float> spikeData = py::array_t<float>({ numChans, numSamples });
217+
218+
for (int i = 0; i < numChans; ++i)
219+
{
220+
const float* spikeChanDataPtr = spike->getDataPointer(i);
221+
float* numpyChannelPtr = spikeData.mutable_data(i, 0);
222+
memcpy(numpyChannelPtr, spikeChanDataPtr, sizeof(float) * numSamples);
223+
}
224+
225+
// py::gil_scoped_acquire acquire;
226+
if(py::hasattr(*pyObject, "handle_spike"))
227+
{
228+
try {
229+
pyObject->attr("handle_spike")
230+
(sourceNodeId, electrodeName.toRawUTF8(), numChans, numSamples, sampleNum, sortedId, spikeData);
231+
}
232+
catch (py::error_already_set& e) {
233+
handlePythonException(e);
234+
}
235+
}
236+
}
237+
}
216238

217239

218240
// void PythonProcessor::handleBroadcastMessage(String message)
@@ -249,52 +271,67 @@ bool PythonProcessor::startAcquisition()
249271
{
250272
// py::gil_scoped_acquire acquire;
251273

252-
try {
253-
pyObject->attr("start_acquisition")();
254-
}
255-
catch (py::error_already_set& e) {
256-
handlePythonException(e);
274+
if(py::hasattr(*pyObject, "start_acquisition"))
275+
{
276+
try {
277+
pyObject->attr("start_acquisition")();
278+
}
279+
catch (py::error_already_set& e) {
280+
handlePythonException(e);
281+
}
257282
}
258283
return true;
259284
}
260285
return false;
261286
}
262287

263-
bool PythonProcessor::stopAcquisition() {
288+
bool PythonProcessor::stopAcquisition()
289+
{
264290
if (moduleReady)
265291
{
266292
// py::gil_scoped_acquire acquire;
267-
try {
268-
pyObject->attr("stop_acquisition")();
269-
}
270-
catch (py::error_already_set& e) {
271-
handlePythonException(e);
293+
if(py::hasattr(*pyObject, "stop_acquisition"))
294+
{
295+
try {
296+
pyObject->attr("stop_acquisition")();
297+
}
298+
catch (py::error_already_set& e) {
299+
handlePythonException(e);
300+
}
272301
}
273302
}
274303
return true;
275304
}
276305

277-
void PythonProcessor::startRecording() {
306+
void PythonProcessor::startRecording()
307+
{
278308
String recordingDirectory = CoreServices::getRecordingDirectoryName();
279309

280310
// py::gil_scoped_acquire acquire;
281-
try {
282-
pyObject->attr("start_recording")(recordingDirectory.toRawUTF8());
283-
}
284-
catch (py::error_already_set& e) {
285-
handlePythonException(e);
311+
if(moduleReady && py::hasattr(*pyObject, "start_recording"))
312+
{
313+
try {
314+
pyObject->attr("start_recording")(recordingDirectory.toRawUTF8());
315+
}
316+
catch (py::error_already_set& e) {
317+
handlePythonException(e);
318+
}
286319
}
287320
}
288321

289322

290323

291-
void PythonProcessor::stopRecording() {
324+
void PythonProcessor::stopRecording()
325+
{
292326
// py::gil_scoped_acquire acquire;
293-
try {
294-
pyObject->attr("stop_recording")();
295-
}
296-
catch (py::error_already_set& e) {
297-
handlePythonException(e);
327+
if(moduleReady && py::hasattr(*pyObject, "stop_recording"))
328+
{
329+
try {
330+
pyObject->attr("stop_recording")();
331+
}
332+
catch (py::error_already_set& e) {
333+
handlePythonException(e);
334+
}
298335
}
299336
}
300337

Source/PythonProcessor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ class PythonProcessor : public GenericProcessor
106106
// the plugin's process() method */
107107
void handleTTLEvent(TTLEventPtr event) override;
108108

109-
// /** Handles spikes received by the processor
110-
// Called automatically for each received spike whenever checkForEvents(true) is called from
111-
// the plugin's process() method */
112-
// void handleSpike(SpikePtr spike) override;
109+
/** Handles spikes received by the processor
110+
Called automatically for each received spike whenever checkForEvents(true) is called from
111+
the plugin's process() method */
112+
void handleSpike(SpikePtr spike) override;
113113

114114
// /** Handles broadcast messages sent during acquisition
115115
// Called automatically whenever a broadcast message is sent through the signal chain */

0 commit comments

Comments
 (0)