@@ -147,33 +147,36 @@ function cpu_topk(x::Matrix{T}, k; rev=true, dims=1) where {T}
147147end 
148148
149149@testset  " topk & topk!" begin 
150-     for  ftype in  (Float16, Float32)
150+     ftypes =  [Float16, Float32]
151+ 
152+     @testset  " $ftype " for  ftype in  ftypes
151153        #  Normal operation
152-         @testset  " $ftype " begin 
153-             for  (shp,k) in  [((3 ,1 ), 2 ), ((20 ,30 ), 5 )]
154-                 cpu_a =  rand (ftype, shp... )
154+         @testset  " $shp , k=$k " for  (shp,k) in  [((3 ,1 ), 2 ), ((20 ,30 ), 5 )]
155+             cpu_a =  rand (ftype, shp... )
155156
156-                  # topk
157-                  cpu_i, cpu_v =  cpu_topk (cpu_a, k)
157+             # topk
158+             cpu_i, cpu_v =  cpu_topk (cpu_a, k)
158159
159-                  a =  MtlMatrix (cpu_a)
160-                  i, v =  MPS. topk (a, k)
160+             a =  MtlMatrix (cpu_a)
161+             i, v =  MPS. topk (a, k)
161162
162-                  @test  Array (i) ==  cpu_i
163-                  @test  Array (v) ==  cpu_v
163+             @test  Array (i) ==  cpu_i
164+             @test  Array (v) ==  cpu_v
164165
165-                  # topk!
166-                  i =  MtlMatrix {UInt32} (undef, (k, shp[2 ]))
167-                  v =  MtlMatrix {ftype} (undef, (k, shp[2 ]))
166+             # topk!
167+             i =  MtlMatrix {UInt32} (undef, (k, shp[2 ]))
168+             v =  MtlMatrix {ftype} (undef, (k, shp[2 ]))
168169
169-                  i, v =  MPS. topk! (a, i, v, k)
170+             i, v =  MPS. topk! (a, i, v, k)
170171
171-                 @test  Array (i) ==  cpu_i
172-                 @test  Array (v) ==  cpu_v
173-             end 
174-             shp =  (20 ,30 )
175-             k =  17 
172+             @test  Array (i) ==  cpu_i
173+             @test  Array (v) ==  cpu_v
174+         end 
176175
176+         #  test too big `k`
177+         shp =  (20 ,30 )
178+         k =  17 
179+         @testset  " $shp , k=$k " begin 
177180            cpu_a =  rand (ftype, shp... )
178181            cpu_i, cpu_v =  cpu_topk (cpu_a, k)
179182
185188            v =  MtlMatrix {ftype} (undef, (k, shp[2 ]))
186189
187190            @test_throws  " MPSMatrixFindTopK does not support values of k > 16" =  MPS. topk! (a, i, v, k)
188- 
189191        end 
190192    end 
191193end 
0 commit comments