Skip to content

Commit d82fb52

Browse files
committed
feat: support BFloat16 from Core (if available)
1 parent 91a4a00 commit d82fb52

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

src/Reactant.jl

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,42 @@ include("OrderedIdDict.jl")
77

88
using Enzyme
99

10-
const ReactantPrimitives = Union{
11-
Bool,
12-
Int8,
13-
UInt8,
14-
Int16,
15-
UInt16,
16-
Int32,
17-
UInt32,
18-
Int64,
19-
UInt64,
20-
Float16,
21-
Float32,
22-
# BFloat16,
23-
Float64,
24-
Complex{Float32},
25-
Complex{Float64},
26-
}
10+
@static if isdefined(Core, :BFloat16)
11+
const ReactantPrimitives = Union{
12+
Bool,
13+
Int8,
14+
UInt8,
15+
Int16,
16+
UInt16,
17+
Int32,
18+
UInt32,
19+
Int64,
20+
UInt64,
21+
Float16,
22+
Core.BFloat16,
23+
Float32,
24+
Float64,
25+
Complex{Float32},
26+
Complex{Float64},
27+
}
28+
else
29+
const ReactantPrimitives = Union{
30+
Bool,
31+
Int8,
32+
UInt8,
33+
Int16,
34+
UInt16,
35+
Int32,
36+
UInt32,
37+
Int64,
38+
UInt64,
39+
Float16,
40+
Float32,
41+
Float64,
42+
Complex{Float32},
43+
Complex{Float64},
44+
}
45+
end
2746

2847
abstract type RArray{T<:ReactantPrimitives,N} <: AbstractArray{T,N} end
2948
abstract type RNumber{T<:ReactantPrimitives} <: Number end

src/XLA.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ end
227227
@inline primitive_type(::Type{Float16}) = 10
228228
@inline primitive_type(::Type{Float32}) = 11
229229

230-
# @inline primitive_type(::Type{BFloat16}) = 16
230+
@static if isdefined(Core, :BFloat16)
231+
@inline primitive_type(::Type{BFloat16}) = 16
232+
end
231233

232234
@inline primitive_type(::Type{Float64}) = 12
233235

0 commit comments

Comments
 (0)