diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index 0ef1f2d0e..df82c868b 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -16,9 +16,9 @@ using ObjectiveC, .Foundation import GPUArrays -using BFloat16s +using BFloat16s: BFloat16 -const MtlFloat = Union{Float32, Float16} +const MtlFloat = Union{Float32, Float16, BFloat16} const MPSShape = NSArray#{NSNumber} Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple))) diff --git a/src/Metal.jl b/src/Metal.jl index b6c974588..585a6f748 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS import ObjectiveC: is_macos, darwin_version, macos_version import KernelAbstractions +using BFloat16s: BFloat16 using ScopedValues include("version.jl") diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index d1c3019e1..0210eac06 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob), # pointer type information for typed intrinsics # (this is consumed by the LLVM IR downgrader) for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64, - Float16 => :f16, Float32 => :f32), + Float16 => :f16, Float32 => :f32, + BFloat16 => :bf16), (as, asname) in (AS.Device => "global", AS.ThreadGroup => "local") # map of intrinsics to pointer operand indices and eltypes diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl index 9274cecab..8d0f14cb9 100644 --- a/src/device/intrinsics/simd.jl +++ b/src/device/intrinsics/simd.jl @@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64}) return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1)) end -for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32")) +for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16")) for as in (AS.Device, AS.ThreadGroup) @eval begin @device_function simdgroup_load( @@ -55,7 +55,7 @@ end simdgroup_load(data::MtlDeviceArray{T}, matrix_origin=(1, 1)) Loads data from device or threadgroup memory into an 8x8 SIMD-group matrix -and returns it. `T` must be either `Float16` or `Float32`. +and returns it. `T` must be either `Float16`, `Float32`, or `BFloat16`. # Arguments - `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the source memory to load from. @@ -65,7 +65,7 @@ and returns it. `T` must be either `Float16` or `Float32`. simdgroup_store(src, dest::MtlDeviceArray{T}, matrix_origin=(1, 1)) Stores data from an 8x8 SIMD-group matrix into device or threadgroup memory. -`T` must be either `Float16` or `Float32`. +`T` must be either `Float16`, `Float32`, or `BFloat16`. # Arguments - `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the destination memory to store to. @@ -88,6 +88,7 @@ Returns `a * b + c`. simd_shuffle_map = ((Float32, "f32"), (Float16, "f16"), + (BFloat16,"bf16"), (Int32, "s.i32"), (UInt32, "u.i32"), (Int16, "s.i16"), @@ -118,7 +119,7 @@ The value for `delta` must be the same for all threads in the SIMD-group. This f doesn't modify the upper `delta` lanes of `data` because it doesn't wrap values around the SIMD-group. -T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8 +T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8 """ simd_shuffle_down @@ -131,6 +132,6 @@ lane ID minus `delta`. The value of `delta` must be the same for all threads in a SIMD-group. This function doesn't modify the lower `delta` lanes of `data` because it doesn't wrap values around the SIMD-group. -T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8 +T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8 """ simd_shuffle_up diff --git a/test/device/intrinsics/simd.jl b/test/device/intrinsics/simd.jl index 5107f60a0..f696efb7b 100644 --- a/test/device/intrinsics/simd.jl +++ b/test/device/intrinsics/simd.jl @@ -1,3 +1,5 @@ +using Metal: metal_support + @testset "simd intrinsics" begin @testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,res_idx) in [(simd_shuffle_down, 1), (simd_shuffle_up, 32)] @@ -36,7 +38,9 @@ end @testset "matrix functions" begin - @testset "load_store($typ)" for typ in [Float16, Float32] + simdgroup_types = [Float16, Float32] + metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16) + @testset "load_store($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, origin_a=(1, 1), origin_b=(1, 1)) where {T} sg_a = simdgroup_load(a, origin_a) @@ -59,7 +63,7 @@ end end end - @testset "load_store_tg($typ)" for typ in [Float16, Float32] + @testset "load_store_tg($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T} pos = thread_position_in_threadgroup_2d() @@ -83,7 +87,7 @@ end @test Array(a) == Array(b) end - @testset "mul($typ)" for typ in [Float16, Float32] + @testset "mul($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T} sg_a = simdgroup_load(a) sg_b = simdgroup_load(b) @@ -99,7 +103,7 @@ end @test Array(a) * Array(b) ≈ Array(c) end - @testset "mad($typ)" for typ in [Float16, Float32] + @testset "mad($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}, d::MtlDeviceArray{T}) where {T} sg_a = simdgroup_load(a)