Skip to content

Commit b1ca6d9

Browse files
committed
Merge branch 'master' of ssh://github.com/jpsamaroo/DaggerGPU.jl
2 parents 11873cd + fede992 commit b1ca6d9

File tree

7 files changed

+30
-7
lines changed

7 files changed

+30
-7
lines changed

.github/workflows/TagBot.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: TagBot
2+
on:
3+
schedule:
4+
- cron: 0 * * * *
5+
jobs:
6+
TagBot:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- uses: JuliaRegistries/TagBot@v1
10+
with:
11+
token: ${{ secrets.GITHUB_TOKEN }}

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
55

66
[[Dagger]]
77
deps = ["Distributed", "LinearAlgebra", "MemPool", "Profile", "Random", "Serialization", "SharedArrays", "SparseArrays", "Statistics", "StatsBase"]
8-
git-tree-sha1 = "e77f451e4c1f9acbf794cb6377ec42130ff10f56"
8+
git-tree-sha1 = "8262f275c3acf3787fc70d1bd99aec011a02ddb8"
99
repo-rev = "jps/compute-resource"
1010
repo-url = "https://github.com/JuliaParallel/Dagger.jl.git"
1111
uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
88
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
99

1010
[compat]
11-
julia = "1.0"
11+
Dagger = "0.8"
12+
Requires = "1.0"
13+
julia = "1"
1214

1315
[extras]
1416
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"

src/DaggerGPU.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ end
1717

1818
processor(kind::Symbol) = processor(Val(kind))
1919
processor(::Val) = Dagger.ThreadProc
20+
cancompute(kind::Symbol) = cancompute(Val(kind))
21+
cancompute(::Val) = false
2022

2123
function __init__()
2224
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin

src/cuarrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ end
77

88
@gpuproc(CuArrayProc, CuArray)
99

10+
processor(::Val{:CUDA}) = CuArrayProc
11+
cancompute(::Val{:CUDA}) = CUDAapi.has_cuda()
1012

1113
push!(Dagger.PROCESSOR_CALLBACKS, proc -> begin
1214
if CUDAapi.has_cuda()
13-
@eval processor(::Val{:CUDA}) = CuArrayProc
14-
return CuArrayProc(first(devices()))
15+
return CuArrayProc(first(CUDAdrv.devices()))
1516
end
1617
end)

src/rocarrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ end
77

88
@gpuproc(ROCArrayProc, ROCArray)
99

10+
processor(::Val{:ROC}) = ROCArrayProc
11+
cancompute(::Val{:ROC}) = ROCArrays.configured
1012

1113
push!(Dagger.PROCESSOR_CALLBACKS, proc -> begin
1214
if ROCArrays.configured
13-
@eval processor(::Val{:ROC}) = ROCArrayProc
1415
return ROCArrayProc(HSARuntime.get_default_agent())
1516
end
1617
end)

test/runtests.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@ end
1212
cuproc = DaggerGPU.processor(:CUDA)
1313
rocproc = DaggerGPU.processor(:ROC)
1414

15-
cuproc === Dagger.ThreadProc && @warn "No CUDA devices available"
16-
rocproc === Dagger.ThreadProc && @warn "No ROCm devices available"
15+
if !DaggerGPU.cancompute(:CUDA)
16+
@warn "No CUDA devices available, falling back to ThreadProc"
17+
cuproc = Dagger.ThreadProc
18+
end
19+
if !DaggerGPU.cancompute(:ROC)
20+
@warn "No ROCm devices available, falling back to ThreadProc"
21+
rocproc = Dagger.ThreadProc
22+
end
1723

1824
as = [delayed(x->x+1)(1) for i in 1:10]
1925
b = delayed((xs...)->[sum(xs)])(as...)

0 commit comments

Comments
 (0)