diff --git a/docs/src/datadeps.md b/docs/src/datadeps.md index 9f0c5b51e..5586d256c 100644 --- a/docs/src/datadeps.md +++ b/docs/src/datadeps.md @@ -179,9 +179,20 @@ Dagger.spawn_datadeps() do end ``` -You can pass any number of aliasing modifiers to `Deps`. This is particularly -useful for declaring aliasing with `Diagonal`, `Bidiagonal`, `Tridiagonal`, and -`SymTridiagonal` access, as these "wrappers" make a copy of their parent array -and thus can't be used to "mask" access to the parent like `UpperTriangular` -and `UnitLowerTriangular` can (which is valuable for writing memory-efficient, -generic algorithms in Julia). +We call `InOut(Diagonal)` an "aliasing modifier". The purpose of `Deps` is to +pass an argument (here, `A`) as-is, while specifying to Datadeps what portions +of the argument will be accessed (in this case, the diagonal elements) and how +(read/write/both). You can pass any number of aliasing modifiers to `Deps`. + +`Deps` is particularly useful for declaring aliasing with `Diagonal`, +`Bidiagonal`, `Tridiagonal`, and `SymTridiagonal` access, as these "wrappers" +make a copy of their parent array and thus can't be used to "mask" access to the +parent like `UpperTriangular` and `UnitLowerTriangular` can (which is valuable +for writing memory-efficient, generic algorithms in Julia). + +### Supported Aliasing Modifiers + +- Any function that returns the original object or a view of the original object +- `UpperTriangular`/`LowerTriangular`/`UnitUpperTriangular`/`UnitLowerTriangular` +- `Diagonal`/`Bidiagonal`/`Tridiagonal`/`SymTridiagonal` (via `Deps`, e.g. to read from the diagonal of `X`: `Dagger.@spawn sum(Deps(X, In(Diagonal)))`) +- `Symbol` for field access (via `Deps`, e.g. to write to `X.value`: `Dagger.@spawn setindex!(Deps(X, InOut(:value)), :value, 42)` diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index b0aa248ce..6b75a5038 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -162,7 +162,12 @@ function memory_spans(oa::ObjectAliasing) return [span] end -aliasing(x, T) = aliasing(T(x)) +function aliasing(x, dep_mod) + if dep_mod isa Symbol + return aliasing(getfield(x, dep_mod)) + end + return aliasing(dep_mod(x)) +end function aliasing(x::T) where T if isbits(x) return NoAliasing() diff --git a/test/datadeps.jl b/test/datadeps.jl index b8830f58e..0de70c318 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -109,6 +109,7 @@ function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vect end @everywhere do_nothing(Xs...) = nothing +@everywhere mut_ref!(R) = (R[] .= 0;) function test_datadeps(;args_chunks::Bool, args_thunks::Bool, args_loc::Int, @@ -425,10 +426,17 @@ function test_datadeps(;args_chunks::Bool, end # Inner Scope - @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do + @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do Dagger.@spawn scope=Dagger.ExactScope(Dagger.ThreadProc(1, 5000)) 1+1 end + # Field aliasing + X = Ref(rand(1000)) + @test all(x->x==0, fetch(Dagger.spawn_datadeps() do + Dagger.@spawn mut_ref!(Deps(X, InOut(:x))) + Dagger.@spawn getfield(Deps(X, In(:x)), :x) + end)) + # Add-to-copy A = rand(1000) B = rand(1000)