@@ -52,6 +52,7 @@ bool validateConfig(const MatmulConfig &cfg, ArrayRef<uint32_t> shape,
5252 (shape[1 ] % cfg.NThreads != 0 && isVNNIMM2D &&
5353 cfg.NBlock != cfg.innerMostNBlock ))
5454 return false ;
55+ // Require K % KBlock == 0 as brgemm dynamic bs is not supported now
5556 if (cfg.KBlock % cfg.innerMostKBlock != 0 ||
5657 ((shape[2 ] / cfg.KThreads % cfg.KBlock != 0 ||
5758 shape[2 ] / cfg.KThreads % cfg.innerMostKBlock != 0 ) &&
@@ -199,13 +200,15 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
199200 llvm::divideCeil (M / config.innerMostMBlock , config.MThreads );
200201 uint32_t MNumInnerBlockPerBlock =
201202 llvm::divideCeil (config.MBlock , config.innerMostMBlock );
203+ assert (MNumInnerBlockPerBlock > 0 && " Invalid MNumInnerBlockPerBlock." );
202204 uint32_t MCost = MNumBlockPerThread % MNumInnerBlockPerBlock != 0 ||
203205 (M / config.innerMostNBlock % config.MThreads != 0 &&
204206 config.MBlock != config.innerMostMBlock );
205207 uint32_t NNumBlockPerThread =
206208 llvm::divideCeil (N / config.innerMostNBlock , config.NThreads );
207209 uint32_t NNumInnerBlockPerBlock =
208210 llvm::divideCeil (config.NBlock , config.innerMostNBlock );
211+ assert (NNumInnerBlockPerBlock > 0 && " Invalid NNumInnerBlockPerBlock." );
209212 uint32_t NCost = NNumBlockPerThread % NNumInnerBlockPerBlock != 0 ||
210213 (N / config.innerMostNBlock % config.NThreads != 0 &&
211214 config.NBlock != config.innerMostNBlock );
@@ -324,39 +327,28 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
324327 KBlockCandidates = innerMostKBlockCandidates;
325328 }
326329
327- // TODO: improve via multi threading or add more constraints to restrict the
328- // candidate size
330+ bool isVNNIMM2D =
331+ linalgx::isGenericPackedMatmulOp (root, linalgx::PackingType::VNNI_MM2D);
332+ // TODO: improve via multi threading or add more constraints to restrict
333+ // the candidate size
329334 for (uint32_t MThreads : MThreadsCandidates) {
330335 for (uint32_t NThreads : NThreadsCandidates) {
331336 for (uint32_t KThreads : KThreadsCandidates) {
332337 if (!validateThreads ({MThreads, NThreads, KThreads}, sysDesc))
333338 continue ;
334339 for (uint32_t MBlock : MBlockCandidates) {
335340 for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
336- if (MBlock % innerMostMBlock != 0 ||
337- (shape[0 ] % innerMostMBlock != 0 &&
338- !allowIndivisibleInnerblock))
339- continue ;
340341 for (uint32_t NBlock : NBlockCandidates) {
341342 for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
342- if (NBlock % innerMostNBlock != 0 ||
343- (shape[1 ] % innerMostNBlock != 0 &&
344- !allowIndivisibleInnerblock))
345- continue ;
346343 for (uint32_t KBlock : KBlockCandidates) {
347344 for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
348- // Require K % KBlock == 0 as dynamic bs is not supported
349- // now
350- if (KBlock % innerMostKBlock != 0 ||
351- ((shape[2 ] / KThreads % KBlock != 0 ||
352- shape[2 ] / KThreads % innerMostKBlock != 0 ) &&
353- !allowIndivisibleInnerblock))
354- continue ;
355345 MatmulConfig config{
356346 MThreads, NThreads, KThreads,
357347 MBlock, NBlock, KBlock,
358348 innerMostMBlock, innerMostNBlock, innerMostKBlock};
359- configs.push_back (config);
349+ if (validateConfig (config, shape,
350+ allowIndivisibleInnerblock, isVNNIMM2D))
351+ configs.push_back (config);
360352 }
361353 }
362354 }
0 commit comments