Skip to content

Commit 51fdb5e

Browse files
authored
Merge pull request #401 from JuliaParallel/jps/darray-correct-eltype
DArray: Figure out correct eltype
2 parents f94be28 + 6e0704c commit 51fdb5e

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

.buildkite/pipeline.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
queue: "juliaecosystem"
55
sandbox_capable: "true"
66
os: linux
7-
arch: x86_64
87
command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'"
98
.bench: &bench
109
if: build.message =~ /\[run benchmarks\]/

src/array/darray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ mutable struct DArray{T,N,F} <: ArrayOp{T, N}
116116
subdomains::AbstractArray{ArrayDomain{N}, N}
117117
chunks::AbstractArray{Any, N}
118118
concat::F
119-
function DArray{T,N,F}(domain, subdomains, chunks, concat::Function) where {T, N,F}
120-
new(domain, subdomains, chunks, concat)
119+
function DArray{T,N,F}(domain, subdomains, chunks, concat::Function) where {T,N,F}
120+
new{T,N,F}(domain, subdomains, chunks, concat)
121121
end
122122
end
123123

@@ -227,14 +227,14 @@ end
227227
228228
If a `DArray` tree has a `Thunk` in it, make the whole thing a big thunk.
229229
"""
230-
function Base.fetch(c::DArray)
230+
function Base.fetch(c::DArray{T}) where T
231231
if any(istask, chunks(c))
232232
thunks = chunks(c)
233233
sz = size(thunks)
234234
dmn = domain(c)
235235
dmnchunks = domainchunks(c)
236236
fetch(Dagger.spawn(Options(meta=true), thunks...) do results...
237-
t = eltype(results[1])
237+
t = eltype(fetch(results[1]))
238238
DArray(t, dmn, dmnchunks, reshape(Any[results...], sz))
239239
end)
240240
else

src/array/map-reduce.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ function stage(ctx::Context, node::Map)
1919
f = node.f
2020
for i=eachindex(domains)
2121
inps = map(x->chunks(x)[i], inputs)
22-
thunks[i] = Dagger.@spawn map(f, inps...)
22+
thunks[i] = Dagger.@spawn map(f, inps...)
2323
end
24-
DArray(Any, domain(primary), domainchunks(primary), thunks)
24+
RT = Base.promote_op(node.f, map(eltype, node.inputs)...)
25+
return DArray(RT, domain(primary), domainchunks(primary), thunks)
2526
end
2627

2728
map(f, x::ArrayOp, xs::ArrayOp...) = Map(f, (x, xs...))

test/array.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ end
1313
@testset "DArray constructor" begin
1414
x = fetch(rand(Blocks(2,2), 3,3))
1515
@test collect(x) == DArray{Float64, 2}(x.domain, x.subdomains, x.chunks) |> collect
16+
@test x isa DArray{Float64, 2}
1617
end
1718

1819
@testset "rand" begin
@@ -39,6 +40,13 @@ end
3940
@test r[1:10] != r[11:20]
4041
end
4142

43+
@testset "map" begin
44+
X1 = fetch(ones(Blocks(10, 10), 100, 100))
45+
X2 = fetch(map(x->x+1, X1))
46+
@test typeof(X1) === typeof(X2)
47+
@test collect(X1) .+ 1 == collect(X2)
48+
end
49+
4250
@testset "sum" begin
4351
X = ones(Blocks(10, 10), 100, 100)
4452
@test sum(X) == 10000

0 commit comments

Comments
 (0)