Skip to content

Commit f5a52c8

Browse files
committed
Move BFloat16 code out of extension
1 parent dbed750 commit f5a52c8

File tree

4 files changed

+4
-17
lines changed

4 files changed

+4
-17
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.4.0"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
8+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
89
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
910
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
1011
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -26,11 +27,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2627
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2728

2829
[weakdeps]
29-
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
3030
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3131

3232
[extensions]
33-
BFloat16sExt = "BFloat16s"
3433
SpecialFunctionsExt = "SpecialFunctions"
3534

3635
[compat]

ext/BFloat16sExt.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

lib/mps/MPS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19+
using BFloat16s
20+
1921
const MtlFloat = Union{Float32, Float16}
2022

2123
is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)

lib/mps/matrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
3838

3939
# Conversions for MPSDataTypes with Julia equivalents
4040
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
41-
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool]
41+
for type in [:UInt8,:UInt16,:UInt32,:UInt64,:Int8,:Int16,:Int32,:Int64,:Float16,:BFloat16,:Float32,:ComplexF16,:ComplexF32,:Bool]
4242
@eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type))
4343
@eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type
4444
end

0 commit comments

Comments
 (0)