@@ -443,6 +443,7 @@ def _get_analogsignal_chunk_header_attached(self, i_start, i_stop, stream_index,
443443
444444 stream_name = self .header ["signal_streams" ][stream_index ]["name" ][:]
445445 stream_is_digital = stream_name in digital_stream_names
446+ stream_is_stim = stream_name == "Stim channel"
446447
447448 field_name = stream_name if stream_is_digital else channel_ids [0 ]
448449
@@ -462,7 +463,18 @@ def _get_analogsignal_chunk_header_attached(self, i_start, i_stop, stream_index,
462463 sl1 = sl0 + (i_stop - i_start )
463464
464465 # For all streams raw_data is a structured memmap with a field for each channel_id
465- if not stream_is_digital :
466+ if stream_is_stim :
467+ # For stim data, we need to extract the raw data first, then demultiplex it
468+ stim_data = np .zeros ((i_stop - i_start , len (channel_ids )), dtype = dtype )
469+ for chunk_index , channel_id in enumerate (channel_ids ):
470+ data_chan = self ._raw_data [channel_id ]
471+ if multiple_samples_per_block :
472+ stim_data [:, chunk_index ] = data_chan [block_start :block_stop ].flatten ()[sl0 :sl1 ]
473+ else :
474+ stim_data [:, chunk_index ] = data_chan [i_start :i_stop ]
475+ # Now demultiplex the stim data
476+ sigs_chunk = self ._demultiplex_stim_data (stim_data , 0 , stim_data .shape [0 ])
477+ elif not stream_is_digital :
466478 sigs_chunk = np .zeros ((i_stop - i_start , len (channel_ids )), dtype = dtype )
467479
468480 for chunk_index , channel_id in enumerate (channel_ids ):
@@ -480,6 +492,8 @@ def _get_analogsignal_chunk_one_file_per_channel(self, i_start, i_stop, stream_i
480492
481493 stream_name = self .header ["signal_streams" ][stream_index ]["name" ][:]
482494 signal_data_memmap_list = self ._raw_data [stream_name ]
495+ stream_is_stim = stream_name == "Stim channel"
496+
483497 channel_indexes_are_slice = isinstance (channel_indexes , slice )
484498 if channel_indexes_are_slice :
485499 num_channels = len (signal_data_memmap_list )
@@ -496,6 +510,10 @@ def _get_analogsignal_chunk_one_file_per_channel(self, i_start, i_stop, stream_i
496510 for chunk_index , channel_index in enumerate (channel_indexes ):
497511 channel_memmap = signal_data_memmap_list [channel_index ]
498512 sigs_chunk [:, chunk_index ] = channel_memmap [i_start :i_stop ]
513+
514+ # If this is stim data, we need to demultiplex it
515+ if stream_is_stim :
516+ sigs_chunk = self ._demultiplex_stim_data (sigs_chunk , 0 , sigs_chunk .shape [0 ])
499517
500518 return sigs_chunk
501519
@@ -505,14 +523,17 @@ def _get_analogsignal_chunk_one_file_per_signal(self, i_start, i_stop, stream_in
505523 raw_data = self ._raw_data [stream_name ]
506524
507525 stream_is_digital = stream_name in digital_stream_names
526+ stream_is_stim = stream_name == "Stim channel"
527+
508528 if stream_is_digital :
509529 stream_id = self .header ["signal_streams" ][stream_index ]["id" ]
510530 mask = self .header ["signal_channels" ]["stream_id" ] == stream_id
511531 signal_channels = self .header ["signal_channels" ][mask ]
512532 channel_ids = signal_channels ["id" ][channel_indexes ]
513533
514534 output = self ._demultiplex_digital_data (raw_data , channel_ids , i_start , i_stop )
515-
535+ elif stream_is_stim :
536+ output = self ._demultiplex_stim_data (raw_data , i_start , i_stop )
516537 else :
517538 output = raw_data [i_start :i_stop , channel_indexes ]
518539
@@ -530,6 +551,42 @@ def _demultiplex_digital_data(self, raw_digital_data, channel_ids, i_start, i_st
530551 output [:, channel_index ] = demultiplex_data [i_start :i_stop ].flatten ()
531552
532553 return output
554+
555+ def _demultiplex_stim_data (self , raw_stim_data , i_start , i_stop ):
556+ """
557+ Demultiplexes the stim data stream.
558+
559+ Parameters
560+ ----------
561+ raw_stim_data : ndarray
562+ The raw stim data
563+ i_start : int
564+ Start index
565+ i_stop : int
566+ Stop index
567+
568+ Returns
569+ -------
570+ output : ndarray
571+ Demultiplexed stim data containing only the current values, preserving channel dimensions
572+ """
573+ # Get the relevant portion of the data
574+ data = raw_stim_data [i_start :i_stop ]
575+
576+ # Extract current value (bits 0-8)
577+ magnitude = np .bitwise_and (data , 0xFF ) # Extract lowest 8 bits
578+ sign_bit = np .bitwise_and (np .right_shift (data , 8 ), 0x01 ) # Extract 9th bit for sign
579+
580+ # Apply sign to current values
581+ current = np .where (sign_bit == 1 , - magnitude , magnitude )
582+
583+ # Note: If needed, other flag bits could be extracted as follows:
584+ # compliance_flag = np.bitwise_and(np.right_shift(data, 15), 0x01).astype(bool) # Bit 16 (MSB)
585+ # charge_recovery_flag = np.bitwise_and(np.right_shift(data, 14), 0x01).astype(bool) # Bit 15
586+ # amp_settle_flag = np.bitwise_and(np.right_shift(data, 13), 0x01).astype(bool) # Bit 14
587+ # These could be returned as a structured array or dictionary if needed
588+
589+ return current
533590
534591 def get_intan_timestamps (self , i_start = None , i_stop = None ):
535592 """
@@ -857,8 +914,8 @@ def read_rhs(filename, file_format: str):
857914 chan_info_stim ["sampling_rate" ] = sr
858915 # stim channel are complicated because they are coded
859916 # with bits, they do not fit the gain/offset rawio strategy
860- chan_info_stim ["units" ] = ""
861- chan_info_stim ["gain" ] = 1.0
917+ chan_info_stim ["units" ] = "A" # Amps
918+ chan_info_stim ["gain" ] = global_info [ "stim_step_size" ]
862919 chan_info_stim ["offset" ] = 0.0
863920 chan_info_stim ["signal_type" ] = 11 # put it in another group
864921 chan_info_stim ["dtype" ] = "uint16"
0 commit comments