Skip to content

Commit 711ca0e

Browse files
feat: allow SetArray to return the array being set
1 parent 6ff8faf commit 711ca0e

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

src/code.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,11 @@ end
416416
inbounds::Bool
417417
arr
418418
elems # Either iterator of Pairs or just an iterator
419+
return_arr::Bool
419420
end
420421

421422
"""
422-
SetArray(inbounds, arr, elems)
423+
SetArray(inbounds::Bool, arr, elems[, return_arr::Bool])
423424
424425
An expression representing setting of elements of `arr`.
425426
@@ -430,9 +431,14 @@ is performed in its place.
430431
431432
`inbounds` is a boolean flag, `true` surrounds the resulting expression
432433
in an `@inbounds`.
434+
435+
`return_arr` is a flag which controls whether the generated `begin..end` block
436+
returns the `arr`. Defaults to `false`, in which case the block returns `nothing`.
433437
"""
434438
SetArray
435439

440+
SetArray(inbounds, arr, elems) = SetArray(inbounds, arr, elems, false)
441+
436442
@matchable struct AtIndex <: CodegenPrimitive
437443
i
438444
elem
@@ -446,7 +452,7 @@ function toexpr(s::SetArray, st)
446452
ex = quote
447453
$([:($(toexpr(s.arr, st))[$(ex isa AtIndex ? toexpr(ex.i, st) : i)] = $(toexpr(ex, st)))
448454
for (i, ex) in enumerate(s.elems)]...)
449-
nothing
455+
$(s.return_arr ? toexpr(s.arr, st) : nothing)
450456
end
451457
s.inbounds ? :(@inbounds $ex) : ex
452458
end
@@ -906,7 +912,7 @@ function cse!(x::MakeArray, state::CSEState)
906912
end
907913

908914
function cse!(x::SetArray, state::CSEState)
909-
return SetArray(x.inbounds, x.arr, cse!(x.elems, state))
915+
return SetArray(x.inbounds, x.arr, cse!(x.elems, state), x.return_arr)
910916
end
911917

912918
function cse!(x::MakeSparseArray, state::CSEState)

test/code.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,21 @@ end
323323
@test all(iszero, arr[1:3])
324324
@test all(iszero, arr[8:end])
325325
end
326+
327+
@testset "`SetArray` with `return_arr`" begin
328+
@syms a b c::Array
329+
ex = SetArray(false, c, [3, 2, 1], false)
330+
expr = quote
331+
let b = 2, c = zeros(Int, 3)
332+
$(toexpr(ex))
333+
end
334+
end
335+
@test eval(expr) === nothing
336+
ex = SetArray(false, c, [3, 2, 1], true)
337+
expr = quote
338+
let b = 2, c = zeros(Int, 3)
339+
$(toexpr(ex))
340+
end
341+
end
342+
@test eval(expr) == [3, 2, 1]
343+
end

test/cse.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ end
107107
marr = MakeArray(arr, Array)
108108
sparr = sparse([1, 2, 3, 4], [1, 2, 3, 4], vec(arr))
109109
msparr = MakeSparseArray(sparr)
110-
sarr = SetArray(false, :buffer, [[a^2 + c^2], AtIndex(3, arr), AtIndex(4, msparr)])
110+
sarr = SetArray(false, :buffer, [[a^2 + c^2], AtIndex(3, arr), AtIndex(4, msparr)], true)
111111

112112
csex = cse(sarr)
113113
# test that simple array is CSEd
@@ -127,7 +127,6 @@ end
127127
expr = quote
128128
let a = 1, b = 2, c = 3, buffer = Any[0, "A", 0, 0]
129129
$(toexpr(csex))
130-
buffer
131130
end
132131
end
133132
val = eval(expr)

0 commit comments

Comments
 (0)