Skip to content

Commit f16795c

Browse files
authored
Cleanup solve rules and bigfloat (#1641)
* Cleanup solve rules and bigfloat * fix * Fix aug fwd msg * fix cache ty
1 parent 674fd0d commit f16795c

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err
19991999
end
20002000

20012001
if errtype == API.ET_NoDerivative
2002-
if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward mode derivative found for", msg) || occursin("No reverse pass found", msg)
2002+
if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward pass", msg) || occursin("No reverse pass found", msg)
20032003
ir = nothing
20042004
end
20052005
exc = NoDerivativeException(msg, ir, bt)

src/internal_rules.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,20 @@ else
390390
}
391391
end
392392

393-
cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), UT, typeof(cache_b)}}(
393+
cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{
394+
eltype(RT),
395+
EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing,
396+
UT,
397+
typeof(cache_b)
398+
}}(
394399
(cache_res, dres, cache_A, cache_b)
395400
)
396401

397-
return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache)
402+
return EnzymeRules.AugmentedReturn{
403+
EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
404+
EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing,
405+
typeof(cache)
406+
}(retres, dres, cache)
398407
end
399408

400409
function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT

src/typetree.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable)
111111
return TypeTree(API.DT_Double, -1, ctx)
112112
end
113113

114+
function typetree_inner(::Type{BigFloat}, ctx, dl, seen::TypeTreeTable)
115+
return TypeTree()
116+
end
117+
114118
function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat}
115119
GPUCompiler.@safe_warn "Unknown floating point type" T
116120
return TypeTree()

0 commit comments

Comments
 (0)