7373 end
7474end
7575
76- @inline numbereltype (:: Type{<:EnzymeCore.RNumber{T}} ) where {T} = T
77- @inline ptreltype (:: Type{<:EnzymeCore.RArray{T}} ) where {T} = T
7876@inline ptreltype (:: Type{Ptr{T}} ) where {T} = T
7977@inline ptreltype (:: Type{Core.LLVMPtr{T,N}} ) where {T,N} = T
8078@inline ptreltype (:: Type{Core.LLVMPtr{T} where N} ) where {T} = T
8179@inline ptreltype (:: Type{Base.RefValue{T}} ) where {T} = T
82- @inline ptreltype (:: Type{Array{T,N}} ) where {T,N} = T
83- @inline ptreltype (:: Type{Array{T,N} where N} ) where {T} = T
8480@inline ptreltype (:: Type{Complex{T}} ) where {T} = T
8581@inline ptreltype (:: Type{Tuple{Vararg{T}}} ) where {T} = T
8682@inline ptreltype (:: Type{IdDict{K,V}} ) where {K,V} = V
9288end
9389
9490@inline is_arrayorvararg_ty (:: Type ) = false
95- @inline is_arrayorvararg_ty (:: Type{Array{T,N}} ) where {T,N} = true
96- @inline is_arrayorvararg_ty (:: Type{Array{T,N} where N} ) where {T} = true
9791@inline is_arrayorvararg_ty (:: Type{Tuple{Vararg{T2}}} ) where {T2} = true
9892@inline is_arrayorvararg_ty (:: Type{Ptr{T}} ) where {T} = true
9993@inline is_arrayorvararg_ty (:: Type{Core.LLVMPtr{T,N}} ) where {T,N} = true
159153@inline is_vararg_tup (x) = false
160154@inline is_vararg_tup (:: Type{Tuple{Vararg{T2}}} ) where {T2} = true
161155
156+ Base. @nospecializeinfer @inline function is_mutable_array (@nospecialize (T:: Type ))
157+ if T <: Array
158+ return true
159+ end
160+ while T isa UnionAll
161+ T = T. body
162+ end
163+ if T isa DataType
164+ if hasproperty (T, :name ) && hasproperty (T. name, :module )
165+ mod = T. name. module
166+ if string (mod) == " Reactant" && (T. name. name == :ConcretePJRTArray || T. name. name == :ConcreteIFRTArray || T. name. name == :TracedRArray )
167+ return true
168+ end
169+ end
170+ end
171+ return false
172+ end
173+
174+ Base. @nospecializeinfer @inline function is_wrapped_number (@nospecialize (T:: Type ))
175+ if T isa UnionAll
176+ return is_wrapped_number (T. body)
177+ end
178+ while T isa UnionAll
179+ T = T. body
180+ end
181+ if T isa DataType && hasproperty (T, :name ) && hasproperty (T. name, :module )
182+ mod = T. name. module
183+ if string (mod) == " Reactant" && (T. name. name == :ConcretePJRTNumber || T. name. name == :ConcreteIFRTNumber || T. name. name == :TracedRNumber )
184+ return true
185+ end
186+ end
187+ return false
188+ end
189+
190+ Base. @nospecializeinfer @inline function unwrapped_number_type (@nospecialize (T:: Type ))
191+ while T isa UnionAll
192+ T = T. body
193+ end
194+ return T. parameters[1 ]
195+ end
196+
197+
162198@inline function active_reg_inner (
163199 :: Type{T} ,
164200 seen:: ST ,
198234 return ActiveState
199235 end
200236
201- if T <: EnzymeCore.RNumber
237+ if is_wrapped_number (T)
202238 return active_reg_inner (
203- numbereltype (T),
239+ unwrapped_number_type (T),
204240 seen,
205241 world,
206242 Val (justActive),
@@ -209,11 +245,32 @@ end
209245 )
210246 end
211247
248+ if is_mutable_array (T)
249+ if justActive
250+ return AnyState
251+ end
252+
253+ if active_reg_inner (
254+ eltype (T),
255+ seen,
256+ world,
257+ Val (justActive),
258+ Val (UnionSret),
259+ Val (AbstractIsMixed),
260+ ) == AnyState
261+ return AnyState
262+ else
263+ if AbstractIsMixed
264+ return MixedState
265+ else
266+ return DupState
267+ end
268+ end
269+ end
270+
212271 if T <: Ptr ||
213272 T <: Core.LLVMPtr ||
214- T <: Base.RefValue ||
215- T <: Array || T <: EnzymeCore.RArray
216- is_arrayorvararg_ty (T)
273+ T <: Base.RefValue || is_arrayorvararg_ty (T)
217274 if justActive
218275 return AnyState
219276 end
0 commit comments