Skip to content
34 changes: 5 additions & 29 deletions ext/JustPICAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ JustPIC.TA(::Type{AMDGPUBackend}) = ROCArray
function ROCCellArray(
::Type{T}, ::UndefInitializer, dims::NTuple{N, Int}
) where {T <: CellArrays.Cell, N}
return CellArrays.CellArray{T, N, 0, AMDGPU.ROCCellArrayArray{eltype(T), 3}}(undef, dims)
return CellArrays.CellArray{T, N, 0, AMDGPU.ROCArray{eltype(T), 3}}(undef, dims)
end
function ROCCellArray(
::Type{T}, ::UndefInitializer, dims::Int...
Expand All @@ -33,26 +33,7 @@ function AMDGPU.ROCArray(::Type{T}, chain::JustPIC.MarkerChain) where {T <: Numb
coords_gpu = ntuple(i -> ROCArray(T, coords[i]), Val(length(coords)))
coords0_gpu = ntuple(i -> ROCArray(T, coords0[i]), Val(length(coords0)))
return MarkerChain(
CUDABackend,
coords_gpu,
coords0_gpu,
ROCArray(h_vertices),
ROCArray(h_vertices0),
cell_vertices,
ROCArray(Bool, index),
max_xcell,
min_xcell,
)
end

function AMDGPU.ROCArray(::Type{T}, chain::JustPIC.MarkerChain) where {T <: Number}
(;
cell_vertices, coords, coords0, h_vertices, h_vertices0, index, max_xcell, min_xcell,
) = chain
coords_gpu = ntuple(i -> ROCArray(T, coords[i]), Val(length(coords)))
coords0_gpu = ntuple(i -> ROCArray(T, coords0[i]), Val(length(coords0)))
return MarkerChain(
CUDABackend,
AMDGPUBackend,
coords_gpu,
coords0_gpu,
ROCArray(h_vertices),
Expand Down Expand Up @@ -109,7 +90,7 @@ function AMDGPU.ROCArray(chain::JustPIC.MarkerChain)
coords_gpu = ntuple(i -> ROCArray(coords[i]), Val(length(coords)))
coords0_gpu = ntuple(i -> ROCArray(coords0[i]), Val(length(coords0)))
return MarkerChain(
CUDABackend,
AMDGPUBackend,
coords_gpu,
coords0_gpu,
ROCArray(h_vertices),
Expand All @@ -121,11 +102,6 @@ function AMDGPU.ROCArray(chain::JustPIC.MarkerChain)
)
end

function AMDGPU.ROCArray(phase_ratios::JustPIC.PhaseRatios)
(; vertex, center) = phase_ratios
return JustPIC.PhaseRatios(AMDGPUBackend, ROCArray(center), ROCArray(vertex))
end

function AMDGPU.ROCArray(::Type{T}, CA::CellArray) where {T <: Number}
ni = size(CA)
# Array initializations
Expand Down Expand Up @@ -216,10 +192,10 @@ module _2D
end

function JustPIC._2D.init_particles(
::Type{AMDGPUBackend}, nxcell, max_xcell, min_xcell, x, y, z; buffer = 1 - 1.0e-5
::Type{AMDGPUBackend}, nxcell, max_xcell, min_xcell, x, y; buffer = 1 - 1.0e-5
)
return init_particles(
AMDGPUBackend, nxcell, max_xcell, min_xcell, x, y, z; buffer = buffer
AMDGPUBackend, nxcell, max_xcell, min_xcell, x, y; buffer = buffer
)
end

Expand Down
4 changes: 2 additions & 2 deletions test/test_save_load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ const backend = JustPIC.CPUBackend

if isCUDA || isAMDGPU
T = isCUDA ? CuArray : ROCArray
Backend = isCUDA ? CUDABackend : AMDGPUBackend
Backend = isCUDA ? CUDABackend : JustPIC.AMDGPUBackend

particles2 = Array(particles)
phases2 = Array(phases)
Expand Down Expand Up @@ -239,7 +239,7 @@ end

if isCUDA || isAMDGPU
T = isCUDA ? CuArray : ROCArray
Backend = isCUDA ? CUDABackend : AMDGPUBackend
Backend = isCUDA ? CUDABackend : JustPIC.AMDGPUBackend

particles2 = Array(particles)
phases2 = Array(phases)
Expand Down
Loading