diff --git a/src/thunk.jl b/src/thunk.jl index dc961f303..e173e1c28 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -371,10 +371,12 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) body = nothing arg1 = nothing arg2 = nothing + value = nothing if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) || @capture(ex, arg1_[allargs__]) || + @capture(ex, arg1_[allargs__] = value_) || @capture(ex, arg1_.arg2_) || @capture(ex, (;allargs__)) || @capture(ex, bf_.(allargs__)) @@ -387,8 +389,13 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) # Getproperty (A.B) f = Base.getproperty allargs = Any[arg1, QuoteNode(arg2)] + elseif value !== nothing + # setindex! (A[2,3] = 4) + f = _setindex!_return_value + pushfirst!(allargs, value) + pushfirst!(allargs, arg1) else - # Indexing (A[2,3]) + # getindex (A[2,3]) f = Base.getindex pushfirst!(allargs, arg1) end @@ -444,6 +451,11 @@ _par(mod, ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: _par_inner(mod, ex; kwargs...) = ex _par_inner(mod, ex::Expr; kwargs...) = _par(mod, ex; kwargs...) +function _setindex!_return_value(A, value, idxs...) + setindex!(A, value, idxs...) + return value +end + """ Dagger.spawn(f, args...; kwargs...) -> DTask diff --git a/test/thunk.jl b/test/thunk.jl index 73879545b..06ba25144 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -139,6 +139,25 @@ end @test t isa Dagger.DTask @test fetch(t) == 42 end + @testset "setindex!" begin + A = Dagger.@mutable rand(4, 4) + + t = @spawn A[1, 2] = 3.0 + @test t isa Dagger.DTask + @test fetch(t) == 3.0 + @test fetch(@spawn A[1, 2]) == 3.0 + + t = @spawn A[2] = 4.0 + @test t isa Dagger.DTask + @test fetch(t) == 4.0 + @test fetch(@spawn A[2]) == 4.0 + + R = Dagger.@mutable Ref(42) + t = @spawn R[] = 43 + @test t isa Dagger.DTask + @test fetch(t) == 43 + @test fetch(@spawn R[]) == 43 + end @testset "NamedTuple" begin t = @spawn (;a=1, b=2) @test t isa Dagger.DTask