@@ -291,18 +291,61 @@ Base.@nospecializeinfer function traced_type_inner(
291291 throw (" XLA $T array cannot be traced" )
292292end
293293
294+ Base. @nospecializeinfer function traced_type_inner (
295+ @nospecialize (A:: Type{AbstractArray} ),
296+ seen,
297+ mode:: TraceMode ,
298+ @nospecialize (track_numbers:: Type )
299+ )
300+ return A
301+ end
302+
303+ Base. @nospecializeinfer function traced_type_inner (
304+ @nospecialize (A:: Type{AbstractArray{T}} ),
305+ seen,
306+ mode:: TraceMode ,
307+ @nospecialize (track_numbers:: Type )
308+ ) where {T}
309+ if mode == ConcreteToTraced
310+ return AbstractArray{TracedRNumber{T}}
311+ else
312+ return A
313+ end
314+ end
315+
316+ Base. @nospecializeinfer function traced_type_inner (
317+ @nospecialize (A:: Type{AbstractArray{T,N}} ),
318+ seen,
319+ mode:: TraceMode ,
320+ @nospecialize (track_numbers:: Type )
321+ ) where {T,N}
322+ if mode == ConcreteToTraced
323+ return AbstractArray{TracedRNumber{T},N}
324+ else
325+ return A
326+ end
327+ end
328+
294329Base. @nospecializeinfer function traced_type_inner (
295330 @nospecialize (A:: Type{<:Array} ),
296331 seen,
297332 mode:: TraceMode ,
298333 @nospecialize (track_numbers:: Type )
299334)
300335 T = eltype (A)
301- N = ndims (A)
302- if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
303- return ConcreteRArray{T,N}
336+ if A isa UnionAll
337+ if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
338+ return ConcreteRArray{T}
339+ else
340+ return Array{traced_type_inner (T, seen, mode, track_numbers)}
341+ end
304342 else
305- return Array{traced_type_inner (T, seen, mode, track_numbers),N}
343+ N = ndims (A)
344+ if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
345+ return ConcreteRArray{T,N}
346+ else
347+ return Array{traced_type_inner (T, seen, mode, track_numbers),N}
348+ end
306349 end
307350end
308351
@@ -365,6 +408,7 @@ Base.@nospecializeinfer function traced_type_inner(
365408 if isnothing (Base. datatype_fieldcount (aT))
366409 throw (TracedTypeError (" Unhandled type $T " ))
367410 end
411+ return T
368412 end
369413
370414 if T isa Union
@@ -457,7 +501,7 @@ Base.@nospecializeinfer function traced_type_inner(
457501 end
458502
459503 name = Symbol[]
460- throw (NoFieldMatchError (T, TT2))
504+ throw (NoFieldMatchError (T, TT2, subTys ))
461505end
462506
463507const traced_type_cache = Dict {Tuple{TraceMode,Type},Dict{Type,Type}} ()
@@ -580,13 +624,18 @@ end
580624struct NoFieldMatchError <: TracedTypeException
581625 origty
582626 besteffort
627+ subTys
583628end
584629function Base. showerror (io:: IO , err:: NoFieldMatchError )
585- print (io, " NoFieldMatchError: " )
586- return print (
630+ println (io, " NoFieldMatchError: " )
631+ println (
587632 io,
588633 " Cannot convert type $(err. origty) , best attempt $(err. besteffort) failed.\n This could be because the type does not capture the fieldtypes that should be converted in its type parameters." ,
589634 )
635+ for (i, subty) in zip (1 : fieldcount (err. origty), err. subTys)
636+ origty = fieldtype (err. origty, i)
637+ println (io, " idx=" , i, " Derived: " , subty, " Existing: " , origty)
638+ end
590639end
591640
592641function make_tracer (
0 commit comments