Skip to content

Commit 5c9cbfc

Browse files
committed
Add tests
1 parent c479d0b commit 5c9cbfc

File tree

5 files changed

+48
-35
lines changed

5 files changed

+48
-35
lines changed

test/array.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
STORAGEMODES = [Metal.PrivateStorage, Metal.SharedStorage, Metal.ManagedStorage]
22

3+
const FILL_TYPES = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
4+
Float16, Float32]
5+
Metal.metal_support() >= v"3.1" && push!(FILL_TYPES, BFloat16)
6+
37
@testset "array" begin
48

59
let arr = MtlVector{Int}(undef, 1)
@@ -27,8 +31,7 @@ end
2731
@test mtl(1:3) === 1:3
2832

2933

30-
# Page 22 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
31-
# Only bfloat missing
34+
# Section 2.1 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
3235
supported_number_types = [Float16 => Float16,
3336
Float32 => Float32,
3437
Float64 => Float32,
@@ -41,6 +44,8 @@ end
4144
UInt32 => UInt32,
4245
UInt64 => UInt64,
4346
UInt8 => UInt8]
47+
Metal.metal_support() >= v"3.1" && push!(supported_number_types, BFloat16 => BFloat16)
48+
4449
# Test supported types and ensure only Float64 get converted to Float32
4550
for (SrcType, TargType) in supported_number_types
4651
@test mtl(SrcType[1]) isa MtlArray{TargType}
@@ -227,8 +232,7 @@ end
227232

228233
end
229234

230-
@testset "fill($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
231-
Float16, Float32]
235+
@testset "fill($T)" for T in FILL_TYPES
232236

233237
b = rand(T)
234238

@@ -265,8 +269,7 @@ end
265269
end
266270
end
267271

268-
@testset "fill!($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
269-
Float16, Float32]
272+
@testset "fill!($T)" for T in FILL_TYPES
270273

271274
b = rand(T)
272275

test/device/intrinsics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ end
276276

277277
@testset "parametrically typed" begin
278278
typs = [Int32, Int64, Float32]
279+
metal_support() >= v"3.1" && push!(types, BFloat16)
279280
@testset for typ in typs
280281
function kernel(d::MtlDeviceArray{T}, n) where {T}
281282
t = thread_position_in_threadgroup_1d()

test/mps/linalg.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -147,33 +147,36 @@ function cpu_topk(x::Matrix{T}, k; rev=true, dims=1) where {T}
147147
end
148148

149149
@testset "topk & topk!" begin
150-
for ftype in (Float16, Float32)
150+
ftypes = [Float16, Float32]
151+
152+
@testset "$ftype" for ftype in ftypes
151153
# Normal operation
152-
@testset "$ftype" begin
153-
for (shp,k) in [((3,1), 2), ((20,30), 5)]
154-
cpu_a = rand(ftype, shp...)
154+
@testset "$shp, k=$k" for (shp,k) in [((3,1), 2), ((20,30), 5)]
155+
cpu_a = rand(ftype, shp...)
155156

156-
#topk
157-
cpu_i, cpu_v = cpu_topk(cpu_a, k)
157+
#topk
158+
cpu_i, cpu_v = cpu_topk(cpu_a, k)
158159

159-
a = MtlMatrix(cpu_a)
160-
i, v = MPS.topk(a, k)
160+
a = MtlMatrix(cpu_a)
161+
i, v = MPS.topk(a, k)
161162

162-
@test Array(i) == cpu_i
163-
@test Array(v) == cpu_v
163+
@test Array(i) == cpu_i
164+
@test Array(v) == cpu_v
164165

165-
#topk!
166-
i = MtlMatrix{UInt32}(undef, (k, shp[2]))
167-
v = MtlMatrix{ftype}(undef, (k, shp[2]))
166+
#topk!
167+
i = MtlMatrix{UInt32}(undef, (k, shp[2]))
168+
v = MtlMatrix{ftype}(undef, (k, shp[2]))
168169

