Skip to content

Commit 4547ef6

Browse files
authored
Merge pull request #117 from PKU-NIP-Lab/enhance-measure
Enhance measure/input/brainpylib
2 parents ddf0e3a + 67207f0 commit 4547ef6

File tree

12 files changed

+658
-414
lines changed

12 files changed

+658
-414
lines changed

brainpy/inputs/__init__.py

Lines changed: 1 addition & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -6,203 +6,5 @@
66
You can access them through ``brainpy.inputs.XXX``.
77
"""
88

9-
import numpy as np
10-
11-
from brainpy import math as bm
12-
13-
__all__ = [
14-
'section_input',
15-
'constant_input', 'constant_current',
16-
'spike_input', 'spike_current',
17-
'ramp_input', 'ramp_current',
18-
]
19-
20-
21-
def section_input(values, durations, dt=None, return_length=False):
22-
"""Format an input current with different sections.
23-
24-
For example:
25-
26-
If you want to get an input where the size is 0 bwteen 0-100 ms,
27-
and the size is 1. between 100-200 ms.
28-
29-
>>> section_input(values=[0, 1],
30-
>>> durations=[100, 100])
31-
32-
Parameters
33-
----------
34-
values : list, np.ndarray
35-
The current values for each period duration.
36-
durations : list, np.ndarray
37-
The duration for each period.
38-
dt : float
39-
Default is None.
40-
return_length : bool
41-
Return the final duration length.
42-
43-
Returns
44-
-------
45-
current_and_duration : tuple
46-
(The formatted current, total duration)
47-
"""
48-
assert len(durations) == len(values), f'"values" and "durations" must be the same length, while ' \
49-
f'we got {len(values)} != {len(durations)}.'
50-
51-
dt = bm.get_dt() if dt is None else dt
52-
53-
# get input current shape, and duration
54-
I_duration = sum(durations)
55-
I_shape = ()
56-
for val in values:
57-
shape = bm.shape(val)
58-
if len(shape) > len(I_shape):
59-
I_shape = shape
60-
61-
# get the current
62-
start = 0
63-
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_)
64-
for c_size, duration in zip(values, durations):
65-
length = int(duration / dt)
66-
I_current[start: start + length] = c_size
67-
start += length
68-
69-
if return_length:
70-
return I_current, I_duration
71-
else:
72-
return I_current
73-
74-
75-
def constant_input(I_and_duration, dt=None):
76-
"""Format constant input in durations.
77-
78-
For example:
79-
80-
If you want to get an input where the size is 0 bwteen 0-100 ms,
81-
and the size is 1. between 100-200 ms.
82-
83-
>>> import brainpy.math as bm
84-
>>> constant_input([(0, 100), (1, 100)])
85-
>>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)])
86-
87-
Parameters
88-
----------
89-
I_and_duration : list
90-
This parameter receives the current size and the current
91-
duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`.
92-
dt : float
93-
Default is None.
94-
95-
Returns
96-
-------
97-
current_and_duration : tuple
98-
(The formatted current, total duration)
99-
"""
100-
dt = bm.get_dt() if dt is None else dt
101-
102-
# get input current dimension, shape, and duration
103-
I_duration = 0.
104-
I_shape = ()
105-
for I in I_and_duration:
106-
I_duration += I[1]
107-
shape = bm.shape(I[0])
108-
if len(shape) > len(I_shape):
109-
I_shape = shape
110-
111-
# get the current
112-
start = 0
113-
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_)
114-
for c_size, duration in I_and_duration:
115-
length = int(duration / dt)
116-
I_current[start: start + length] = c_size
117-
start += length
118-
return I_current, I_duration
119-
120-
121-
constant_current = constant_input
122-
123-
124-
def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None):
125-
"""Format current input like a series of short-time spikes.
126-
127-
For example:
128-
129-
If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms,
130-
and each spike lasts 1 ms and the spike current is 0.5, then you can use the
131-
following funtions:
132-
133-
>>> spike_input(sp_times=[10, 20, 30, 200, 300],
134-
>>> sp_lens=1., # can be a list to specify the spike length at each point
135-
>>> sp_sizes=0.5, # can be a list to specify the current size at each point
136-
>>> duration=400.)
137-
138-
Parameters
139-
----------
140-
sp_times : list, tuple
141-
The spike time-points. Must be an iterable object.
142-
sp_lens : int, float, list, tuple
143-
The length of each point-current, mimicking the spike durations.
144-
sp_sizes : int, float, list, tuple
145-
The current sizes.
146-
duration : int, float
147-
The total current duration.
148-
dt : float
149-
The default is None.
150-
151-
Returns
152-
-------
153-
current : bm.ndarray
154-
The formatted input current.
155-
"""
156-
dt = bm.get_dt() if dt is None else dt
157-
assert isinstance(sp_times, (list, tuple))
158-
if isinstance(sp_lens, (float, int)):
159-
sp_lens = [sp_lens] * len(sp_times)
160-
if isinstance(sp_sizes, (float, int)):
161-
sp_sizes = [sp_sizes] * len(sp_times)
162-
163-
current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_)
164-
for time, dur, size in zip(sp_times, sp_lens, sp_sizes):
165-
pp = int(time / dt)
166-
p_len = int(dur / dt)
167-
current[pp: pp + p_len] = size
168-
return current
169-
170-
171-
spike_current = spike_input
172-
173-
174-
def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
175-
"""Get the gradually changed input current.
176-
177-
Parameters
178-
----------
179-
c_start : float
180-
The minimum (or maximum) current size.
181-
c_end : float
182-
The maximum (or minimum) current size.
183-
duration : int, float
184-
The total duration.
185-
t_start : float
186-
The ramped current start time-point.
187-
t_end : float
188-
The ramped current end time-point. Default is the None.
189-
dt : float, int, optional
190-
The numerical precision.
191-
192-
Returns
193-
-------
194-
current : bm.ndarray
195-
The formatted current
196-
"""
197-
dt = bm.get_dt() if dt is None else dt
198-
t_end = duration if t_end is None else t_end
199-
200-
current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_)
201-
p1 = int(np.ceil(t_start / dt))
202-
p2 = int(np.ceil(t_end / dt))
203-
current[p1: p2] = bm.array(bm.linspace(c_start, c_end, p2 - p1), dtype=bm.float_)
204-
return current
205-
206-
207-
ramp_current = ramp_input
9+
from .currents import *
20810

0 commit comments

Comments
 (0)