Skip to content

Commit 48bda9f

Browse files
committed
Add tests
1 parent 996f5b3 commit 48bda9f

File tree

5 files changed

+37
-21
lines changed

5 files changed

+37
-21
lines changed

test/array.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
STORAGEMODES = [Metal.PrivateStorage, Metal.SharedStorage, Metal.ManagedStorage]
22

3+
const FILL_TYPES = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
4+
Float16, Float32]
5+
Metal.metal_support() >= v"3.1" && push!(FILL_TYPES, BFloat16)
6+
37
@testset "array" begin
48

59
let arr = MtlVector{Int}(undef, 1)
@@ -27,8 +31,7 @@ end
2731
@test mtl(1:3) === 1:3
2832

2933

30-
# Page 22 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
31-
# Only bfloat missing
34+
# Section 2.1 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
3235
supported_number_types = [Float16 => Float16,
3336
Float32 => Float32,
3437
Float64 => Float32,
@@ -41,6 +44,8 @@ end
4144
UInt32 => UInt32,
4245
UInt64 => UInt64,
4346
UInt8 => UInt8]
47+
Metal.metal_support() >= v"3.1" && push!(supported_number_types, BFloat16 => BFloat16)
48+
4449
# Test supported types and ensure only Float64 get converted to Float32
4550
for (SrcType, TargType) in supported_number_types
4651
@test mtl(SrcType[1]) isa MtlArray{TargType}
@@ -227,8 +232,7 @@ end
227232

228233
end
229234

230-
@testset "fill($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
231-
Float16, Float32]
235+
@testset "fill($T)" for T in FILL_TYPES
232236
broken466a = T [Int8,UInt8]
233237
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)
234238

@@ -267,8 +271,7 @@ end
267271
end
268272
end
269273

270-
@testset "fill!($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
271-
Float16, Float32]
274+
@testset "fill!($T)" for T in FILL_TYPES
272275
broken466a = T [Int8,UInt8]
273276
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)
274277

test/device/intrinsics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ end
276276

277277
@testset "parametrically typed" begin
278278
typs = [Int32, Int64, Float32]
279+
metal_support() >= v"3.1" && push!(types, BFloat16)
279280
@testset for typ in typs
280281
function kernel(d::MtlDeviceArray{T}, n) where {T}
281282
t = thread_position_in_threadgroup_1d()

test/mps/linalg.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,11 @@ end
140140
return perm, y
141141
end
142142
end
143-
@testset "$ftype" for ftype in (Float16, Float32)
143+
@testset "$ftype" ftypes = [Float16, Float32]
144+
145+
@testset "$ftype" for ftype in ftypes
144146
# Normal operation
145-
for (shp,k) in [((3,1), 2), ((20,30), 5)]
147+
@testset "$shp, k=$k" for (shp,k) in [((3,1), 2), ((20,30), 5)]
146148
cpu_a = rand(ftype, shp...)
147149

148150
#topk
@@ -163,11 +165,13 @@ end
163165
@test Array(i) == cpu_i
164166
@test Array(v) == cpu_v
165167
end
168+
169+
# test too big `k`
166170
shp = (20,30)
167171
k = 17
168-
169-
cpu_a = rand(ftype, shp...)
170-
cpu_i, cpu_v = cpu_topk(cpu_a, k)
172+
@testset "$shp, k=$k" begin
173+
cpu_a = rand(ftype, shp...)
174+
cpu_i, cpu_v = cpu_topk(cpu_a, k)
171175

172176
a = MtlMatrix(cpu_a)
173177
@test_throws "MPSMatrixFindTopK does not support values of k > 16" i, v = MPS.topk(a, k)
@@ -176,7 +180,8 @@ end
176180
i = MtlMatrix{UInt32}(undef, (k, shp[2]))
177181
v = MtlMatrix{ftype}(undef, (k, shp[2]))
178182

179-
@test_throws "MPSMatrixFindTopK does not support values of k > 16" i, v = MPS.topk!(a, i, v, k)
183+
@test_throws "MPSMatrixFindTopK does not support values of k > 16" i, v = MPS.topk!(a, i, v, k)
184+
end
180185
end
181186
end
182187

test/runtests.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,27 @@ for (rootpath, dirs, files) in walkdir(@__DIR__)
7373
test_runners[file] = ()->include("$(@__DIR__)/$file.jl")
7474
end
7575
end
76+
7677
## GPUArrays testsuite
78+
const gpuarr_eltypes = [Int16, Int32, Int64,
79+
Complex{Int16}, Complex{Int32}, Complex{Int64},
80+
Float16, Float32,
81+
ComplexF16, ComplexF32]
82+
const gpuarr_eltypes_nobf16 = copy(gpuarr_eltypes)
83+
84+
# Add BFloat16 for tests that use it
85+
Metal.metal_support() >= v"3.1" && push!(gpuarr_eltypes, BFloat16)
86+
7787
for name in keys(TestSuite.tests)
7888
if Metal.DefaultStorageMode != Metal.PrivateStorage && name == "indexing scalar"
7989
# GPUArrays' scalar indexing tests assume that indexing is not supported
8090
continue
8191
end
92+
93+
tmp_eltypes = name in ["random"] ? gpuarr_eltypes_nobf16 : gpuarr_eltypes
94+
8295
push!(tests, "gpuarrays$(Base.Filesystem.path_separator)$name")
83-
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray)
96+
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray;eltypes=tmp_eltypes)
8497
end
8598
unique!(tests)
8699

test/setup.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Distributed, Test, Metal, Adapt, ObjectiveC, ObjectiveC.Foundation
1+
using Distributed, Test, Metal, BFloat16s, Adapt, ObjectiveC, ObjectiveC.Foundation
22

33
Metal.functional() || error("Metal.jl is not functional on this system")
44

@@ -10,12 +10,6 @@ gpuarrays_root = dirname(dirname(gpuarrays))
1010
include(joinpath(gpuarrays_root, "test", "testsuite.jl"))
1111
testf(f, xs...; kwargs...) = TestSuite.compare(f, MtlArray, xs...; kwargs...)
1212

13-
const eltypes = [Int16, Int32, Int64,
14-
Complex{Int16}, Complex{Int32}, Complex{Int64},
15-
Float16, Float32,
16-
ComplexF16, ComplexF32]
17-
TestSuite.supported_eltypes(::Type{<:MtlArray}) = eltypes
18-
1913
const runtime_validation = get(ENV, "MTL_DEBUG_LAYER", "0") != "0"
2014
const shader_validation = get(ENV, "MTL_SHADER_VALIDATION", "0") != "0"
2115

@@ -32,7 +26,7 @@ function runtests(f, name)
3226
# generate a temporary module to execute the tests in
3327
mod_name = Symbol("Test", rand(1:100), "Main_", replace(name, '/' => '_'))
3428
mod = @eval(Main, module $mod_name end)
35-
@eval(mod, using Test, Random, Metal)
29+
@eval(mod, using Test, Random, Metal, BFloat16s)
3630

3731
let id = myid()
3832
wait(@spawnat 1 print_testworker_started(name, id))

0 commit comments

Comments
 (0)