Skip to content
40 changes: 8 additions & 32 deletions ext/JustPICAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ JustPIC.TA(::Type{AMDGPUBackend}) = ROCArray
function ROCCellArray(
::Type{T}, ::UndefInitializer, dims::NTuple{N, Int}
) where {T <: CellArrays.Cell, N}
return CellArrays.CellArray{T, N, 0, AMDGPU.ROCCellArrayArray{eltype(T), 3}}(undef, dims)
return CellArrays.CellArray{T, N, 0, AMDGPU.ROCArray{eltype(T), 3}}(undef, dims)
end
function ROCCellArray(
::Type{T}, ::UndefInitializer, dims::Int...
Expand All @@ -33,26 +33,7 @@ function AMDGPU.ROCArray(::Type{T}, chain::JustPIC.MarkerChain) where {T <: Numb
coords_gpu = ntuple(i -> ROCArray(T, coords[i]), Val(length(coords)))
coords0_gpu = ntuple(i -> ROCArray(T, coords0[i]), Val(length(coords0)))
return MarkerChain(
CUDABackend,
coords_gpu,
coords0_gpu,
ROCArray(h_vertices),
ROCArray(h_vertices0),
cell_vertices,
ROCArray(Bool, index),
max_xcell,
min_xcell,
)
end

function AMDGPU.ROCArray(::Type{T}, chain::JustPIC.MarkerChain) where {T <: Number}
(;
cell_vertices, coords, coords0, h_vertices, h_vertices0, index, max_xcell, min_xcell,
) = chain
coords_gpu = ntuple(i -> ROCArray(T, coords[i]), Val(length(coords)))
coords0_gpu = ntuple(i -> ROCArray(T, coords0[i]), Val(length(coords0)))
return MarkerChain(
CUDABackend,
AMDGPUBackend,
coords_gpu,
coords0_gpu,
ROCArray(h_vertices),
Expand Down Expand Up @@ -109,7 +90,7 @@ function AMDGPU.ROCArray(chain::JustPIC.MarkerChain)
coords_gpu = ntuple(i -> ROCArray(coords[i]), Val(length(coords)))
coords0_gpu = ntuple(i -> ROCArray(coords0[i]), Val(length(coords0)))
return MarkerChain(
CUDABackend,
AMDGPUBackend,
coords_gpu,
coords0_gpu,
ROCArray(h_vertices),
Expand All @@ -121,11 +102,6 @@ function AMDGPU.ROCArray(chain::JustPIC.MarkerChain)
)
end

function AMDGPU.ROCArray(phase_ratios::JustPIC.PhaseRatios)
(; vertex, center) = phase_ratios
return JustPIC.PhaseRatios(AMDGPUBackend, ROCArray(center), ROCArray(vertex))
end

function AMDGPU.ROCArray(::Type{T}, CA::CellArray) where {T <: Number}
ni = size(CA)
# Array initializations
Expand Down Expand Up @@ -216,10 +192,10 @@ module _2D
end

function JustPIC._2D.init_particles(
::Type{AMDGPUBackend}, nxcell, max_xcell, min_xcell, x, y, z; buffer = 1 - 1.0e-5
::Type{AMDGPUBackend}, nxcell, max_xcell, min_xcell, x, y; buffer = 1 - 1.0e-5
)
return init_particles(
AMDGPUBackend, nxcell, max_xcell, min_xcell, x, y, z; buffer = buffer
AMDGPUBackend, nxcell, max_xcell, min_xcell, x, y; buffer = buffer
)
end

Expand All @@ -228,9 +204,9 @@ module _2D
nxcell,
max_xcell,
min_xcell,
coords::NTuple{3, AbstractArray},
dxᵢ::NTuple{3, T},
nᵢ::NTuple{3, I};
coords::NTuple{2, AbstractArray},
dxᵢ::NTuple{2, T},
nᵢ::NTuple{2, I};
buffer = 1 - 1.0e-5,
) where {T, I}
return init_particles(
Expand Down
2 changes: 1 addition & 1 deletion src/PhaseRatios/midpoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
) where {N, T}

# index corresponding to the cell center
cell_center = getindex.(xci, I)
cell_center = ntuple(i -> xci[i][I[i]], Val(N))
cell_face = @. cell_center + di * offsets / 2
ni = size(phases)
NC = nphases(ratio_faces)
Expand Down
8 changes: 4 additions & 4 deletions src/PhaseRatios/vertices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ end
any(isnan, p) && continue
# check if it's within half cell
tmp = false
for x in zip(p, cell_vertex, di)
if abs(x[1] - x[2]) ≥ x[3] / 2
for i in eachindex(p)
if abs(p[i] - cell_vertex[i]) ≥ di[i] / 2
tmp = true
break
end
Expand Down Expand Up @@ -82,8 +82,8 @@ end
any(isnan, p) && continue
# check if it's within half cell
tmp = false
for x in zip(p, cell_vertex, di)
if abs(x[1] - x[2]) ≥ x[3] / 2
for i in eachindex(p)
if abs(p[i] - cell_vertex[i]) ≥ di[i] / 2
tmp = true
break
end
Expand Down
3 changes: 2 additions & 1 deletion test/test_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ end
values(Δ),
values(Nc)
) # random position by default

println("Type particles: ", typeof(particles))
println("Type index: ", typeof(particles.index))
# Initialise phase field
particle_args = phases, = init_cell_arrays(particles, Val(1)) # cool

Expand Down
4 changes: 2 additions & 2 deletions test/test_save_load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ const backend = JustPIC.CPUBackend

if isCUDA || isAMDGPU
T = isCUDA ? CuArray : ROCArray
Backend = isCUDA ? CUDABackend : AMDGPUBackend
Backend = isCUDA ? CUDABackend : JustPIC.AMDGPUBackend

particles2 = Array(particles)
phases2 = Array(phases)
Expand Down Expand Up @@ -239,7 +239,7 @@ end

if isCUDA || isAMDGPU
T = isCUDA ? CuArray : ROCArray
Backend = isCUDA ? CUDABackend : AMDGPUBackend
Backend = isCUDA ? CUDABackend : JustPIC.AMDGPUBackend

particles2 = Array(particles)
phases2 = Array(phases)
Expand Down
Loading