@@ -172,7 +172,7 @@ def get_delay_data(
172172 self ,
173173 name : str ,
174174 delay_step : Union [int , bm .JaxArray , jnp .DeviceArray ],
175- indices : Union [int , bm .JaxArray , jnp .DeviceArray ] = None ,
175+ * indices : Union [int , bm .JaxArray , jnp .DeviceArray ],
176176 ):
177177 """Get delay data according to the provided delay steps.
178178
@@ -192,18 +192,18 @@ def get_delay_data(
192192 """
193193 if name in self .global_delay_vars :
194194 if isinstance (delay_step , int ):
195- return self .global_delay_vars [name ](delay_step , indices )
195+ return self .global_delay_vars [name ](delay_step , * indices )
196196 else :
197- if indices is None :
198- indices = jnp .arange (delay_step .size )
199- return self .global_delay_vars [name ](delay_step , indices )
197+ if len ( indices ) == 0 :
198+ indices = ( jnp .arange (delay_step .size ), )
199+ return self .global_delay_vars [name ](delay_step , * indices )
200200 elif name in self .local_delay_vars :
201201 if isinstance (delay_step , int ):
202202 return self .local_delay_vars [name ](delay_step )
203203 else :
204- if indices is None :
205- indices = jnp .arange (delay_step .size )
206- return self .local_delay_vars [name ](delay_step , indices )
204+ if len ( indices ) == 0 :
205+ indices = ( jnp .arange (delay_step .size ), )
206+ return self .local_delay_vars [name ](delay_step , * indices )
207207 else :
208208 raise ValueError (f'{ name } is not defined in delay variables.' )
209209
0 commit comments