Skip to content

Commit 4444d88

Browse files
committed
fix bug
1 parent 8060516 commit 4444d88

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

brainpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.1.9"
3+
__version__ = "2.1.10"
44

55

66
try:

brainpy/dyn/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def get_delay_data(
172172
self,
173173
name: str,
174174
delay_step: Union[int, bm.JaxArray, jnp.DeviceArray],
175-
indices: Union[int, bm.JaxArray, jnp.DeviceArray] = None,
175+
*indices: Union[int, bm.JaxArray, jnp.DeviceArray],
176176
):
177177
"""Get delay data according to the provided delay steps.
178178
@@ -192,18 +192,18 @@ def get_delay_data(
192192
"""
193193
if name in self.global_delay_vars:
194194
if isinstance(delay_step, int):
195-
return self.global_delay_vars[name](delay_step, indices)
195+
return self.global_delay_vars[name](delay_step, *indices)
196196
else:
197-
if indices is None:
198-
indices = jnp.arange(delay_step.size)
199-
return self.global_delay_vars[name](delay_step, indices)
197+
if len(indices) == 0:
198+
indices = (jnp.arange(delay_step.size), )
199+
return self.global_delay_vars[name](delay_step, *indices)
200200
elif name in self.local_delay_vars:
201201
if isinstance(delay_step, int):
202202
return self.local_delay_vars[name](delay_step)
203203
else:
204-
if indices is None:
205-
indices = jnp.arange(delay_step.size)
206-
return self.local_delay_vars[name](delay_step, indices)
204+
if len(indices) == 0:
205+
indices = (jnp.arange(delay_step.size), )
206+
return self.local_delay_vars[name](delay_step, *indices)
207207
else:
208208
raise ValueError(f'{name} is not defined in delay variables.')
209209

brainpy/nn/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1478,7 +1478,6 @@ def plot_node_graph(self,
14781478
legends = dict() if legends is None else legends
14791479
ax.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best', **legends)
14801480
if show:
1481-
plt.tight_layout()
14821481
plt.show()
14831482

14841483

0 commit comments

Comments
 (0)