@@ -147,33 +147,36 @@ function cpu_topk(x::Matrix{T}, k; rev=true, dims=1) where {T}
147147end
148148
149149@testset " topk & topk!" begin
150- for ftype in (Float16, Float32)
150+ ftypes = [Float16, Float32]
151+
152+ @testset " $ftype " for ftype in ftypes
151153 # Normal operation
152- @testset " $ftype " begin
153- for (shp,k) in [((3 ,1 ), 2 ), ((20 ,30 ), 5 )]
154- cpu_a = rand (ftype, shp... )
154+ @testset " $shp , k=$k " for (shp,k) in [((3 ,1 ), 2 ), ((20 ,30 ), 5 )]
155+ cpu_a = rand (ftype, shp... )
155156
156- # topk
157- cpu_i, cpu_v = cpu_topk (cpu_a, k)
157+ # topk
158+ cpu_i, cpu_v = cpu_topk (cpu_a, k)
158159
159- a = MtlMatrix (cpu_a)
160- i, v = MPS. topk (a, k)
160+ a = MtlMatrix (cpu_a)
161+ i, v = MPS. topk (a, k)
161162
162- @test Array (i) == cpu_i
163- @test Array (v) == cpu_v
163+ @test Array (i) == cpu_i
164+ @test Array (v) == cpu_v
164165
165- # topk!
166- i = MtlMatrix {UInt32} (undef, (k, shp[2 ]))
167- v = MtlMatrix {ftype} (undef, (k, shp[2 ]))
166+ # topk!
167+ i = MtlMatrix {UInt32} (undef, (k, shp[2 ]))
168+ v = MtlMatrix {ftype} (undef, (k, shp[2 ]))
168169
169- i, v = MPS. topk! (a, i, v, k)
170+ i, v = MPS. topk! (a, i, v, k)
170171
171- @test Array (i) == cpu_i
172- @test Array (v) == cpu_v
173- end
174- shp = (20 ,30 )
175- k = 17
172+ @test Array (i) == cpu_i
173+ @test Array (v) == cpu_v
174+ end
176175
176+ # test too big `k`
177+ shp = (20 ,30 )
178+ k = 17
179+ @testset " $shp , k=$k " begin
177180 cpu_a = rand (ftype, shp... )
178181 cpu_i, cpu_v = cpu_topk (cpu_a, k)
179182
185188 v = MtlMatrix {ftype} (undef, (k, shp[2 ]))
186189
187190 @test_throws " MPSMatrixFindTopK does not support values of k > 16" i, v = MPS. topk! (a, i, v, k)
188-
189191 end
190192 end
191193end
0 commit comments