Skip to content

Commit e59b18d

Browse files
committed
Move BFloat16 code out of extension
1 parent 60a9e34 commit e59b18d

File tree

4 files changed

+5
-18
lines changed

4 files changed

+5
-18
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.4.0"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
8+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
89
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
910
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
1011
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -26,11 +27,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2627
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2728

2829
[weakdeps]
29-
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
3030
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3131

3232
[extensions]
33-
BFloat16sExt = "BFloat16s"
3433
SpecialFunctionsExt = "SpecialFunctions"
3534

3635
[compat]

ext/BFloat16sExt.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

lib/mps/MPS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19+
using BFloat16s
20+
1921
const MtlFloat = Union{Float32, Float16}
2022

2123
is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)

lib/mps/matrix.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
4141

4242
# Conversions for MPSDataTypes with Julia equivalents
4343
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
44-
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,(ComplexF16,:MPSDataTypeComplexFloat16),(ComplexF32,:MPSDataTypeComplexFloat32),Bool]
45-
jltype, mpstype = if type isa Type
44+
for type in [:UInt8,:UInt16,:UInt32,:UInt64,:Int8,:Int16,:Int32,:Int64,:Float16,:BFloat16,:Float32,(:ComplexF16,:MPSDataTypeComplexFloat16),(:ComplexF32,:MPSDataTypeComplexFloat32),:Bool]
45+
jltype, mpstype = if type isa Symbol
4646
type, Symbol(:MPSDataType, type)
4747
else
4848
type

0 commit comments

Comments
 (0)