File tree Expand file tree Collapse file tree 1 file changed +7
-5
lines changed
pytensor/link/numba/dispatch Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -222,14 +222,16 @@ def add_output_storage_post_proc_stmt(
222222 # the storage array.
223223 # This is needed when the output storage array does not have a length
224224 # equal to the number of taps plus `n_steps`.
225+ # If the storage size only allows one entry, there's nothing to rotate
225226 output_storage_post_proc_stmts .append (
226227 dedent (
227228 f"""
228- if (i + { tap_size } ) > { storage_size } :
229+ if 1 < { storage_size } < (i + { tap_size } ):
229230 { outer_in_name } _shift = (i + { tap_size } ) % ({ storage_size } )
230- { outer_in_name } _left = { outer_in_name } [:{ outer_in_name } _shift]
231- { outer_in_name } _right = { outer_in_name } [{ outer_in_name } _shift:]
232- { outer_in_name } = np.concatenate(({ outer_in_name } _right, { outer_in_name } _left))
231+ if { outer_in_name } _shift > 0:
232+ { outer_in_name } _left = { outer_in_name } [:{ outer_in_name } _shift]
233+ { outer_in_name } _right = { outer_in_name } [{ outer_in_name } _shift:]
234+ { outer_in_name } = np.concatenate(({ outer_in_name } _right, { outer_in_name } _left))
233235 """
234236 ).strip ()
235237 )
@@ -417,4 +419,4 @@ def scan({", ".join(outer_in_names)}):
417419
418420 scan_op_fn = compile_function_src (scan_op_src , "scan" , {** globals (), ** global_env })
419421
420- return numba_basic .numba_njit (scan_op_fn )
422+ return numba_basic .numba_njit (scan_op_fn , boundscheck = False )
You can’t perform that action at this time.
0 commit comments