Skip to content

Commit 29fecdc

Browse files
Merge pull request #228 from ChevronETC/wask/sparse-coordinates-for-complex
Coordinates for complex SparseFunction
2 parents f2abbd5 + 270b9e4 commit 29fecdc

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Devito"
22
uuid = "06ed14b1-0e40-4084-abdf-764a285f8c42"
33
authors = ["Sam Kaplan <Sam.Kaplan@chevron.com>"]
4-
version = "1.4.0"
4+
version = "1.4.1"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

src/Devito.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,10 +1086,12 @@ data(x::SubFunction{T,N,M}) where {T,N,M} = data_allocated(x)
10861086
coordinates(x::SparseDiscreteFunction)
10871087
10881088
Returns a Devito function associated with the coordinates of a sparse time function.
1089-
Note that contrary to typical Julia convention, coordinate order is from slow-to-fast (Python ordering).
1089+
Note 1: contrary to typical Julia convention, coordinate order is from slow-to-fast (Python ordering).
10901090
Thus, for a 3D grid, the sparse time function coordinates would be ordered x,y,z.
1091+
Note 2: we need to handle complex data types, because coordinats are purely real
10911092
"""
10921093
coordinates(x::SparseDiscreteFunction{T,N,M}) where {T,N,M} = SubFunction{T,2,M}(x.o.coordinates)
1094+
coordinates(x::SparseDiscreteFunction{Complex{T},N,M}) where {T,N,M} = SubFunction{T,2,M}(x.o.coordinates)
10931095

10941096
"""
10951097
coordinates_data(x::SparseDiscreteFunction)

test/serialtests.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ end
203203
@test size_with_halo(sf) == (npoint,)
204204
end
205205

206-
@testset "Sparse function coordinates, n=$n" for n in ( (10,11), (10,11,12) )
206+
@testset "Sparse function coordinates, Real, n=$n" for n in ( (10,11), (10,11,12) )
207207
grid = Grid(shape=n, dtype=Float32)
208208
sf = SparseFunction(name="sf", npoint=10, grid=grid)
209209
@test typeof(coordinates(sf)) <: SubFunction{Float32,2}
@@ -216,6 +216,19 @@ end
216216
@test _sf_coords x
217217
end
218218

219+
@testset "Sparse function coordinates, Complex, n=$n" for n in ( (10,11), (10,11,12) )
220+
grid = Grid(shape=n, dtype=Float32)
221+
sf = SparseFunction(name="sf", npoint=10, grid=grid, dtype=Complex{Float32})
222+
@test typeof(coordinates(sf)) <: SubFunction{Float32,2}
223+
sf_coords = coordinates_data(sf)
224+
@test isa(sf_coords, Devito.DevitoArray)
225+
@test size(sf_coords) == (length(n),10)
226+
x = rand(length(n),10)
227+
sf_coords .= x
228+
_sf_coords = coordinates_data(sf)
229+
@test _sf_coords x
230+
end
231+
219232
@testset "SparseFunction from PyObject, T=$T, n=$n, npoint=$npoint" for T in (Float32, Float64), n in ((3,4),(3,4,5)), npoint in (1,5,10)
220233
g = Grid(shape=n, dtype=T)
221234
sf = SparseFunction(name="sf", grid=g, npoint=npoint)
@@ -520,11 +533,11 @@ end
520533
apply(op)
521534
for j in 1:div(size,factr)+1
522535
@test data(f)[j] == data(g)[(j-1)*factr+1]
523-
end
524-
if ENV["DEVITO_BRANCH"] in ("main", "devitopro")
525-
@test data(f)[end] == data(g)[end]
526-
else
527-
@test_broken data(f)[end] == data(g)[end]
536+
if get(ENV,"DEVITO_BRANCH","main") in ("main", "devitopro")
537+
@test data(f)[end] == data(g)[end]
538+
else
539+
@test_broken data(f)[end] == data(g)[end]
540+
end
528541
end
529542
end
530543

0 commit comments

Comments
 (0)