Skip to content

Commit ba4405d

Browse files
wsmosesswilliamson7giordanogithub-actions[bot]
authored
Misc fixes (#687)
* Misc fixes * Apply formatting suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Sarah Williamson <[email protected]> Co-authored-by: Mosè Giordano <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent bf0fb61 commit ba4405d

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

src/TracedUtils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ function __lookup_unique_name_in_module(mod, name)
299299
new_name = i == 0 ? name : name * "_" * string(i)
300300
MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, new_name)) && return new_name
301301
end
302-
return error("Could not find unique name for $name")
302+
modstr = string(mod)
303+
return error("Mod\n$modstr\nCould not find unique name for $name")
303304
end
304305

305306
function __take_region(compiled_fn)

src/Tracing.jl

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,61 @@ Base.@nospecializeinfer function traced_type_inner(
291291
throw("XLA $T array cannot be traced")
292292
end
293293

294+
Base.@nospecializeinfer function traced_type_inner(
295+
@nospecialize(A::Type{AbstractArray}),
296+
seen,
297+
mode::TraceMode,
298+
@nospecialize(track_numbers::Type)
299+
)
300+
return A
301+
end
302+
303+
Base.@nospecializeinfer function traced_type_inner(
304+
@nospecialize(A::Type{AbstractArray{T}}),
305+
seen,
306+
mode::TraceMode,
307+
@nospecialize(track_numbers::Type)
308+
) where {T}
309+
if mode == ConcreteToTraced
310+
return AbstractArray{TracedRNumber{T}}
311+
else
312+
return A
313+
end
314+
end
315+
316+
Base.@nospecializeinfer function traced_type_inner(
317+
@nospecialize(A::Type{AbstractArray{T,N}}),
318+
seen,
319+
mode::TraceMode,
320+
@nospecialize(track_numbers::Type)
321+
) where {T,N}
322+
if mode == ConcreteToTraced
323+
return AbstractArray{TracedRNumber{T},N}
324+
else
325+
return A
326+
end
327+
end
328+
294329
Base.@nospecializeinfer function traced_type_inner(
295330
@nospecialize(A::Type{<:Array}),
296331
seen,
297332
mode::TraceMode,
298333
@nospecialize(track_numbers::Type)
299334
)
300335
T = eltype(A)
301-
N = ndims(A)
302-
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
303-
return ConcreteRArray{T,N}
336+
if A isa UnionAll
337+
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
338+
return ConcreteRArray{T}
339+
else
340+
return Array{traced_type_inner(T, seen, mode, track_numbers)}
341+
end
304342
else
305-
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
343+
N = ndims(A)
344+
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
345+
return ConcreteRArray{T,N}
346+
else
347+
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
348+
end
306349
end
307350
end
308351

@@ -365,6 +408,7 @@ Base.@nospecializeinfer function traced_type_inner(
365408
if isnothing(Base.datatype_fieldcount(aT))
366409
throw(TracedTypeError("Unhandled type $T"))
367410
end
411+
return T
368412
end
369413

370414
if T isa Union
@@ -457,7 +501,7 @@ Base.@nospecializeinfer function traced_type_inner(
457501
end
458502

459503
name = Symbol[]
460-
throw(NoFieldMatchError(T, TT2))
504+
throw(NoFieldMatchError(T, TT2, subTys))
461505
end
462506

463507
const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}()
@@ -580,13 +624,18 @@ end
580624
struct NoFieldMatchError <: TracedTypeException
581625
origty
582626
besteffort
627+
subTys
583628
end
584629
function Base.showerror(io::IO, err::NoFieldMatchError)
585-
print(io, "NoFieldMatchError: ")
586-
return print(
630+
println(io, "NoFieldMatchError: ")
631+
println(
587632
io,
588633
"Cannot convert type $(err.origty), best attempt $(err.besteffort) failed.\nThis could be because the type does not capture the fieldtypes that should be converted in its type parameters.",
589634
)
635+
for (i, subty) in zip(1:fieldcount(err.origty), err.subTys)
636+
origty = fieldtype(err.origty, i)
637+
println(io, "idx=", i, " Derived: ", subty, " Existing: ", origty)
638+
end
590639
end
591640

592641
function make_tracer(

0 commit comments

Comments
 (0)