Skip to content

Commit 2048604

Browse files
committed
Initial commit
0 parents  commit 2048604

File tree

6 files changed

+238
-0
lines changed

6 files changed

+238
-0
lines changed

Manifest.toml

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
[[Base64]]
4+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
5+
6+
[[Dagger]]
7+
deps = ["Distributed", "LinearAlgebra", "MemPool", "Profile", "Random", "Serialization", "SharedArrays", "SparseArrays", "Statistics", "StatsBase"]
8+
git-tree-sha1 = "e77f451e4c1f9acbf794cb6377ec42130ff10f56"
9+
repo-rev = "jps/compute-resource"
10+
repo-url = "https://github.com/JuliaParallel/Dagger.jl.git"
11+
uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
12+
version = "0.8.0"
13+
14+
[[DataAPI]]
15+
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
16+
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
17+
version = "1.1.0"
18+
19+
[[DataStructures]]
20+
deps = ["InteractiveUtils", "OrderedCollections"]
21+
git-tree-sha1 = "73eb18320fe3ba58790c8b8f6f89420f0a622773"
22+
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
23+
version = "0.17.11"
24+
25+
[[Distributed]]
26+
deps = ["Random", "Serialization", "Sockets"]
27+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
28+
29+
[[InteractiveUtils]]
30+
deps = ["Markdown"]
31+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
32+
33+
[[Libdl]]
34+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
35+
36+
[[LinearAlgebra]]
37+
deps = ["Libdl"]
38+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
39+
40+
[[Logging]]
41+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
42+
43+
[[Markdown]]
44+
deps = ["Base64"]
45+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
46+
47+
[[MemPool]]
48+
deps = ["DataStructures", "Distributed", "Mmap", "Random", "Serialization", "Sockets", "Test"]
49+
git-tree-sha1 = "d52799152697059353a8eac1000d32ba8d92aa25"
50+
uuid = "f9f48841-c794-520a-933b-121f7ba6ed94"
51+
version = "0.2.0"
52+
53+
[[Missings]]
54+
deps = ["DataAPI"]
55+
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
56+
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
57+
version = "0.4.3"
58+
59+
[[Mmap]]
60+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
61+
62+
[[OrderedCollections]]
63+
deps = ["Random", "Serialization", "Test"]
64+
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
65+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
66+
version = "1.1.0"
67+
68+
[[Printf]]
69+
deps = ["Unicode"]
70+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
71+
72+
[[Profile]]
73+
deps = ["Printf"]
74+
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
75+
76+
[[Random]]
77+
deps = ["Serialization"]
78+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
79+
80+
[[Requires]]
81+
deps = ["UUIDs"]
82+
git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b"
83+
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
84+
version = "1.0.1"
85+
86+
[[SHA]]
87+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
88+
89+
[[Serialization]]
90+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
91+
92+
[[SharedArrays]]
93+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
94+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
95+
96+
[[Sockets]]
97+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
98+
99+
[[SortingAlgorithms]]
100+
deps = ["DataStructures", "Random", "Test"]
101+
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
102+
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
103+
version = "0.3.1"
104+
105+
[[SparseArrays]]
106+
deps = ["LinearAlgebra", "Random"]
107+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
108+
109+
[[Statistics]]
110+
deps = ["LinearAlgebra", "SparseArrays"]
111+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
112+
113+
[[StatsBase]]
114+
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
115+
git-tree-sha1 = "19bfcb46245f69ff4013b3df3b977a289852c3a1"
116+
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
117+
version = "0.32.2"
118+
119+
[[Test]]
120+
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
121+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
122+
123+
[[UUIDs]]
124+
deps = ["Random", "SHA"]
125+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
126+
127+
[[Unicode]]
128+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name = "DaggerGPU"
2+
uuid = "68e73e28-2238-4d5a-bf97-e5d4aa3c4be2"
3+
authors = ["Julian P Samaroo <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
8+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
9+
10+
[compat]
11+
julia = "1.0"
12+
13+
[extras]
14+
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
15+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
16+
ROCArrays = "ddf941ca-5d6a-11e9-36cc-a3fed13dd2fc"
17+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18+
19+
[targets]
20+
test = ["CuArrays", "Distributed", "ROCArrays", "Test"]

src/DaggerGPU.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
module DaggerGPU
2+
3+
using Dagger, Requires
4+
5+
macro gpuproc(PROC, T)
6+
quote
7+
Dagger.iscompatible(proc::$PROC, opts, x::AbstractArray{AT}) where AT =
8+
isbitstype(AT)
9+
Dagger.move(ctx, from_proc::OSProc, to_proc::$PROC, x::AbstractArray) =
10+
$T(x)
11+
Dagger.move(ctx, from_proc::$PROC, to_proc::OSProc, x) = x
12+
Dagger.move(ctx, from_proc::$PROC, to_proc::OSProc, x::$T) =
13+
collect(x)
14+
Dagger.execute!(proc::$PROC, func, args...) = func(args...)
15+
end
16+
end
17+
18+
processor(kind::Symbol) = processor(Val(kind))
19+
processor(::Val) = Dagger.ThreadProc
20+
21+
function __init__()
22+
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
23+
include("cuarrays.jl")
24+
end
25+
@require ROCArrays="ddf941ca-5d6a-11e9-36cc-a3fed13dd2fc" begin
26+
include("rocarrays.jl")
27+
end
28+
end
29+
30+
end

src/cuarrays.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using .CuArrays
2+
import .CuArrays: CUDAapi, CUDAdrv
3+
4+
struct CuArrayProc <: Dagger.Processor
5+
device
6+
end
7+
8+
@gpuproc(CuArrayProc, CuArray)
9+
10+
11+
push!(Dagger.PROCESSOR_CALLBACKS, proc -> begin
12+
if CUDAapi.has_cuda()
13+
@eval processor(::Val{:CUDA}) = CuArrayProc
14+
return CuArrayProc(first(devices()))
15+
end
16+
end)

src/rocarrays.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using .ROCArrays
2+
import .ROCArrays: AMDGPUnative.HSARuntime
3+
4+
struct ROCArrayProc <: Dagger.Processor
5+
device
6+
end
7+
8+
@gpuproc(ROCArrayProc, ROCArray)
9+
10+
11+
push!(Dagger.PROCESSOR_CALLBACKS, proc -> begin
12+
if ROCArrays.configured
13+
@eval processor(::Val{:ROC}) = ROCArrayProc
14+
return ROCArrayProc(HSARuntime.get_default_agent())
15+
end
16+
end)

test/runtests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using Distributed
2+
using Test
3+
addprocs(2)
4+
5+
@everywhere begin
6+
7+
using Distributed, Dagger, DaggerGPU
8+
using CuArrays, ROCArrays
9+
10+
end
11+
12+
cuproc = DaggerGPU.processor(:CUDA)
13+
rocproc = DaggerGPU.processor(:ROC)
14+
15+
cuproc === Dagger.ThreadProc && @warn "No CUDA devices available"
16+
rocproc === Dagger.ThreadProc && @warn "No ROCm devices available"
17+
18+
as = [delayed(x->x+1)(1) for i in 1:10]
19+
b = delayed((xs...)->[sum(xs)])(as...)
20+
21+
opts = Dagger.Sch.ThunkOptions(;proctypes=[cuproc])
22+
c1 = delayed(sum; options=opts)(b)
23+
opts = Dagger.Sch.ThunkOptions(;proctypes=[rocproc])
24+
c2 = delayed(sum; options=opts)(b)
25+
26+
opts = Dagger.Sch.ThunkOptions(;proctypes=[Dagger.ThreadProc])
27+
d = delayed((x,y)->x+y; options=opts)(c1,c2)
28+
@test collect(d) == 40

0 commit comments

Comments
 (0)