Skip to content

Commit 44871b8

Browse files
authored
EnzymeCore: remove rarray (#2409)
* EnzymeCore: remove rarray * more * fix * fix * fix * fixup * fix
1 parent 4ae6187 commit 44871b8

File tree

4 files changed

+70
-68
lines changed

4 files changed

+70
-68
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ EnzymeStaticArraysExt = "StaticArrays"
3838
BFloat16s = "0.2, 0.3, 0.4, 0.5"
3939
CEnum = "0.4, 0.5"
4040
ChainRulesCore = "1"
41-
EnzymeCore = "0.8.8"
41+
EnzymeCore = "0.8.9"
4242
Enzyme_jll = "0.0.180"
4343
GPUArraysCore = "0.1.6, 0.2"
4444
GPUCompiler = "1.3"

lib/EnzymeCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeCore"
22
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.8.8"
4+
version = "0.8.9"
55

66
[compat]
77
Adapt = "3, 4"

lib/EnzymeCore/src/EnzymeCore.jl

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -692,59 +692,4 @@ end
692692

693693
Combined(mode::ReverseMode) = mode
694694

695-
"""
696-
Primitive Type usable within Reactant. See Reactant.jl for more information.
697-
"""
698-
@static if isdefined(Core, :BFloat16)
699-
const ReactantPrimitive = Union{
700-
Bool,
701-
Int8,
702-
UInt8,
703-
Int16,
704-
UInt16,
705-
Int32,
706-
UInt32,
707-
Int64,
708-
UInt64,
709-
Float16,
710-
Core.BFloat16,
711-
Float32,
712-
Float64,
713-
Complex{Float32},
714-
Complex{Float64},
715-
}
716-
else
717-
const ReactantPrimitive = Union{
718-
Bool,
719-
Int8,
720-
UInt8,
721-
Int16,
722-
UInt16,
723-
Int32,
724-
UInt32,
725-
Int64,
726-
UInt64,
727-
Float16,
728-
Float32,
729-
Float64,
730-
Complex{Float32},
731-
Complex{Float64},
732-
}
733-
end
734-
735-
"""
736-
Abstract Reactant Array type. See Reactant.jl for more information
737-
"""
738-
abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end
739-
@inline Base.eltype(::RArray{T}) where T = T
740-
@inline Base.eltype(::Type{<:RArray{T}}) where T = T
741-
742-
"""
743-
Abstract Reactant Number type. See Reactant.jl for more information
744-
"""
745-
abstract type RNumber{T<:ReactantPrimitive} <: Number end
746-
@inline Base.eltype(::RNumber{T}) where T = T
747-
@inline Base.eltype(::Type{<:RNumber{T}}) where T = T
748-
749-
750695
end # module EnzymeCore

src/analyses/activity.jl

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,10 @@ end
7373
end
7474
end
7575

76-
@inline numbereltype(::Type{<:EnzymeCore.RNumber{T}}) where {T} = T
77-
@inline ptreltype(::Type{<:EnzymeCore.RArray{T}}) where {T} = T
7876
@inline ptreltype(::Type{Ptr{T}}) where {T} = T
7977
@inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T
8078
@inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T
8179
@inline ptreltype(::Type{Base.RefValue{T}}) where {T} = T
82-
@inline ptreltype(::Type{Array{T,N}}) where {T,N} = T
83-
@inline ptreltype(::Type{Array{T,N} where N}) where {T} = T
8480
@inline ptreltype(::Type{Complex{T}}) where {T} = T
8581
@inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T
8682
@inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V
@@ -92,8 +88,6 @@ else
9288
end
9389

9490
@inline is_arrayorvararg_ty(::Type) = false
95-
@inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true
96-
@inline is_arrayorvararg_ty(::Type{Array{T,N} where N}) where {T} = true
9791
@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where {T2} = true
9892
@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where {T} = true
9993
@inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N}}) where {T,N} = true
@@ -159,6 +153,48 @@ end
159153
@inline is_vararg_tup(x) = false
160154
@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where {T2} = true
161155

156+
Base.@nospecializeinfer @inline function is_mutable_array(@nospecialize(T::Type))
157+
if T <: Array
158+
return true
159+
end
160+
while T isa UnionAll
161+
T = T.body
162+
end
163+
if T isa DataType
164+
if hasproperty(T, :name) && hasproperty(T.name, :module)
165+
mod = T.name.module
166+
if string(mod) == "Reactant" && (T.name.name == :ConcretePJRTArray || T.name.name == :ConcreteIFRTArray || T.name.name == :TracedRArray)
167+
return true
168+
end
169+
end
170+
end
171+
return false
172+
end
173+
174+
Base.@nospecializeinfer @inline function is_wrapped_number(@nospecialize(T::Type))
175+
if T isa UnionAll
176+
return is_wrapped_number(T.body)
177+
end
178+
while T isa UnionAll
179+
T = T.body
180+
end
181+
if T isa DataType && hasproperty(T, :name) && hasproperty(T.name, :module)
182+
mod = T.name.module
183+
if string(mod) == "Reactant" && (T.name.name == :ConcretePJRTNumber || T.name.name == :ConcreteIFRTNumber || T.name.name == :TracedRNumber)
184+
return true
185+
end
186+
end
187+
return false
188+
end
189+
190+
Base.@nospecializeinfer @inline function unwrapped_number_type(@nospecialize(T::Type))
191+
while T isa UnionAll
192+
T = T.body
193+
end
194+
return T.parameters[1]
195+
end
196+
197+
162198
@inline function active_reg_inner(
163199
::Type{T},
164200
seen::ST,
@@ -198,9 +234,9 @@ end
198234
return ActiveState
199235
end
200236

201-
if T <: EnzymeCore.RNumber
237+
if is_wrapped_number(T)
202238
return active_reg_inner(
203-
numbereltype(T),
239+
unwrapped_number_type(T),
204240
seen,
205241
world,
206242
Val(justActive),
@@ -209,11 +245,32 @@ end
209245
)
210246
end
211247

248+
if is_mutable_array(T)
249+
if justActive
250+
return AnyState
251+
end
252+
253+
if active_reg_inner(
254+
eltype(T),
255+
seen,
256+
world,
257+
Val(justActive),
258+
Val(UnionSret),
259+
Val(AbstractIsMixed),
260+
) == AnyState
261+
return AnyState
262+
else
263+
if AbstractIsMixed
264+
return MixedState
265+
else
266+
return DupState
267+
end
268+
end
269+
end
270+
212271
if T <: Ptr ||
213272
T <: Core.LLVMPtr ||
214-
T <: Base.RefValue ||
215-
T <: Array || T <: EnzymeCore.RArray
216-
is_arrayorvararg_ty(T)
273+
T <: Base.RefValue || is_arrayorvararg_ty(T)
217274
if justActive
218275
return AnyState
219276
end

0 commit comments

Comments
 (0)