Skip to content

Commit 2839d3f

Browse files
vchuravywsmoses
andauthored
Don't call specialize_method again (#2148)
* Don't call specialize_method again * fixup * fix * fix * fix * fix * fixups * more cleaning * fix * cleanup * fix * fix * fix * fix * fixup * fix * fixup * fewer calls in custom rules * more cleanup * fix * fix * fix * fix * fix * fix * ar --------- Co-authored-by: William S. Moses <[email protected]>
1 parent e69f3c2 commit 2839d3f

File tree

20 files changed

+3290
-3441
lines changed

20 files changed

+3290
-3441
lines changed

src/absint.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,14 +740,21 @@ function abs_typeof(
740740
return (false, nothing, nothing)
741741
end
742742

743+
@inline function is_zero(@nospecialize(x::LLVM.Value))::Bool
744+
if x isa LLVM.ConstantInt
745+
return convert(UInt, x) == 0
746+
end
747+
return false
748+
end
749+
743750
function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String}
744751
if isa(arg, ConstantExpr)
745752
ce = arg
746753
while isa(ce, ConstantExpr)
747754
if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr
748755
ce = operands(ce)[1]
749756
elseif opcode(ce) == LLVM.API.LLVMGetElementPtr
750-
if all(x -> x isa LLVM.ConstantInt && convert(UInt, x) == 0, operands(ce)[2:end])
757+
if all(is_zero, operands(ce)[2:end])
751758
ce = operands(ce)[1]
752759
else
753760
break

src/analyses/activity.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262

6363
@inline forcefold(::Val{RT}) where {RT} = RT
6464

65-
@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N}) where {ty,sty,N}
65+
@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N})::ActivityState where {ty,sty,N}
6666
if sty == AnyState || sty == ty
6767
return forcefold(Val(ty), C...)
6868
end
@@ -107,11 +107,7 @@ else
107107
@inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true
108108
end
109109

110-
@inline function datatype_fieldcount(t::Type{T}) where {T}
111-
return Base.datatype_fieldcount(t)
112-
end
113-
114-
@inline function staticInTup(::Val{T}, tup::NTuple{N,Val}) where {T,N}
110+
Base.@assume_effects :removable :foldable :nothrow @inline function staticInTup(::Val{T}, tup::NTuple{N,Val})::Bool where {T,N}
115111
any(ntuple(Val(N)) do i
116112
Base.@_inline_meta
117113
Val(T) == tup[i]
@@ -125,7 +121,7 @@ end
125121
::Val{justActive},
126122
::Val{UnionSret},
127123
::Val{AbstractIsMixed},
128-
) where {ST,Seen,justActive,UnionSret,AbstractIsMixed}
124+
)::ActivityState where {ST,Seen,justActive,UnionSret,AbstractIsMixed}
129125
if ST isa Union
130126
return forcefold(
131127
Val(
@@ -285,7 +281,7 @@ end
285281
return DupState
286282
end
287283
end
288-
if datatype_fieldcount(aT) === nothing
284+
if Base.datatype_fieldcount(aT) === nothing
289285
if AbstractIsMixed
290286
return MixedState
291287
else
@@ -383,11 +379,11 @@ end
383379
return ty
384380
end
385381

386-
@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T,world}
382+
Base.@assume_effects :removable :foldable @inline @generated function active_reg_nothrow(::Type{T}, ::Val{world})::ActivityState where {T,world}
387383
return active_reg_inner(T, (), world)
388384
end
389385

390-
Base.@pure @inline function active_reg(
386+
Base.@assume_effects :removable :foldable @inline function active_reg(
391387
::Type{T},
392388
world::Union{Nothing,UInt} = nothing,
393389
)::Bool where {T}
@@ -411,21 +407,21 @@ Base.@pure @inline function active_reg(
411407
end
412408
end
413409

414-
@inline function guaranteed_const(::Type{T}) where {T}
410+
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const(::Type{T})::Bool where {T}
415411
rt = active_reg_nothrow(T, Val(nothing))
416412
res = rt == AnyState
417413
return res
418414
end
419415

420-
@inline function guaranteed_const_nongen(::Type{T}, world) where {T}
416+
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world)::Bool where {T}
421417
rt = active_reg_inner(T, (), world)
422418
res = rt == AnyState
423419
return res
424420
end
425421

426422
# check if a value is guaranteed to be not contain active[register] data
427423
# (aka not either mixed or active)
428-
@inline function guaranteed_nonactive(::Type{T}) where {T}
424+
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive(::Type{T})::Bool where {T}
429425
rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing))
430426
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
431427
end
@@ -435,10 +431,10 @@ end
435431
436432
Try to guess the most appropriate [`Annotation`](@ref) for arguments of type `T` passed to [`autodiff`](@ref) with a given `mode`.
437433
"""
438-
@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} =
434+
Base.@assume_effects :removable :foldable :nothrow @inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} =
439435
guess_activity(T, convert(API.CDerivativeMode, mode))
440436

441-
@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
437+
Base.@assume_effects :removable :foldable :nothrow @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
442438
ActReg = active_reg_nothrow(T, Val(nothing))
443439
if ActReg == AnyState
444440
return Const{T}

src/api.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module API
22

33
import LLVM.API: LLVMValueRef, LLVMModuleRef, LLVMTypeRef, LLVMContextRef
44
using Enzyme_jll
5+
using EnzymeCore
56
using Libdl
67
using LLVM
78
using CEnum
@@ -208,6 +209,20 @@ end
208209
# but don't need the forward
209210
)
210211

212+
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.Const} = API.DFT_CONSTANT
213+
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.Active} =
214+
API.DFT_OUT_DIFF
215+
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.Duplicated} =
216+
API.DFT_DUP_ARG
217+
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.BatchDuplicated} =
218+
API.DFT_DUP_ARG
219+
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.BatchDuplicatedFunc} =
220+
API.DFT_DUP_ARG
221+
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.DuplicatedNoNeed} =
222+
API.DFT_DUP_NONEED
223+
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.BatchDuplicatedNoNeed} =
224+
API.DFT_DUP_NONEED
225+
211226
@cenum(
212227
CDerivativeMode,
213228
DEM_ForwardMode = 0,

0 commit comments

Comments
 (0)