File tree Expand file tree Collapse file tree 5 files changed +42
-6
lines changed Expand file tree Collapse file tree 5 files changed +42
-6
lines changed Original file line number Diff line number Diff line change @@ -16,7 +16,21 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1616Serialization = " 9e88b42a-f829-5b0c-bbe9-9e923198166b"
1717Sockets = " 6462fe0b-24de-5631-8697-dd941f90decc"
1818
19+ [extras ]
20+ AMDGPU = " 21141c5a-9bdb-4563-92ae-f87d6854732e"
21+ CUDA = " 052768ef-5323-5732-b1bb-66c8b64840ba"
22+
23+ [weakdeps ]
24+ AMDGPU = " 21141c5a-9bdb-4563-92ae-f87d6854732e"
25+ CUDA = " 052768ef-5323-5732-b1bb-66c8b64840ba"
26+
27+ [extensions ]
28+ AMDGPUExt = " AMDGPU"
29+ CUDAExt = " CUDA"
30+
1931[compat ]
32+ AMDGPU = " 0.3, 0.4"
33+ CUDA = " 3, 4"
2034DocStringExtensions = " 0.8, 0.9"
2135MPIPreferences = " 0.1.6"
2236Requires = " ~0.5, 1.0"
Original file line number Diff line number Diff line change 1- import . AMDGPU
1+ module AMDGPUExt
2+
3+ import MPI
4+ isdefined (Base, :get_extension ) ? (import AMDGPU) : (import .. AMDGPU)
5+ import MPI: MPIPtr, Buffer, Datatype
6+
27
38function Base. cconvert (:: Type{MPIPtr} , A:: AMDGPU.ROCArray{T} ) where T
49 A
1924function Buffer (arr:: AMDGPU.ROCArray )
2025 Buffer (arr, Cint (length (arr)), Datatype (eltype (arr)))
2126end
27+
28+ end # AMDGPUExt
Original file line number Diff line number Diff line change 1- import . CUDA
1+ module CUDAExt
2+
3+ import MPI
4+ isdefined (Base, :get_extension ) ? (import CUDA) : (import .. CUDA)
5+ import MPI: MPIPtr, Buffer, Datatype
26
37function Base. cconvert (:: Type{MPIPtr} , buf:: CUDA.CuArray{T} ) where T
48 Base. cconvert (CUDA. CuPtr{T}, buf) # returns DeviceBuffer
1923function Buffer (arr:: CUDA.CuArray )
2024 Buffer (arr, Cint (length (arr)), Datatype (eltype (arr)))
2125end
26+
27+ end # CUDAExt
Original file line number Diff line number Diff line change 11module MPI
22
33using Libdl, Serialization
4- using Requires
54using DocStringExtensions
65import MPIPreferences
76
@@ -80,6 +79,10 @@ include("misc.jl")
8079
8180include (" deprecated.jl" )
8281
82+ if ! isdefined (Base, :get_extension )
83+ using Requires
84+ end
85+
8386function __init__ ()
8487 MPIPreferences. check_unchanged ()
8588
@@ -136,8 +139,10 @@ function __init__()
136139
137140 run_load_time_hooks ()
138141
139- @require AMDGPU= " 21141c5a-9bdb-4563-92ae-f87d6854732e" include (" rocm.jl" )
140- @require CUDA= " 052768ef-5323-5732-b1bb-66c8b64840ba" include (" cuda.jl" )
142+ @static if ! isdefined (Base, :get_extension )
143+ @require AMDGPU= " 21141c5a-9bdb-4563-92ae-f87d6854732e" include (" ../ext/AMDGPUExt.jl" )
144+ @require CUDA= " 052768ef-5323-5732-b1bb-66c8b64840ba" include (" ../ext/CUDAExt.jl" )
145+ end
141146end
142147
143148end
Original file line number Diff line number Diff line change @@ -6,7 +6,11 @@ using MPIPreferences
66using DoubleFloats
77if get (ENV , " JULIA_MPI_TEST_ARRAYTYPE" , " " ) == " CuArray"
88 import CUDA
9- CUDA. version ()
9+ if isdefined (CUDA, :versioninfo )
10+ CUDA. versioninfo ()
11+ else
12+ CUDA. version ()
13+ end
1014 CUDA. precompile_runtime ()
1115 ArrayType = CUDA. CuArray
1216elseif get (ENV ," JULIA_MPI_TEST_ARRAYTYPE" ," " ) == " ROCArray"
You can’t perform that action at this time.
0 commit comments