169-
i, v = MPS.topk!(a, i, v, k)
170+
i, v = MPS.topk!(a, i, v, k)
170171

171-
@test Array(i) == cpu_i
172-
@test Array(v) == cpu_v
173-
end
174-
shp = (20,30)
175-
k = 17
172+
@test Array(i) == cpu_i
173+
@test Array(v) == cpu_v
174+
end
176175

176+
# test too big `k`
177+
shp = (20,30)
178+
k = 17
179+
@testset "$shp, k=$k" begin
177180
cpu_a = rand(ftype, shp...)
178181
cpu_i, cpu_v = cpu_topk(cpu_a, k)
179182

@@ -185,7 +188,6 @@ end
185188
v = MtlMatrix{ftype}(undef, (k, shp[2]))
186189

187190
@test_throws "MPSMatrixFindTopK does not support values of k > 16" i, v = MPS.topk!(a, i, v, k)
188-
189191
end
190192
end
191193
end

test/runtests.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,27 @@ for (rootpath, dirs, files) in walkdir(@__DIR__)
7373
test_runners[file] = ()->include("$(@__DIR__)/$file.jl")
7474
end
7575
end
76+
7677
## GPUArrays testsuite
78+
const gpuarr_eltypes = [Int16, Int32, Int64,
79+
Complex{Int16}, Complex{Int32}, Complex{Int64},
80+
Float16, Float32,
81+
ComplexF16, ComplexF32]
82+
const gpuarr_eltypes_nobf16 = copy(gpuarr_eltypes)
83+
84+
# Add BFloat16 for tests that use it
85+
Metal.metal_support() >= v"3.1" && push!(gpuarr_eltypes, BFloat16)
86+
7787
for name in keys(TestSuite.tests)
7888
if Metal.DefaultStorageMode != Metal.PrivateStorage && name == "indexing scalar"
7989
# GPUArrays' scalar indexing tests assume that indexing is not supported
8090
continue
8191
end
92+
93+
tmp_eltypes = name in ["random"] ? gpuarr_eltypes_nobf16 : gpuarr_eltypes
94+
8295
push!(tests, "gpuarrays$(Base.Filesystem.path_separator)$name")
83-
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray)
96+
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray;eltypes=tmp_eltypes)
8497
end
8598
unique!(tests)
8699

test/setup.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Distributed, Test, Metal, Adapt, ObjectiveC, ObjectiveC.Foundation
1+
using Distributed, Test, Metal, BFloat16s, Adapt, ObjectiveC, ObjectiveC.Foundation
22

33
Metal.functional() || error("Metal.jl is not functional on this system")
44

@@ -10,12 +10,6 @@ gpuarrays_root = dirname(dirname(gpuarrays))
1010
include(joinpath(gpuarrays_root, "test", "testsuite.jl"))
1111
testf(f, xs...; kwargs...) = TestSuite.compare(f, MtlArray, xs...; kwargs...)
1212

13-
const eltypes = [Int16, Int32, Int64,
14-
Complex{Int16}, Complex{Int32}, Complex{Int64},
15-
Float16, Float32,
16-
ComplexF16, ComplexF32]
17-
TestSuite.supported_eltypes(::Type{<:MtlArray}) = eltypes
18-
1913
const runtime_validation = get(ENV, "MTL_DEBUG_LAYER", "0") != "0"
2014

2115
using Random
@@ -31,7 +25,7 @@ function runtests(f, name)
3125
# generate a temporary module to execute the tests in
3226
mod_name = Symbol("Test", rand(1:100), "Main_", replace(name, '/' => '_'))
3327
mod = @eval(Main, module $mod_name end)
34-
@eval(mod, using Test, Random, Metal)
28+
@eval(mod, using Test, Random, Metal, BFloat16s)
3529

3630
let id = myid()
3731
wait(@spawnat 1 print_testworker_started(name, id))

0 commit comments

Comments
 (0)