11"""
22 struct MeshArray{T,N,MT} <: AbstractMeshArray{T,N}
33
4- Multi-dimensional array that is defined on a mesh.
5- The mesh is a tuple of meshgrid objects.
6- The mesh is stored in the field `mesh` and the data is stored in the field `data`.
4+ Multi-dimensional array that is defined on a mesh.
5+ The mesh is a tuple of meshgrid objects.
6+ The mesh is stored in the field `mesh` and the data is stored in the field `data`.
77
88# Parameters:
99- `T`: type of data
@@ -12,16 +12,16 @@ The mesh is stored in the field `mesh` and the data is stored in the field `data
1212
1313# Members:
1414- `mesh` (`MT`): the mesh is a tuple of meshes.
15- The mesh should be an iterable object that contains an ordered list of grid points.
16- Examples are the
15+ The mesh should be an iterable object that contains an ordered list of grid points.
16+ Examples are the
1717 1. Meshes defined in the `MeshGrids` module.
1818 2. UnitRange such as `1:10`, etc.
1919 3. Product of meshes `MeshProduct` defined in the `MeshGrids` module.
2020
2121 If a mesh is defined on a continuous manifold and supports the following methods, then one can perform interpolation, derivatives, etc. on the mesh:
2222 - `locate(mesh, value)`: find the index of the closest grid point for given value;
2323 - `volume(mesh, index)`: find the volume of grid space near the point at griven index.
24- - `volume(mesh, gridpoint)`: locate the corresponding index of a given grid point and than find the volume spanned by the grid point.
24+ - `volume(mesh, gridpoint)`: locate the corresponding index of a given grid point and than find the volume spanned by the grid point.
2525
2626- `data` (`Array{T,N}`): the data.
2727- `dims`: dimension of the data
@@ -30,7 +30,7 @@ struct MeshArray{T,N,MT} <: AbstractMeshArray{T,N}
3030 mesh:: MT
3131 data:: Array{T,N}
3232 dims:: NTuple{N,Int}
33- function MeshArray {T,N,MT} (data:: AbstractArray{T,N} , mesh) where {T,N,MT}
33+ function MeshArray {T,N,MT} (data:: AbstractArray{T,N} , mesh:: MT ) where {T,N,MT}
3434 # do nothing constructor, so that it is fast with no additional allocation
3535 # but you need to make sure that the mesh and data make sense
3636
@@ -55,12 +55,56 @@ Alias for [`MeshArray{T,1,MT}`](@ref MeshArray).
5555"""
5656const MeshVector{T,MT} = MeshArray{T,1 ,MT}
5757
58+ # =====================================================
59+ # Type stabilization helper functions (Issue #78 fix)
60+ # =====================================================
61+
62+ # Type stabilization helper function (function barrier)
63+ function _stabilize_mesh_type (mesh:: Tuple )
64+ # Concretize the mesh tuple type
65+ if isconcretetype (typeof (mesh))
66+ return mesh
67+ else
68+ # Create a new tuple while preserving the type of each element
69+ return map (identity, mesh)
70+ end
71+ end
72+
73+ function _stabilize_mesh_type (mesh)
74+ # Convert non-tuple to tuple
75+ return tuple (mesh... )
76+ end
77+
78+ # Type-stable internal constructor (function barrier)
79+ @inline function _create_mesharray_typed (data:: AbstractArray{T,N} , mesh:: MT , :: Type{T} , :: Val{N} ) where {T,N,MT}
80+ return MeshArray {T,N,MT} (data, mesh)
81+ end
82+
83+ function _create_mesharray_typed (data:: AbstractArray{T,N} , mesh, dtype:: Type , n:: Int ) where {T,N}
84+ # Handle type conversion when necessary
85+ if T != dtype
86+ data_converted = convert (Array{dtype,N}, data)
87+ return _create_mesharray_typed (data_converted, mesh, dtype, Val (n))
88+ else
89+ return _create_mesharray_typed (data, mesh, T, Val (N))
90+ end
91+ end
92+
93+ # Generated function for improved type inference
94+ @generated function _infer_mesh_type (mesh:: MT ) where {MT}
95+ return :(MT)
96+ end
97+
98+ # =====================================================
99+ # Main constructor (type-stabilized)
100+ # =====================================================
101+
58102"""
59103 function MeshArray(;
60104 mesh...;
61105 dtype = Float64,
62106 data::Union{Nothing,AbstractArray}=nothing) where {T}
63-
107+
64108Create a Green struct. Its memeber `dims` is setted as the tuple consisting of the length of all meshes.
65109
66110# Arguments
@@ -75,6 +119,8 @@ function MeshArray(mesh...;
75119
76120 @assert all (x -> isiterable (typeof (x)), mesh) " all meshes should be iterable"
77121
122+ mesh = _stabilize_mesh_type (mesh)
123+
78124 if isconcretetype (typeof (mesh)) == false
79125 @warn " Mesh type $(typeof (mesh)) is not concrete, it may cause performance issue."
80126 end
@@ -91,7 +137,7 @@ function MeshArray(mesh...;
91137 if dtype != eltype (data)
92138 data = convert (Array{dtype,N}, data)
93139 end
94- return MeshArray {dtype,N,typeof(mesh)} ( data, mesh)
140+ return _create_mesharray_typed ( data, mesh, dtype, N )
95141end
96142function MeshArray (; mesh:: Union{Tuple,AbstractVector} ,
97143 dtype= Float64,
@@ -100,9 +146,11 @@ function MeshArray(; mesh::Union{Tuple,AbstractVector},
100146 @assert all (x -> isiterable (typeof (x)), mesh) " all meshes should be iterable"
101147
102148 if mesh isa AbstractVector
103- mesh = (m for m in mesh)
149+ mesh = tuple ( mesh... )
104150 end
105151
152+ mesh = _stabilize_mesh_type (mesh)
153+
106154 if isconcretetype (typeof (mesh)) == false
107155 @warn " Mesh type $(typeof (mesh)) is not concrete, it may cause performance issue."
108156 end
@@ -120,15 +168,15 @@ function MeshArray(; mesh::Union{Tuple,AbstractVector},
120168 if dtype != eltype (data)
121169 data = convert (Array{dtype,N}, data)
122170 end
123- return MeshArray {dtype,N,typeof(mesh)} ( data, mesh)
171+ return _create_mesharray_typed ( data, mesh, dtype, N )
124172end
125173
126174"""
127175 getindex(obj::MeshArray, inds...)
128176
129- Return a subset of `obj`'s data as specified by `inds`, where each `inds` may be, for example, an Int, an AbstractRange, or a Vector.
177+ Return a subset of `obj`'s data as specified by `inds`, where each `inds` may be, for example, an Int, an AbstractRange, or a Vector.
130178"""
131- Base. getindex (obj:: MeshArray{T,N,MT} , inds:: Vararg{Int,N} ) where {T,MT,N} = Base. getindex (obj. data, inds... )
179+ @inline Base. getindex (obj:: MeshArray{T,N,MT} , inds:: Vararg{Int,N} ) where {T,MT,N} = @inbounds Base. getindex (obj. data, inds... )
132180# Base.getindex(obj::MeshArray, I::Int) = Base.getindex(obj.data, I)
133181
134182"""
@@ -137,7 +185,7 @@ Base.getindex(obj::MeshArray{T,N,MT}, inds::Vararg{Int,N}) where {T,MT,N} = Base
137185
138186Store values from array `v` within some subset of `obj.data` as specified by `inds`.
139187"""
140- Base. setindex! (obj:: MeshArray{T,N,MT} , v, inds:: Vararg{Int,N} ) where {T,MT,N} = Base. setindex! (obj. data, v, inds... )
188+ @inline Base. setindex! (obj:: MeshArray{T,N,MT} , v, inds:: Vararg{Int,N} ) where {T,MT,N} = @inbounds Base. setindex! (obj. data, v, inds... )
141189# Base.setindex!(obj::MeshArray, v, I::Int) = Base.setindex!(obj.data, v, I)
142190
143191# IndexStyle(::Type{<:MeshArray}) = IndexCartesian() # by default, it is IndexCartesian
@@ -178,21 +226,32 @@ find_gf(::Tuple{}) = nothing
178226find_gf (a:: MeshArray , rest) = a
179227find_gf (:: Any , rest) = find_gf (rest)
180228
181- function Base. copyto! (dest, bc:: Base.Broadcast.Broadcasted{MeshArray{T,N,MT}} ) where {T,MT,N}
229+ # Type-stable broadcast implementation
230+ function Base. copyto! (dest:: MeshArray{T,N,MT} , bc:: Base.Broadcast.Broadcasted ) where {T,N,MT}
182231 # without this function, inplace operation like g1 .+= g2 will make a lot of allocations
183232 # Please refer to the following posts for more details:
184233 # 1. manual on the interface: https://docs.julialang.org/en/v1/manual/interfaces/#extending-in-place-broadcast-2
185234 # 2. see the post: https://discourse.julialang.org/t/help-implementing-copyto-for-broadcasting/51204/3
186235 # 3. example from DataFrames.jl: https://github.com/JuliaData/DataFrames.jl/blob/main/src/other/broadcasting.jl#L193
187236
188- # ####### approach 2: use materialize ########
189- bcf = Base. Broadcast. materialize (bc)
190- for I in CartesianIndices (dest)
191- dest[I] = bcf[I]
192- end
237+ # Type stabilization: @simd and inlining
238+ indices = CartesianIndices (dest. data)
239+ bcf = Base. Broadcast. flatten (bc)
240+
241+ # Call type-stable internal function (function barrier)
242+ _copyto_typed! (dest, bcf, indices)
243+
193244 return dest
194245end
195246
247+ @inline function _copyto_typed! (dest:: MeshArray{T,N,MT} , bcf, indices:: CartesianIndices{N} ) where {T,N,MT}
248+ # Fast copy with determined types
249+ @inbounds @simd for I in indices
250+ dest. data[I] = bcf[I]
251+ end
252+ return nothing
253+ end
254+
196255# ########## alternative approach ######################
197256# function Base.copyto!(dest::MeshArray{T, N, MT}, bc::Base.Broadcast.Broadcasted{Nothing}) where {T,MT,N}
198257# _bcf = Base.Broadcast.flatten(bc)
@@ -257,7 +316,7 @@ Check if the Green's functions `objL` and `objR` are on the same meshes. Throw a
257316"""
258317function _check (objL:: MeshArray , objR:: MeshArray )
259318 # KUN: check --> __check
260- # first: check typeof(objL.tgrid)==typeof(objR.tgrid)
319+ # first: check typeof(objL.tgrid)==typeof(objR.tgrid)
261320 # second: check length(objL.tgrid)
262321 # third: hasmethod(objL.tgrid, isequal) --> assert
263322 # @assert objL.innerstate == objR.innerstate "Green's function innerstates are not inconsistent: $(objL.innerstate) and $(objR.innerstate)"
@@ -267,4 +326,4 @@ function _check(objL::MeshArray, objR::MeshArray)
267326 return true
268327 # @assert objL.tgrid == objR.tgrid "Green's function time grids are not compatible:\n $(objL.tgrid)\nand\n $(objR.tgrid)"
269328 # @assert objL.mesh == objR.mesh "Green's function meshes are not compatible:\n $(objL.mesh)\nand\n $(objR.mesh)"
270- end
329+ end
0 commit comments