Skip to content

Commit 939ee6e

Browse files
committed
Move BFloat16 code out of extension
1 parent b8ab3b6 commit 939ee6e

File tree

4 files changed

+5
-18
lines changed

4 files changed

+5
-18
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "1.5.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
78
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
89
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
910
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -23,11 +24,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2324
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2425

2526
[weakdeps]
26-
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
2727
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2828

2929
[extensions]
30-
BFloat16sExt = "BFloat16s"
3130
SpecialFunctionsExt = "SpecialFunctions"
3231

3332
[compat]

ext/BFloat16sExt.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

lib/mps/MPS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19+
using BFloat16s
20+
1921
const MtlFloat = Union{Float32, Float16}
2022

2123
const MPSShape = NSArray#{NSNumber}

lib/mps/matrix.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
55

66
# Conversions for MPSDataTypes with Julia equivalents
77
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
8-
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,(ComplexF16,:MPSDataTypeComplexFloat16),(ComplexF32,:MPSDataTypeComplexFloat32),Bool]
9-
jltype, mpstype = if type isa Type
8+
for type in [:UInt8,:UInt16,:UInt32,:UInt64,:Int8,:Int16,:Int32,:Int64,:Float16,:BFloat16,:Float32,(:ComplexF16,:MPSDataTypeComplexFloat16),(:ComplexF32,:MPSDataTypeComplexFloat32),:Bool]
9+
jltype, mpstype = if type isa Symbol
1010
type, Symbol(:MPSDataType, type)
1111
else
1212
type

0 commit comments

Comments
 (0)