Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions ext/MPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,30 +261,31 @@ function Base.getindex(x::DevitoMPIAbstractArray{T,N}, I::Vararg{Int,N}) where {
v
end

function Base.setindex!(x::DevitoMPIAbstractArray{T,N}, v::T, I::Vararg{Int,N}) where {T,N}
myrank = MPI.Comm_rank(MPI.COMM_WORLD)
if myrank == 0
@warn "`setindex!` for Devito MPI Arrays has suboptimal performance. consider using `copy!`"
end
wanted_rank = find_rank(x, I...)
if wanted_rank == 0
received_v = v
else
message_tag = 2*MPI.Comm_size(MPI.COMM_WORLD)
source_rank = 0
send_mesg = [v]
recv_mesg = 0 .* send_mesg
rreq = ( myrank == wanted_rank ? MPI.Irecv!(recv_mesg, source_rank, message_tag, MPI.COMM_WORLD) : MPI.Request())
sreq = ( myrank == source_rank ? MPI.Isend(send_mesg, wanted_rank, message_tag, MPI.COMM_WORLD) : MPI.Request() )
stats = MPI.Waitall!([rreq, sreq])
received_v = recv_mesg[1]
end
if myrank == wanted_rank
J = ntuple(idim-> Devito.shift_localindicies( I[idim], localindices(x)[idim]), N)
setindex!(x.p, received_v, J...)
end
MPI.Barrier(MPI.COMM_WORLD)
end
# 2025-09-03 JKW this is never ever used in practice - remove?
# function Base.setindex!(x::DevitoMPIAbstractArray{T,N}, v::T, I::Vararg{Int,N}) where {T,N}
# myrank = MPI.Comm_rank(MPI.COMM_WORLD)
# if myrank == 0
# @warn "`setindex!` for Devito MPI Arrays has suboptimal performance. consider using `copy!`"
# end
# wanted_rank = find_rank(x, I...)
# if wanted_rank == 0
# received_v = v
# else
# message_tag = 2*MPI.Comm_size(MPI.COMM_WORLD)
# source_rank = 0
# send_mesg = [v]
# recv_mesg = 0 .* send_mesg
# rreq = ( myrank == wanted_rank ? MPI.Irecv!(recv_mesg, source_rank, message_tag, MPI.COMM_WORLD) : MPI.Request())
# sreq = ( myrank == source_rank ? MPI.Isend(send_mesg, wanted_rank, message_tag, MPI.COMM_WORLD) : MPI.Request() )
# stats = MPI.Waitall!([rreq, sreq])
# received_v = recv_mesg[1]
# end
# if myrank == wanted_rank
# J = ntuple(idim-> Devito.shift_localindicies( I[idim], localindices(x)[idim]), N)
# setindex!(x.p, received_v, J...)
# end
# MPI.Barrier(MPI.COMM_WORLD)
# end

Base.size(x::SparseDiscreteFunction{T,N,DevitoMPITrue}) where {T,N} = size(data(x))

Expand Down
50 changes: 31 additions & 19 deletions test/devitoprotests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Devito, PyCall, Test

@testset "ABox Expanding Source" begin
@test_skip @testset "ABox Expanding Source" begin
g = Grid(shape=(8,8), extent=(7.0,7.0))
nt = 3
coords = [0.5 2.5; 2.5 2.5; 0.5 4.5; 2.5 4.5]
Expand All @@ -20,7 +20,7 @@ end
# TODO (9/2/2025) - failing with decoupler, mloubout is looking into the issue
if get(ENV, "DEVITO_DECOUPLER", "0") != "1"
# TODO - 2024-08-15 JKW these two ABox tests are broken -- some kind of API change?
@testset "ABox Time Function" begin
@test_skip @testset "ABox Time Function" begin
g = Grid(shape=(5,5), extent=(4.0,4.0))
nt = 3
coords = [2. 2. ;]
Expand All @@ -46,7 +46,7 @@ end

# TODO (9/2/2025)- failing with decoupler, mloubout is looking into the issue
if get(ENV, "DEVITO_DECOUPLER", "0") != "1"
@testset "ABox Intersection Time Function" begin
@test_skip @testset "ABox Intersection Time Function" begin
mid = SubDomain("mid",[("middle",2,2),("middle",0,0)])
g = Grid(shape=(5,5), extent=(4.0,4.0), subdomains=mid)
nt = 3
Expand All @@ -73,7 +73,7 @@ if get(ENV, "DEVITO_DECOUPLER", "0") != "1"
end
end

@testset "FloatX dtypes with $(mytype), $(DT), $(CT)" for mytype ∈ [Float32, Float64], (nb, DT, CT) in zip([8, 16], [FloatX8, FloatX16], [UInt8, UInt16])
@test_skip @testset "FloatX dtypes with $(mytype), $(DT), $(CT)" for mytype ∈ [Float32, Float64], (nb, DT, CT) in zip([8, 16], [FloatX8, FloatX16], [UInt8, UInt16])
g = Grid(shape=(5,5))
dtype = DT(1.5f0, 4.5f0)
atol = Devito.scale(dtype)
Expand Down Expand Up @@ -116,7 +116,7 @@ end
@test all(data(f) .== 1.5f0)
end

@testset "FloatX addition" for DT ∈ (FloatX8, FloatX16)
@test_skip @testset "FloatX addition" for DT ∈ (FloatX8, FloatX16)
dtype = DT(1.5f0, 4.5f0)
a = dtype(1.5f0)
b = dtype(1.5f0)
Expand All @@ -125,7 +125,7 @@ end
@test Base.:+(a,1.5f0) ≈ dtype(1.5f0 + 1.5f0).value
end

@testset "FloatX subtraction" for DT ∈ (FloatX8, FloatX16)
@test_skip @testset "FloatX subtraction" for DT ∈ (FloatX8, FloatX16)
dtype = DT(1.5f0, 4.5f0)
a = dtype(3.0f0)
b = dtype(1.5f0)
Expand All @@ -134,7 +134,7 @@ end
@test Base.:-(a,1.5f0) ≈ dtype(3.0f0 - 1.5f0).value
end

@testset "FloatX multiplication" for DT ∈ (FloatX8, FloatX16)
@test_skip @testset "FloatX multiplication" for DT ∈ (FloatX8, FloatX16)
dtype = DT(1.5f0, 4.5f0)
a = dtype(1.5f0)
b = dtype(1.5f0)
Expand All @@ -143,7 +143,7 @@ end
@test Base.:*(a,1.5f0) ≈ dtype(1.5f0 * 1.5f0).value
end

@testset "FloatX division" for DT ∈ (FloatX8, FloatX16)
@test_skip @testset "FloatX division" for DT ∈ (FloatX8, FloatX16)
dtype = DT(1.5f0, 4.5f0)
a = dtype(3.0f0)
b = dtype(1.5f0)
Expand All @@ -152,7 +152,7 @@ end
@test Base.:/(a,1.5f0) ≈ dtype(3.0f0 / 1.5f0).value
end

@testset "FloatX comparison" for DT ∈ (FloatX8, FloatX16)
@test_skip @testset "FloatX comparison" for DT ∈ (FloatX8, FloatX16)
dtype = DT(1.5f0, 4.5f0)
a = dtype(1.5f0)
b = dtype(1.5f0)
Expand All @@ -164,20 +164,26 @@ end
@test Base.isapprox(a,1.5f0)
end

@testset "FloatX convert" for DT ∈ (FloatX8, FloatX16)
@test_skip @testset "FloatX convert" for DT ∈ (FloatX8, FloatX16)
dtype = DT(1.5f0, 4.5f0)
a = dtype(1.5f0)
@test Base.convert(typeof(a),1.5f0) == a
@test Base.convert(Float32,a) ≈ 1.5f0
end

@testset "FloatX eps with $(mytype), $(DT), $(CT)" for mytype ∈ [Float32, Float64], (DT, CT) in zip([FloatX8, FloatX16], [UInt8, UInt16])
g = Grid(shape=(5,5))
dtype = DT(mytype(1.5), mytype(4.5))
@test eps(dtype) ≈ eps(mytype)
@testset "FloatX promote_rule tests" begin
fmin,fmax = 1.5, 4.5
f32u08 = Devito.FloatX{fmin,fmax,Float32,UInt8}(Float32(2))
f32u16 = Devito.FloatX{fmin,fmax,Float32,UInt16}(Float32(2))
f64u08 = Devito.FloatX{fmin,fmax,Float64,UInt8}(Float64(2))
f64u16 = Devito.FloatX{fmin,fmax,Float64,UInt16}(Float64(2))
@test promote_type(typeof(f32u08), typeof(f32u16)) == typeof(f32u16)
@test promote_type(typeof(f64u08), typeof(f64u16)) == typeof(f64u16)
@test promote_type(typeof(f32u08), typeof(f64u08)) == typeof(f64u08)
@test promote_type(typeof(f32u08), typeof(f64u16)) == typeof(f64u16)
end

@testset "FloatX arrays with $(mytype), $(DT), $(CT), autopad=$(autopad)" for mytype ∈ [Float32, Float64], (DT, CT) in zip([FloatX8, FloatX16], [UInt8, UInt16]), autopad ∈ (true,false)
@test_skip @testset "FloatX arrays with $(mytype), $(DT), $(CT), autopad=$(autopad)" for mytype ∈ [Float32, Float64], (DT, CT) in zip([FloatX8, FloatX16], [UInt8, UInt16]), autopad ∈ (true,false)
configuration!("autopadding", autopad)
g = Grid(shape=(5,5))
dtype = DT(mytype(-1.1), mytype(+1.1))
Expand All @@ -189,11 +195,17 @@ end
@test isapprox(Devito.decompress.(data(f)), Devito.decompress.(data(g)))
end

@test_skip @testset "FloatX eps with $(mytype), $(DT), $(CT)" for mytype ∈ [Float32, Float64], (DT, CT) in zip([FloatX8, FloatX16], [UInt8, UInt16])
g = Grid(shape=(5,5))
dtype = DT(mytype(1.5), mytype(4.5))
@test eps(dtype) ≈ eps(mytype)
end

devito_arch = get(ENV, "DEVITO_ARCH", "gcc")

# TODO (9/2/2025) - failing with decoupler, mloubout is looking into the issue
if get(ENV, "DEVITO_DECOUPLER", "0") != "1"
@testset "CCall with printf" begin
@test_skip @testset "CCall with printf" begin
# CCall test written to use gcc
carch = devito_arch in ["gcc", "clang"] ? devito_arch : "gcc"
@pywith switchconfig(;compiler=get(ENV, "CC", carch)) begin
Expand Down Expand Up @@ -225,7 +237,7 @@ compression = []
(lowercase(devito_arch) == "nvc") && (push!(compression, "bitcomp"))
(lowercase(devito_arch) in ["gcc", "clang"]) && (push!(compression, "cvxcompress"))

@testset "Serialization with compression=$(compression)" for compression in compression
@test_skip @testset "Serialization with compression=$(compression)" for compression in compression
if compression == "bitcomp"
configuration!("compiler", "nvc")
else
Expand Down Expand Up @@ -268,7 +280,7 @@ compression = []
end
end

@testset "Serialization serial2str" begin
@test_skip @testset "Serialization serial2str" begin
nt = 11
space_order = 8
grid = Grid(shape=(21,21,21), dtype=Float32)
Expand All @@ -287,7 +299,7 @@ end
end

# JKW: removing for now, not sure what is even being tested here
# @testset "Serialization with CCall T=$T" for T in (Float32,Float64)
# @test_skip @testset "Serialization with CCall T=$T" for T in (Float32,Float64)
# space_order = 2
# time_M = 3
# filename = "testserialization.bin"
Expand Down
Loading