Skip to content

Commit 9d8a911

Browse files
committed
datadeps: Properly skip copying non-written args
1 parent 6176dc7 commit 9d8a911

File tree

4 files changed

+115
-12
lines changed

4 files changed

+115
-12
lines changed

src/array/cholesky.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{UpperTriangular}) where T
2222

2323
info = [convert(LinearAlgebra.BlasInt, 0)]
2424
try
25-
Dagger.spawn_datadeps(;aliasing=true) do
25+
Dagger.spawn_datadeps() do
2626
for k in range(1, mt)
2727
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
2828
for n in range(k+1, nt)

src/datadeps.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ function is_writedep(arg, deps, task::DTask)
221221
end
222222

223223
# Aliasing state setup
224-
function populate_task_info!(state::DataDepsState, spec, task)
224+
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
225225
# Populate task dependencies
226226
dependencies_to_add = Vector{Tuple{Bool,Bool,AbstractAliasing,<:Any,<:Any}}()
227227

@@ -233,8 +233,8 @@ function populate_task_info!(state::DataDepsState, spec, task)
233233
# Unwrap the Chunk underlying any DTask arguments
234234
arg = arg isa DTask ? fetch(arg; raw=true) : arg
235235

236-
# Skip non-mutable arguments
237-
Base.datatype_pointerfree(typeof(arg)) && continue
236+
# Skip non-aliasing arguments
237+
type_may_alias(typeof(arg)) || continue
238238

239239
# Add all aliasing dependencies
240240
for (dep_mod, readdep, writedep) in deps
@@ -251,7 +251,8 @@ function populate_task_info!(state::DataDepsState, spec, task)
251251
end
252252

253253
# Track the task result too
254-
push!(dependencies_to_add, (true, true, UnknownAliasing(), identity, task))
254+
# N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this
255+
push!(dependencies_to_add, (false, false, UnknownAliasing(), identity, task))
255256

256257
# Record argument/result dependencies
257258
push!(state.dependencies, task => dependencies_to_add)
@@ -664,7 +665,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
664665
# Is the data written previously or now?
665666
arg, deps = unwrap_inout(arg)
666667
arg = arg isa DTask ? fetch(arg; raw=true) : arg
667-
if Base.datatype_pointerfree(typeof(arg)) || !has_writedep(state, arg, deps, task)
668+
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
668669
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
669670
spec.args[idx] = pos => arg
670671
continue
@@ -736,7 +737,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
736737
for (idx, (_, arg)) in enumerate(task_args)
737738
arg, deps = unwrap_inout(arg)
738739
arg = arg isa DTask ? fetch(arg; raw=true) : arg
739-
Base.datatype_pointerfree(typeof(arg)) && continue
740+
type_may_alias(typeof(arg)) || continue
740741
if queue.aliasing
741742
for (dep_mod, _, writedep) in deps
742743
ainfo = aliasing(arg, dep_mod)
@@ -769,7 +770,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
769770
for (idx, (_, arg)) in enumerate(task_args)
770771
arg, deps = unwrap_inout(arg)
771772
arg = arg isa DTask ? fetch(arg; raw=true) : arg
772-
Base.datatype_pointerfree(typeof(arg)) && continue
773+
type_may_alias(typeof(arg)) || continue
773774
if queue.aliasing
774775
for (dep_mod, _, writedep) in deps
775776
ainfo = aliasing(arg, dep_mod)
@@ -864,7 +865,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
864865
for arg in keys(astate.data_origin)
865866
# Is the data previously written?
866867
arg, deps = unwrap_inout(arg)
867-
if Base.datatype_pointerfree(typeof(arg)) || !has_writedep(state, arg, deps)
868+
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps)
868869
@dagdebug nothing :spawn_datadeps "Skipped copy-from (unwritten)"
869870
end
870871

src/memory-spaces.jl

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ function unwrap(x::Chunk)
3333
@assert root_worker_id(x.processor) == myid()
3434
MemPool.poolget(x.handle)
3535
end
36+
move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} =
37+
throw(ArgumentError("No `move!` implementation defined for $F -> $T"))
3638
function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk)
3739
to_w = root_worker_id(to_space)
3840
remotecall_wait(to_w, dep_mod, to_space, from_space, to, from) do dep_mod, to_space, from_space, to, from
@@ -44,6 +46,10 @@ function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::Chun
4446
end
4547
return
4648
end
49+
function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::Base.RefValue{T}, from::Base.RefValue{T}) where {T}
50+
to[] = from[]
51+
return
52+
end
4753
function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::AbstractArray{T,N}, from::AbstractArray{T,N}) where {T,N}
4854
move!(to_space, from_space, dep_mod(to), dep_mod(from))
4955
end
@@ -66,6 +72,23 @@ end
6672

6773
### Aliasing and Memory Spans
6874

75+
type_may_alias(::Type{String}) = false
76+
type_may_alias(::Type{Symbol}) = false
77+
type_may_alias(::Type{<:Type}) = false
78+
type_may_alias(::Type{C}) where C<:Chunk{T} where T = type_may_alias(T)
79+
function type_may_alias(::Type{T}) where T
80+
if isbitstype(T)
81+
return false
82+
elseif ismutabletype(T)
83+
return true
84+
elseif isstructtype(T)
85+
for FT in fieldtypes(T)
86+
type_may_alias(FT) && return true
87+
end
88+
end
89+
return false
90+
end
91+
6992
may_alias(::MemorySpace, ::MemorySpace) = true
7093
may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner
7194

