Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 93 additions & 28 deletions src/mesharrays/dense.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
struct MeshArray{T,N,MT} <: AbstractMeshArray{T,N}

Multi-dimensional array that is defined on a mesh.
The mesh is a tuple of meshgrid objects.
The mesh is stored in the field `mesh` and the data is stored in the field `data`.
Multi-dimensional array that is defined on a mesh.
The mesh is a tuple of meshgrid objects.
The mesh is stored in the field `mesh` and the data is stored in the field `data`.

# Parameters:
- `T`: type of data
Expand All @@ -12,16 +12,16 @@ The mesh is stored in the field `mesh` and the data is stored in the field `data

# Members:
- `mesh` (`MT`): the mesh is a tuple of meshes.
The mesh should be an iterable object that contains an ordered list of grid points.
Examples are the
The mesh should be an iterable object that contains an ordered list of grid points.
Examples are the
1. Meshes defined in the `MeshGrids` module.
2. UnitRange such as `1:10`, etc.
3. Product of meshes `MeshProduct` defined in the `MeshGrids` module.

If a mesh is defined on a continuous manifold and supports the following methods, then one can perform interpolation, derivatives, etc. on the mesh:
- `locate(mesh, value)`: find the index of the closest grid point for given value;
- `volume(mesh, index)`: find the volume of grid space near the point at griven index.
- `volume(mesh, gridpoint)`: locate the corresponding index of a given grid point and than find the volume spanned by the grid point.
- `volume(mesh, gridpoint)`: locate the corresponding index of a given grid point and than find the volume spanned by the grid point.

- `data` (`Array{T,N}`): the data.
- `dims`: dimension of the data
Expand All @@ -30,7 +30,7 @@ struct MeshArray{T,N,MT} <: AbstractMeshArray{T,N}
mesh::MT
data::Array{T,N}
dims::NTuple{N,Int}
function MeshArray{T,N,MT}(data::AbstractArray{T,N}, mesh) where {T,N,MT}
function MeshArray{T,N,MT}(data::AbstractArray{T,N}, mesh::MT) where {T,N,MT}
# do nothing constructor, so that it is fast with no additional allocation
# but you need to make sure that the mesh and data make sense

Expand All @@ -55,12 +55,58 @@ Alias for [`MeshArray{T,1,MT}`](@ref MeshArray).
"""
const MeshVector{T,MT} = MeshArray{T,1,MT}

# =====================================================
# Type stabilization helper functions (Issue #78 fix)
# =====================================================

# Type stabilization helper function (function barrier)
function _stabilize_mesh_type(mesh::Tuple)
# Concretize the mesh tuple type
if isconcretetype(typeof(mesh))
return mesh
# else
# NOTE: This branch is unreachable in practice because typeof() always returns
# the actual runtime type, which is concrete. Kept for defensive programming.
# Create a new tuple while preserving the type of each element
# return map(identity, mesh)
end
end

function _stabilize_mesh_type(mesh)
# Convert non-tuple to tuple
return tuple(mesh...)
end

# Type-stable internal constructor (function barrier)
@inline function _create_mesharray_typed(data::AbstractArray{T,N}, mesh::MT, ::Type{T}, ::Val{N}) where {T,N,MT}
return MeshArray{T,N,MT}(data, mesh)
end

function _create_mesharray_typed(data::AbstractArray{T,N}, mesh, dtype::Type, n::Int) where {T,N}
# Handle type conversion when necessary
if T != dtype
data_converted = convert(Array{dtype,N}, data)
return _create_mesharray_typed(data_converted, mesh, dtype, Val(n))
else
return _create_mesharray_typed(data, mesh, T, Val(N))
end
end

# Generated function for improved type inference
@generated function _infer_mesh_type(mesh::MT) where {MT}
return :(MT)
end

# =====================================================
# Main constructor (type-stabilized)
# =====================================================

"""
function MeshArray(;
mesh...;
dtype = Float64,
data::Union{Nothing,AbstractArray}=nothing) where {T}

Create a Green struct. Its memeber `dims` is setted as the tuple consisting of the length of all meshes.

# Arguments
Expand All @@ -75,9 +121,13 @@ function MeshArray(mesh...;

@assert all(x -> isiterable(typeof(x)), mesh) "all meshes should be iterable"

if isconcretetype(typeof(mesh)) == false
@warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue."
end
mesh = _stabilize_mesh_type(mesh)

# NOTE: This check is unreachable in practice because typeof() always returns
# the actual runtime type, which is concrete. Kept for defensive programming.
# if isconcretetype(typeof(mesh)) == false
# @warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue."
# end

N = length(mesh)
dims = tuple([length(v) for v in mesh]...)
Expand All @@ -91,7 +141,7 @@ function MeshArray(mesh...;
if dtype != eltype(data)
data = convert(Array{dtype,N}, data)
end
return MeshArray{dtype,N,typeof(mesh)}(data, mesh)
return _create_mesharray_typed(data, mesh, dtype, N)
end
function MeshArray(; mesh::Union{Tuple,AbstractVector},
dtype=Float64,
Expand All @@ -100,12 +150,16 @@ function MeshArray(; mesh::Union{Tuple,AbstractVector},
@assert all(x -> isiterable(typeof(x)), mesh) "all meshes should be iterable"

if mesh isa AbstractVector
mesh = (m for m in mesh)
mesh = tuple(mesh...)
end

if isconcretetype(typeof(mesh)) == false
@warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue."
end
mesh = _stabilize_mesh_type(mesh)

# NOTE: This check is unreachable in practice because typeof() always returns
# the actual runtime type, which is concrete. Kept for defensive programming.
# if isconcretetype(typeof(mesh)) == false
# @warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue."
# end

@assert mesh isa Tuple "mesh should be a tuple, now get $(typeof(mesh))"
N = length(mesh)
Expand All @@ -120,15 +174,15 @@ function MeshArray(; mesh::Union{Tuple,AbstractVector},
if dtype != eltype(data)
data = convert(Array{dtype,N}, data)
end
return MeshArray{dtype,N,typeof(mesh)}(data, mesh)
return _create_mesharray_typed(data, mesh, dtype, N)
end

"""
getindex(obj::MeshArray, inds...)

Return a subset of `obj`'s data as specified by `inds`, where each `inds` may be, for example, an Int, an AbstractRange, or a Vector.
Return a subset of `obj`'s data as specified by `inds`, where each `inds` may be, for example, an Int, an AbstractRange, or a Vector.
"""
Base.getindex(obj::MeshArray{T,N,MT}, inds::Vararg{Int,N}) where {T,MT,N} = Base.getindex(obj.data, inds...)
@inline Base.getindex(obj::MeshArray{T,N,MT}, inds::Vararg{Int,N}) where {T,MT,N} = @inbounds Base.getindex(obj.data, inds...)
# Base.getindex(obj::MeshArray, I::Int) = Base.getindex(obj.data, I)

"""
Expand All @@ -137,7 +191,7 @@ Base.getindex(obj::MeshArray{T,N,MT}, inds::Vararg{Int,N}) where {T,MT,N} = Base

Store values from array `v` within some subset of `obj.data` as specified by `inds`.
"""
Base.setindex!(obj::MeshArray{T,N,MT}, v, inds::Vararg{Int,N}) where {T,MT,N} = Base.setindex!(obj.data, v, inds...)
@inline Base.setindex!(obj::MeshArray{T,N,MT}, v, inds::Vararg{Int,N}) where {T,MT,N} = @inbounds Base.setindex!(obj.data, v, inds...)
# Base.setindex!(obj::MeshArray, v, I::Int) = Base.setindex!(obj.data, v, I)

# IndexStyle(::Type{<:MeshArray}) = IndexCartesian() # by default, it is IndexCartesian
Expand Down Expand Up @@ -178,21 +232,32 @@ find_gf(::Tuple{}) = nothing
find_gf(a::MeshArray, rest) = a
find_gf(::Any, rest) = find_gf(rest)

function Base.copyto!(dest, bc::Base.Broadcast.Broadcasted{MeshArray{T,N,MT}}) where {T,MT,N}
# Type-stable broadcast implementation
function Base.copyto!(dest::MeshArray{T,N,MT}, bc::Base.Broadcast.Broadcasted) where {T,N,MT}
# without this function, inplace operation like g1 .+= g2 will make a lot of allocations
# Please refer to the following posts for more details:
# 1. manual on the interface: https://docs.julialang.org/en/v1/manual/interfaces/#extending-in-place-broadcast-2
# 2. see the post: https://discourse.julialang.org/t/help-implementing-copyto-for-broadcasting/51204/3
# 3. example from DataFrames.jl: https://github.com/JuliaData/DataFrames.jl/blob/main/src/other/broadcasting.jl#L193

######## approach 2: use materialize ########
bcf = Base.Broadcast.materialize(bc)
for I in CartesianIndices(dest)
dest[I] = bcf[I]
end
# Type stabilization: @simd and inlining
indices = CartesianIndices(dest.data)
bcf = Base.Broadcast.flatten(bc)

# Call type-stable internal function (function barrier)
_copyto_typed!(dest, bcf, indices)

return dest
end

@inline function _copyto_typed!(dest::MeshArray{T,N,MT}, bcf, indices::CartesianIndices{N}) where {T,N,MT}
# Fast copy with determined types
@inbounds @simd for I in indices
dest.data[I] = bcf[I]
end
return nothing
end

########### alternative approach ######################
# function Base.copyto!(dest::MeshArray{T, N, MT}, bc::Base.Broadcast.Broadcasted{Nothing}) where {T,MT,N}
# _bcf = Base.Broadcast.flatten(bc)
Expand Down Expand Up @@ -257,7 +322,7 @@ Check if the Green's functions `objL` and `objR` are on the same meshes. Throw a
"""
function _check(objL::MeshArray, objR::MeshArray)
# KUN: check --> __check
# first: check typeof(objL.tgrid)==typeof(objR.tgrid)
# first: check typeof(objL.tgrid)==typeof(objR.tgrid)
# second: check length(objL.tgrid)
# third: hasmethod(objL.tgrid, isequal) --> assert
# @assert objL.innerstate == objR.innerstate "Green's function innerstates are not inconsistent: $(objL.innerstate) and $(objR.innerstate)"
Expand All @@ -267,4 +332,4 @@ function _check(objL::MeshArray, objR::MeshArray)
return true
# @assert objL.tgrid == objR.tgrid "Green's function time grids are not compatible:\n $(objL.tgrid)\nand\n $(objR.tgrid)"
# @assert objL.mesh == objR.mesh "Green's function meshes are not compatible:\n $(objL.mesh)\nand\n $(objR.mesh)"
end
end
94 changes: 94 additions & 0 deletions test/test_MeshArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,97 @@ end
g = MeshArray(mesh1, mesh1, mesh1, mesh1, mesh1, mesh1)
@test isconcretetype(typeof(g.mesh))
end

@testset "MeshArray Type Stability Helpers (Issue #78)" begin
N1, N2 = 8, 6
mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1)
mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2)

mesh_tuple = (mesh1, mesh2)
stabilized = MeshArrays._stabilize_mesh_type(mesh_tuple)
@test isconcretetype(typeof(stabilized))
@test stabilized === mesh_tuple

# NOTE: Cannot test the non-concrete tuple branch (line 69 in dense.jl) because
# typeof() always returns the actual runtime type, which is concrete. The type
# annotation ::Tuple{Any, Any} does not change the actual type of the tuple.
# non_concrete_tuple = tuple(mesh1, mesh2)::Tuple{Any, Any}
# stabilized_non_concrete = MeshArrays._stabilize_mesh_type(non_concrete_tuple)
# @test isconcretetype(typeof(stabilized_non_concrete))

non_concrete_mesh = Any[mesh1, mesh2]
stabilized_from_vec = MeshArrays._stabilize_mesh_type(non_concrete_mesh)
@test isconcretetype(typeof(stabilized_from_vec))

data = rand(N1, N2)
result = MeshArrays._create_mesharray_typed(data, mesh_tuple, Float64, 2)
@test result isa MeshArray{Float64, 2}
@test result.data === data

data_float = ones(Float64, N1, N2)
result_same_type = MeshArrays._create_mesharray_typed(data_float, mesh_tuple, Float64, 2)
@test result_same_type isa MeshArray{Float64, 2}
@test result_same_type.data === data_float

data_int = ones(Int, N1, N2)
result_converted = MeshArrays._create_mesharray_typed(data_int, mesh_tuple, Float64, 2)
@test result_converted isa MeshArray{Float64, 2}
@test eltype(result_converted.data) == Float64

MT = typeof(mesh_tuple)
@test MeshArrays._infer_mesh_type(mesh_tuple) == MT
end

@testset "MeshArray Broadcast Type Stability" begin
N1, N2 = 10, 12
mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1)
mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2)

g1 = MeshArray(mesh1, mesh2; data=rand(N1, N2))
g2 = MeshArray(mesh1, mesh2; data=rand(N1, N2))

g3 = similar(g1)
g3 .= g1 .+ g2
@test g3.data ≈ g1.data .+ g2.data

g4 = similar(g1)
Base.copyto!(g4, Base.Broadcast.broadcasted(+, g1, g2))
@test g4.data ≈ g1.data .+ g2.data

indices = CartesianIndices(g1.data)
bcf = Base.Broadcast.flatten(Base.Broadcast.broadcasted(*, g1, 2.0))
g5 = similar(g1)
MeshArrays._copyto_typed!(g5, bcf, indices)
@test g5.data ≈ g1.data .* 2.0
end

@testset "MeshArray with AbstractVector mesh" begin
N1, N2 = 8, 6
mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1)
mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2)

mesh_vector = [mesh1, mesh2]
g = MeshArray(; mesh=mesh_vector, dtype=Float64)

@test size(g) == (N1, N2)
@test eltype(g.data) == Float64
@test isconcretetype(typeof(g.mesh))
end

@testset "MeshArray with type conversion" begin
N1, N2 = 8, 6
mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1)
mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2)

# Test dtype conversion from Int to Float64
data_int = ones(Int, N1, N2)
g1 = MeshArray(; mesh=(mesh1, mesh2), dtype=Float64, data=data_int)
@test eltype(g1.data) == Float64
@test g1.data ≈ ones(Float64, N1, N2)

# Test dtype conversion from Float64 to ComplexF64
data_float = ones(Float64, N1, N2)
g2 = MeshArray(mesh1, mesh2; dtype=ComplexF64, data=data_float)
@test eltype(g2.data) == ComplexF64
@test g2.data ≈ ones(ComplexF64, N1, N2)
end
Loading