Skip to content

Commit efb4d61

Browse files
feat: add ForLoop
1 parent 3658f4d commit efb4d61

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

src/code.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions,
55

66
export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
77
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
8-
SpawnFetch, Multithreaded, cse
8+
SpawnFetch, Multithreaded, ForLoop, cse
99

1010
import ..SymbolicUtils
1111
import ..SymbolicUtils.Rewriters
@@ -705,6 +705,27 @@ function toexpr(exp::LiteralExpr, st)
705705
recurse_expr(exp.ex, st)
706706
end
707707

708+
"""
709+
ForLoop(itervar, range, body)
710+
711+
Generate a `for` loop of the form
712+
```julia
713+
for itervar in range
714+
body
715+
end
716+
```
717+
"""
718+
struct ForLoop <: CodegenPrimitive
719+
itervar
720+
range
721+
body
722+
end
723+
724+
function toexpr(f::ForLoop, st)
725+
:(for $(toexpr(f.itervar, st)) in $(toexpr(f.range, st))
726+
$(toexpr(f.body, st))
727+
end)
728+
end
708729

709730
### Code-related utilities
710731

@@ -935,4 +956,10 @@ function cse!(x::SpawnFetch{T}, state::CSEState) where {T}
935956
return SpawnFetch{T}(map(cse, x.exprs), cse!(x.args, state), x.combine)
936957
end
937958

959+
function cse!(x::ForLoop, state::CSEState)
960+
# cse the range with current scope, CSE the body with a new scope
961+
new_state = new_scope(state)
962+
return ForLoop(x.itervar, cse!(x.range, state), apply_cse(cse!(x.body, new_state), new_state))
963+
end
964+
938965
end

test/code.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,21 @@ end
305305
@test arr[i] == 0
306306
end
307307
end
308+
309+
@testset "`ForLoop`" begin
310+
@syms a b c::Array
311+
ex = ForLoop(a, term(range, b^2, b^2 + 3), SetArray(false, c, [AtIndex(a, a + 1)]))
312+
expr = quote
313+
let b = 2, c = zeros(Int, 10)
314+
$(toexpr(ex))
315+
c
316+
end
317+
end
318+
arr = eval(expr)
319+
@test arr[4] == 5
320+
@test arr[5] == 6
321+
@test arr[6] == 7
322+
@test arr[7] == 8
323+
@test all(iszero, arr[1:3])
324+
@test all(iszero, arr[8:end])
325+
end

test/cse.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,24 @@ end
201201
rgf = @RuntimeGeneratedFunction(toexpr(innerfn))
202202
@test rgf(5, sin(1)) == trueval
203203
end
204+
205+
@testset "ForLoop" begin
206+
@syms a b c::Array
207+
ex = ForLoop(a, term(range, b^2, b^2 + 3), SetArray(false, c, [AtIndex(a, a^2 + sin(a^2))]))
208+
csex = cse(ex)
209+
@test findfirst(isequal(csex.body.range), Code.lhs.(csex.pairs)) !== nothing
210+
@test findfirst(isequal(csex.body.body.body.elems[1].elem), Code.lhs.(csex.body.body.pairs)) !== nothing
211+
expr = quote
212+
let b = 2, c = zeros(10)
213+
$(toexpr(ex))
214+
c
215+
end
216+
end
217+
arr = eval(expr)
218+
@test arr[4] == 4^2 + sin(4^2)
219+
@test arr[5] == 5^2 + sin(5^2)
220+
@test arr[6] == 6^2 + sin(6^2)
221+
@test arr[7] == 7^2 + sin(7^2)
222+
@test all(iszero, arr[1:3])
223+
@test all(iszero, arr[8:end])
224+
end

0 commit comments

Comments
 (0)