diff --git a/src/mesharrays/dense.jl b/src/mesharrays/dense.jl index 84ec161..5403d67 100644 --- a/src/mesharrays/dense.jl +++ b/src/mesharrays/dense.jl @@ -1,9 +1,9 @@ """ struct MeshArray{T,N,MT} <: AbstractMeshArray{T,N} -Multi-dimensional array that is defined on a mesh. -The mesh is a tuple of meshgrid objects. -The mesh is stored in the field `mesh` and the data is stored in the field `data`. +Multi-dimensional array that is defined on a mesh. +The mesh is a tuple of meshgrid objects. +The mesh is stored in the field `mesh` and the data is stored in the field `data`. # Parameters: - `T`: type of data @@ -12,8 +12,8 @@ The mesh is stored in the field `mesh` and the data is stored in the field `data # Members: - `mesh` (`MT`): the mesh is a tuple of meshes. - The mesh should be an iterable object that contains an ordered list of grid points. - Examples are the + The mesh should be an iterable object that contains an ordered list of grid points. + Examples are the 1. Meshes defined in the `MeshGrids` module. 2. UnitRange such as `1:10`, etc. 3. Product of meshes `MeshProduct` defined in the `MeshGrids` module. @@ -21,7 +21,7 @@ The mesh is stored in the field `mesh` and the data is stored in the field `data If a mesh is defined on a continuous manifold and supports the following methods, then one can perform interpolation, derivatives, etc. on the mesh: - `locate(mesh, value)`: find the index of the closest grid point for given value; - `volume(mesh, index)`: find the volume of grid space near the point at griven index. - - `volume(mesh, gridpoint)`: locate the corresponding index of a given grid point and than find the volume spanned by the grid point. + - `volume(mesh, gridpoint)`: locate the corresponding index of a given grid point and than find the volume spanned by the grid point. - `data` (`Array{T,N}`): the data. - `dims`: dimension of the data @@ -30,7 +30,7 @@ struct MeshArray{T,N,MT} <: AbstractMeshArray{T,N} mesh::MT data::Array{T,N} dims::NTuple{N,Int} - function MeshArray{T,N,MT}(data::AbstractArray{T,N}, mesh) where {T,N,MT} + function MeshArray{T,N,MT}(data::AbstractArray{T,N}, mesh::MT) where {T,N,MT} # do nothing constructor, so that it is fast with no additional allocation # but you need to make sure that the mesh and data make sense @@ -55,12 +55,58 @@ Alias for [`MeshArray{T,1,MT}`](@ref MeshArray). """ const MeshVector{T,MT} = MeshArray{T,1,MT} +# ===================================================== +# Type stabilization helper functions (Issue #78 fix) +# ===================================================== + +# Type stabilization helper function (function barrier) +function _stabilize_mesh_type(mesh::Tuple) + # Concretize the mesh tuple type + if isconcretetype(typeof(mesh)) + return mesh + # else + # NOTE: This branch is unreachable in practice because typeof() always returns + # the actual runtime type, which is concrete. Kept for defensive programming. + # Create a new tuple while preserving the type of each element + # return map(identity, mesh) + end +end + +function _stabilize_mesh_type(mesh) + # Convert non-tuple to tuple + return tuple(mesh...) +end + +# Type-stable internal constructor (function barrier) +@inline function _create_mesharray_typed(data::AbstractArray{T,N}, mesh::MT, ::Type{T}, ::Val{N}) where {T,N,MT} + return MeshArray{T,N,MT}(data, mesh) +end + +function _create_mesharray_typed(data::AbstractArray{T,N}, mesh, dtype::Type, n::Int) where {T,N} + # Handle type conversion when necessary + if T != dtype + data_converted = convert(Array{dtype,N}, data) + return _create_mesharray_typed(data_converted, mesh, dtype, Val(n)) + else + return _create_mesharray_typed(data, mesh, T, Val(N)) + end +end + +# Generated function for improved type inference +@generated function _infer_mesh_type(mesh::MT) where {MT} + return :(MT) +end + +# ===================================================== +# Main constructor (type-stabilized) +# ===================================================== + """ function MeshArray(; mesh...; dtype = Float64, data::Union{Nothing,AbstractArray}=nothing) where {T} - + Create a Green struct. Its memeber `dims` is setted as the tuple consisting of the length of all meshes. # Arguments @@ -75,9 +121,13 @@ function MeshArray(mesh...; @assert all(x -> isiterable(typeof(x)), mesh) "all meshes should be iterable" - if isconcretetype(typeof(mesh)) == false - @warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue." - end + mesh = _stabilize_mesh_type(mesh) + + # NOTE: This check is unreachable in practice because typeof() always returns + # the actual runtime type, which is concrete. Kept for defensive programming. + # if isconcretetype(typeof(mesh)) == false + # @warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue." + # end N = length(mesh) dims = tuple([length(v) for v in mesh]...) @@ -91,7 +141,7 @@ function MeshArray(mesh...; if dtype != eltype(data) data = convert(Array{dtype,N}, data) end - return MeshArray{dtype,N,typeof(mesh)}(data, mesh) + return _create_mesharray_typed(data, mesh, dtype, N) end function MeshArray(; mesh::Union{Tuple,AbstractVector}, dtype=Float64, @@ -100,12 +150,16 @@ function MeshArray(; mesh::Union{Tuple,AbstractVector}, @assert all(x -> isiterable(typeof(x)), mesh) "all meshes should be iterable" if mesh isa AbstractVector - mesh = (m for m in mesh) + mesh = tuple(mesh...) end - if isconcretetype(typeof(mesh)) == false - @warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue." - end + mesh = _stabilize_mesh_type(mesh) + + # NOTE: This check is unreachable in practice because typeof() always returns + # the actual runtime type, which is concrete. Kept for defensive programming. + # if isconcretetype(typeof(mesh)) == false + # @warn "Mesh type $(typeof(mesh)) is not concrete, it may cause performance issue." + # end @assert mesh isa Tuple "mesh should be a tuple, now get $(typeof(mesh))" N = length(mesh) @@ -120,15 +174,15 @@ function MeshArray(; mesh::Union{Tuple,AbstractVector}, if dtype != eltype(data) data = convert(Array{dtype,N}, data) end - return MeshArray{dtype,N,typeof(mesh)}(data, mesh) + return _create_mesharray_typed(data, mesh, dtype, N) end """ getindex(obj::MeshArray, inds...) -Return a subset of `obj`'s data as specified by `inds`, where each `inds` may be, for example, an Int, an AbstractRange, or a Vector. +Return a subset of `obj`'s data as specified by `inds`, where each `inds` may be, for example, an Int, an AbstractRange, or a Vector. """ -Base.getindex(obj::MeshArray{T,N,MT}, inds::Vararg{Int,N}) where {T,MT,N} = Base.getindex(obj.data, inds...) +@inline Base.getindex(obj::MeshArray{T,N,MT}, inds::Vararg{Int,N}) where {T,MT,N} = @inbounds Base.getindex(obj.data, inds...) # Base.getindex(obj::MeshArray, I::Int) = Base.getindex(obj.data, I) """ @@ -137,7 +191,7 @@ Base.getindex(obj::MeshArray{T,N,MT}, inds::Vararg{Int,N}) where {T,MT,N} = Base Store values from array `v` within some subset of `obj.data` as specified by `inds`. """ -Base.setindex!(obj::MeshArray{T,N,MT}, v, inds::Vararg{Int,N}) where {T,MT,N} = Base.setindex!(obj.data, v, inds...) +@inline Base.setindex!(obj::MeshArray{T,N,MT}, v, inds::Vararg{Int,N}) where {T,MT,N} = @inbounds Base.setindex!(obj.data, v, inds...) # Base.setindex!(obj::MeshArray, v, I::Int) = Base.setindex!(obj.data, v, I) # IndexStyle(::Type{<:MeshArray}) = IndexCartesian() # by default, it is IndexCartesian @@ -178,21 +232,32 @@ find_gf(::Tuple{}) = nothing find_gf(a::MeshArray, rest) = a find_gf(::Any, rest) = find_gf(rest) -function Base.copyto!(dest, bc::Base.Broadcast.Broadcasted{MeshArray{T,N,MT}}) where {T,MT,N} +# Type-stable broadcast implementation +function Base.copyto!(dest::MeshArray{T,N,MT}, bc::Base.Broadcast.Broadcasted) where {T,N,MT} # without this function, inplace operation like g1 .+= g2 will make a lot of allocations # Please refer to the following posts for more details: # 1. manual on the interface: https://docs.julialang.org/en/v1/manual/interfaces/#extending-in-place-broadcast-2 # 2. see the post: https://discourse.julialang.org/t/help-implementing-copyto-for-broadcasting/51204/3 # 3. example from DataFrames.jl: https://github.com/JuliaData/DataFrames.jl/blob/main/src/other/broadcasting.jl#L193 - ######## approach 2: use materialize ######## - bcf = Base.Broadcast.materialize(bc) - for I in CartesianIndices(dest) - dest[I] = bcf[I] - end + # Type stabilization: @simd and inlining + indices = CartesianIndices(dest.data) + bcf = Base.Broadcast.flatten(bc) + + # Call type-stable internal function (function barrier) + _copyto_typed!(dest, bcf, indices) + return dest end +@inline function _copyto_typed!(dest::MeshArray{T,N,MT}, bcf, indices::CartesianIndices{N}) where {T,N,MT} + # Fast copy with determined types + @inbounds @simd for I in indices + dest.data[I] = bcf[I] + end + return nothing +end + ########### alternative approach ###################### # function Base.copyto!(dest::MeshArray{T, N, MT}, bc::Base.Broadcast.Broadcasted{Nothing}) where {T,MT,N} # _bcf = Base.Broadcast.flatten(bc) @@ -257,7 +322,7 @@ Check if the Green's functions `objL` and `objR` are on the same meshes. Throw a """ function _check(objL::MeshArray, objR::MeshArray) # KUN: check --> __check - # first: check typeof(objL.tgrid)==typeof(objR.tgrid) + # first: check typeof(objL.tgrid)==typeof(objR.tgrid) # second: check length(objL.tgrid) # third: hasmethod(objL.tgrid, isequal) --> assert # @assert objL.innerstate == objR.innerstate "Green's function innerstates are not inconsistent: $(objL.innerstate) and $(objR.innerstate)" @@ -267,4 +332,4 @@ function _check(objL::MeshArray, objR::MeshArray) return true # @assert objL.tgrid == objR.tgrid "Green's function time grids are not compatible:\n $(objL.tgrid)\nand\n $(objR.tgrid)" # @assert objL.mesh == objR.mesh "Green's function meshes are not compatible:\n $(objL.mesh)\nand\n $(objR.mesh)" -end +end \ No newline at end of file diff --git a/test/test_MeshArrays.jl b/test/test_MeshArrays.jl index 7c55a6b..1ccc3bc 100644 --- a/test/test_MeshArrays.jl +++ b/test/test_MeshArrays.jl @@ -94,3 +94,97 @@ end g = MeshArray(mesh1, mesh1, mesh1, mesh1, mesh1, mesh1) @test isconcretetype(typeof(g.mesh)) end + +@testset "MeshArray Type Stability Helpers (Issue #78)" begin + N1, N2 = 8, 6 + mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1) + mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2) + + mesh_tuple = (mesh1, mesh2) + stabilized = MeshArrays._stabilize_mesh_type(mesh_tuple) + @test isconcretetype(typeof(stabilized)) + @test stabilized === mesh_tuple + + # NOTE: Cannot test the non-concrete tuple branch (line 69 in dense.jl) because + # typeof() always returns the actual runtime type, which is concrete. The type + # annotation ::Tuple{Any, Any} does not change the actual type of the tuple. + # non_concrete_tuple = tuple(mesh1, mesh2)::Tuple{Any, Any} + # stabilized_non_concrete = MeshArrays._stabilize_mesh_type(non_concrete_tuple) + # @test isconcretetype(typeof(stabilized_non_concrete)) + + non_concrete_mesh = Any[mesh1, mesh2] + stabilized_from_vec = MeshArrays._stabilize_mesh_type(non_concrete_mesh) + @test isconcretetype(typeof(stabilized_from_vec)) + + data = rand(N1, N2) + result = MeshArrays._create_mesharray_typed(data, mesh_tuple, Float64, 2) + @test result isa MeshArray{Float64, 2} + @test result.data === data + + data_float = ones(Float64, N1, N2) + result_same_type = MeshArrays._create_mesharray_typed(data_float, mesh_tuple, Float64, 2) + @test result_same_type isa MeshArray{Float64, 2} + @test result_same_type.data === data_float + + data_int = ones(Int, N1, N2) + result_converted = MeshArrays._create_mesharray_typed(data_int, mesh_tuple, Float64, 2) + @test result_converted isa MeshArray{Float64, 2} + @test eltype(result_converted.data) == Float64 + + MT = typeof(mesh_tuple) + @test MeshArrays._infer_mesh_type(mesh_tuple) == MT +end + +@testset "MeshArray Broadcast Type Stability" begin + N1, N2 = 10, 12 + mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1) + mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2) + + g1 = MeshArray(mesh1, mesh2; data=rand(N1, N2)) + g2 = MeshArray(mesh1, mesh2; data=rand(N1, N2)) + + g3 = similar(g1) + g3 .= g1 .+ g2 + @test g3.data ≈ g1.data .+ g2.data + + g4 = similar(g1) + Base.copyto!(g4, Base.Broadcast.broadcasted(+, g1, g2)) + @test g4.data ≈ g1.data .+ g2.data + + indices = CartesianIndices(g1.data) + bcf = Base.Broadcast.flatten(Base.Broadcast.broadcasted(*, g1, 2.0)) + g5 = similar(g1) + MeshArrays._copyto_typed!(g5, bcf, indices) + @test g5.data ≈ g1.data .* 2.0 +end + +@testset "MeshArray with AbstractVector mesh" begin + N1, N2 = 8, 6 + mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1) + mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2) + + mesh_vector = [mesh1, mesh2] + g = MeshArray(; mesh=mesh_vector, dtype=Float64) + + @test size(g) == (N1, N2) + @test eltype(g.data) == Float64 + @test isconcretetype(typeof(g.mesh)) +end + +@testset "MeshArray with type conversion" begin + N1, N2 = 8, 6 + mesh1 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N1) + mesh2 = SimpleGrid.Uniform{Float64}([0.0, 1.0], N2) + + # Test dtype conversion from Int to Float64 + data_int = ones(Int, N1, N2) + g1 = MeshArray(; mesh=(mesh1, mesh2), dtype=Float64, data=data_int) + @test eltype(g1.data) == Float64 + @test g1.data ≈ ones(Float64, N1, N2) + + # Test dtype conversion from Float64 to ComplexF64 + data_float = ones(Float64, N1, N2) + g2 = MeshArray(mesh1, mesh2; dtype=ComplexF64, data=data_float) + @test eltype(g2.data) == ComplexF64 + @test g2.data ≈ ones(ComplexF64, N1, N2) +end