@@ -173,3 +173,79 @@ TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],ld
173173 generatorTest.SetKernel (kernel);
174174 generatorTest.RunTest (lda, ldb, ldc, lda * K, ldb * N);
175175}
176+
177+ TEST_CASE (" Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128], 1≤BatchSize≤16, lda=M, ldb=K, and ldc=M) on random data" ,
178+ " [generation][correctness][gemm]" )
179+ {
180+ auto M = GENERATE (range (1u , 64u + 1u , 1u ));
181+ auto N = GENERATE (range (1u , 64u + 1u , 1u ));
182+ auto K = GENERATE (1u , 16u , 32u , 64u , 128u );
183+ auto BatchSize = GENERATE (range (1u , 16u + 1u , 1u ));
184+
185+ CAPTURE (M, N, K, BatchSize);
186+
187+ mini_jit::Brgemm gemm;
188+ mini_jit::Brgemm::error_t error = gemm.generate (M, N, K, BatchSize, 0 , 0 , 0 , mini_jit::Brgemm::dtype_t ::fp32);
189+
190+ switch (error)
191+ {
192+ case mini_jit::Brgemm::error_t ::success:
193+ break ;
194+ case mini_jit::Brgemm::error_t ::err_batch_reduce_size_not_supported:
195+ FAIL (" Error batch reduce size not supported." );
196+ break ;
197+ case mini_jit::Brgemm::error_t ::err_row_major_order_not_supported:
198+ FAIL (" Error row major order not supported." );
199+ break ;
200+ case mini_jit::Brgemm::error_t ::err_wrong_dimension:
201+ FAIL (" Error err wrong dimension." );
202+ break ;
203+ case mini_jit::Brgemm::error_t ::err_wrong_dtype:
204+ FAIL (" Error wrong dtype." );
205+ break ;
206+ default :
207+ FAIL (" Found unprocessed error type" );
208+ break ;
209+ }
210+
211+ mini_jit::Brgemm::kernel_t kernel = gemm.get_kernel ();
212+ REQUIRE (kernel != nullptr );
213+ }
214+
215+ TEST_CASE (" Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128], 1≤BatchSize≤16, lda=M, ldb=K, and ldc=M) on counting data" ,
216+ " [generation][correctness][gemm]" )
217+ {
218+ auto M = GENERATE (range (1u , 64u + 1u , 1u ));
219+ auto N = GENERATE (range (1u , 64u + 1u , 1u ));
220+ auto K = GENERATE (1u , 16u , 32u , 64u , 128u );
221+ auto BatchSize = GENERATE (range (1u , 16u + 1u , 1u ));
222+
223+ CAPTURE (M, N, K, BatchSize);
224+
225+ mini_jit::Brgemm gemm;
226+ mini_jit::Brgemm::error_t error = gemm.generate (M, N, K, BatchSize, 0 , 0 , 0 , mini_jit::Brgemm::dtype_t ::fp32);
227+
228+ switch (error)
229+ {
230+ case mini_jit::Brgemm::error_t ::success:
231+ break ;
232+ case mini_jit::Brgemm::error_t ::err_batch_reduce_size_not_supported:
233+ FAIL (" Error batch reduce size not supported." );
234+ break ;
235+ case mini_jit::Brgemm::error_t ::err_row_major_order_not_supported:
236+ FAIL (" Error row major order not supported." );
237+ break ;
238+ case mini_jit::Brgemm::error_t ::err_wrong_dimension:
239+ FAIL (" Error err wrong dimension." );
240+ break ;
241+ case mini_jit::Brgemm::error_t ::err_wrong_dtype:
242+ FAIL (" Error wrong dtype." );
243+ break ;
244+ default :
245+ FAIL (" Found unprocessed error type" );
246+ break ;
247+ }
248+
249+ mini_jit::Brgemm::kernel_t kernel = gemm.get_kernel ();
250+ REQUIRE (kernel != nullptr );
251+ }
0 commit comments