@@ -104,8 +127,71 @@ memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[]
104127
struct UnknownAliasing <: AbstractAliasing end
105128
memory_spans(::UnknownAliasing) = [MemorySpan{CPURAMMemorySpace}(C_NULL, typemax(UInt))]
106129

130+
warn_unknown_aliasing(T) =
131+
@warn "Cannot resolve aliasing for object of type $T\nExecution may become sequential"
132+
133+
struct CombinedAliasing <: AbstractAliasing
134+
sub_ainfos::Vector{AbstractAliasing}
135+
end
136+
function memory_spans(ca::CombinedAliasing)
137+
# FIXME: Don't hardcode CPURAMMemorySpace
138+
all_spans = MemorySpan{CPURAMMemorySpace}[]
139+
for sub_a in ca.sub_ainfos
140+
append!(all_spans, memory_spans(sub_a))
141+
end
142+
return all_spans
143+
end
144+
Base.:(==)(ca1::CombinedAliasing, ca2::CombinedAliasing) =
145+
ca1.sub_ainfos == ca2.sub_ainfos
146+
Base.hash(ca1::CombinedAliasing, h::UInt) =
147+
hash(ca1.sub_ainfos, hash(CombinedAliasing, h))
148+
149+
struct ObjectAliasing <: AbstractAliasing
150+
ptr::Ptr{Cvoid}
151+
sz::UInt
152+
end
153+
function ObjectAliasing(x::T) where T
154+
@nospecialize x
155+
ptr = pointer_from_objref(x)
156+
sz = sizeof(T)
157+
return ObjectAliasing(ptr, sz)
158+
end
159+
function memory_spans(oa::ObjectAliasing)
160+
rptr = RemotePtr{Cvoid}(oa.ptr)
161+
span = MemorySpan{CPURAMMemorySpace}(rptr, oa.sz)
162+
return [span]
163+
end
164+
107165
aliasing(x, T) = aliasing(T(x))
108-
aliasing(x) = isbits(x) ? NoAliasing() : UnknownAliasing()
166+
function aliasing(x::T) where T
167+
if isbits(x)
168+
return NoAliasing()
169+
elseif isstructtype(T)
170+
as = AbstractAliasing[]
171+
# If the object itself is mutable, it can alias
172+
if ismutabletype(T)
173+
push!(as, ObjectAliasing(x))
174+
end
175+
# Check all object fields (recursive)
176+
for field in fieldnames(T)
177+
sub_as = aliasing(getfield(x, field))
178+
if sub_as isa NoAliasing
179+
continue
180+
elseif sub_as isa CombinedAliasing
181+
append!(as, sub_as.sub_ainfos)
182+
else
183+
push!(as, sub_as)
184+
end
185+
end
186+
return CombinedAliasing(as)
187+
else
188+
warn_unknown_aliasing(T)
189+
return UnknownAliasing()
190+
end
191+
end
192+
aliasing(::String) = NoAliasing() # FIXME: Not necessarily true
193+
aliasing(::Symbol) = NoAliasing()
194+
aliasing(::Type) = NoAliasing()
109195
aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T
110196
aliasing(unwrap(x), T)
111197
end
@@ -129,6 +215,7 @@ function aliasing(x::Array{T}) where T
129215
else
130216
# FIXME: Also ContiguousAliasing of container
131217
#return IteratedAliasing(x)
218+
warn_unknown_aliasing(T)
132219
return UnknownAliasing()
133220
end
134221
end
@@ -173,6 +260,7 @@ function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array}
173260
else
174261
# FIXME: Also ContiguousAliasing of container
175262
#return IteratedAliasing(x)
263+
warn_unknown_aliasing(T)
176264
return UnknownAliasing()
177265
end
178266
end

test/datadeps.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
using LinearAlgebra, Graphs
22

3-
empty!(Dagger.DAGDEBUG_CATEGORIES)
4-
push!(Dagger.DAGDEBUG_CATEGORIES, :spawn_datadeps)
3+
@testset "Memory Aliasing" begin
4+
A = rand(4)
5+
a = Dagger.aliasing(A)
6+
@test a isa Dagger.ContiguousAliasing
7+
@test a.span.ptr.addr == UInt(pointer(A))
8+
@test a.span.len == sizeof(Float64) * length(A)
9+
10+
r = Ref(3)
11+
a = Dagger.aliasing(r)
12+
@test a isa Dagger.CombinedAliasing
13+
@test length(a.sub_ainfos) == 1
14+
s = only(a.sub_ainfos)
15+
@test s isa Dagger.ObjectAliasing
16+
@test s.ptr == pointer_from_objref(r)
17+
@test s.sz == sizeof(3)
18+
end
519

620
function with_logs(f)
721
Dagger.enable_logging!(;taskdeps=true, taskargs=true)

0 commit comments

Comments
 (0)