@@ -37,15 +37,29 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737 return ss;
3838}
3939
40- bool validateConfig (const MatmulConfig &cfg) {
40+ bool validateConfig (const MatmulConfig &cfg, ArrayRef<uint32_t > shape,
41+ bool allowIndivisibleInnerblock, bool isVNNIMM2D) {
4142 if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4243 cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4344 cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
4445 cfg.innerMostKBlock <= 0 )
4546 return false ;
4647 if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
47- cfg.NBlock % cfg.innerMostNBlock != 0 ||
48- cfg.KBlock % cfg.innerMostKBlock != 0 )
48+ (shape[0 ] % cfg.innerMostMBlock != 0 && !allowIndivisibleInnerblock))
49+ return false ;
50+ if (cfg.NBlock % cfg.innerMostNBlock != 0 ||
51+ ((shape[1 ] % cfg.innerMostNBlock != 0 ) && !allowIndivisibleInnerblock) ||
52+ (shape[1 ] % cfg.NThreads != 0 && isVNNIMM2D &&
53+ cfg.NBlock != cfg.innerMostNBlock ))
54+ return false ;
55+ // Require K % KBlock == 0 as brgemm dynamic bs is not supported now
56+ if (cfg.KBlock % cfg.innerMostKBlock != 0 ||
57+ ((shape[2 ] / cfg.KThreads % cfg.KBlock != 0 ||
58+ shape[2 ] / cfg.KThreads % cfg.innerMostKBlock != 0 ) &&
59+ !allowIndivisibleInnerblock))
60+ return false ;
61+ // KThreads will not shrink automatically
62+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
4963 return false ;
5064 return true ;
5165}
@@ -179,21 +193,22 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179193 ArrayRef<uint32_t > shape,
180194 const MatmulConfig &config,
181195 CPUTargetDescriptionAnalysis &sysDesc) {
182- assert (validateConfig (config) && " config is invalid" );
183196 assert (shape.size () >= 3 && " shape.size() should >= 3" );
184197 uint32_t M = shape[0 ], N = shape[1 ];
185198 double cost = 0 ;
186199 uint32_t MNumBlockPerThread =
187200 llvm::divideCeil (M / config.innerMostMBlock , config.MThreads );
188201 uint32_t MNumInnerBlockPerBlock =
189202 llvm::divideCeil (config.MBlock , config.innerMostMBlock );
203+ assert (MNumInnerBlockPerBlock > 0 && " Invalid MNumInnerBlockPerBlock." );
190204 uint32_t MCost = MNumBlockPerThread % MNumInnerBlockPerBlock != 0 ||
191205 (M / config.innerMostNBlock % config.MThreads != 0 &&
192206 config.MBlock != config.innerMostMBlock );
193207 uint32_t NNumBlockPerThread =
194208 llvm::divideCeil (N / config.innerMostNBlock , config.NThreads );
195209 uint32_t NNumInnerBlockPerBlock =
196210 llvm::divideCeil (config.NBlock , config.innerMostNBlock );
211+ assert (NNumInnerBlockPerBlock > 0 && " Invalid NNumInnerBlockPerBlock." );
197212 uint32_t NCost = NNumBlockPerThread % NNumInnerBlockPerBlock != 0 ||
198213 (N / config.innerMostNBlock % config.NThreads != 0 &&
199214 config.NBlock != config.innerMostNBlock );
@@ -312,39 +327,28 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
312327 KBlockCandidates = innerMostKBlockCandidates;
313328 }
314329
315- // TODO: improve via multi threading or add more constraints to restrict the
316- // candidate size
330+ bool isVNNIMM2D =
331+ linalgx::isGenericPackedMatmulOp (root, linalgx::PackingType::VNNI_MM2D);
332+ // TODO: improve via multi threading or add more constraints to restrict
333+ // the candidate size
317334 for (uint32_t MThreads : MThreadsCandidates) {
318335 for (uint32_t NThreads : NThreadsCandidates) {
319336 for (uint32_t KThreads : KThreadsCandidates) {
320337 if (!validateThreads ({MThreads, NThreads, KThreads}, sysDesc))
321338 continue ;
322339 for (uint32_t MBlock : MBlockCandidates) {
323340 for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
324- if (MBlock % innerMostMBlock != 0 ||
325- (shape[0 ] % innerMostMBlock != 0 &&
326- !allowIndivisibleInnerblock))
327- continue ;
328341 for (uint32_t NBlock : NBlockCandidates) {
329342 for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
330- if (NBlock % innerMostNBlock != 0 ||
331- (shape[1 ] % innerMostNBlock != 0 &&
332- !allowIndivisibleInnerblock))
333- continue ;
334343 for (uint32_t KBlock : KBlockCandidates) {
335344 for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
336- // Require K % KBlock == 0 as dynamic bs is not supported
337- // now
338- if (KBlock % innerMostKBlock != 0 ||
339- ((shape[2 ] / KThreads % KBlock != 0 ||
340- shape[2 ] / KThreads % innerMostKBlock != 0 ) &&
341- !allowIndivisibleInnerblock))
342- continue ;
343345 MatmulConfig config{
344346 MThreads, NThreads, KThreads,
345347 MBlock, NBlock, KBlock,
346348 innerMostMBlock, innerMostNBlock, innerMostKBlock};
347- configs.push_back (config);
349+ if (validateConfig (config, shape,
350+ allowIndivisibleInnerblock, isVNNIMM2D))
351+ configs.push_back (config);
348352 }
349353 }
350354 }
@@ -393,12 +397,28 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393397 cfgItemCnt++;
394398 }
395399 }
396- if (validateConfig (config)) {
397- return cfgItemCnt == 9 ;
398- } else {
399- LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
400+ return cfgItemCnt == 9 ;
401+ }
402+
403+ bool readAndValidateConfig (MatmulConfig &config,
404+ const linalg::LinalgOp &linalgOp,
405+ ArrayRef<uint32_t > shape,
406+ bool allowIndivisibleInnerBlock) {
407+ SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
408+ bool fullConfig = readConfigFromAttrs (config, attrs);
409+ if (!fullConfig) {
410+ LLVM_DEBUG (llvm::dbgs () << " Missing fields in predefined config.\n " );
400411 return false ;
401412 }
413+ bool validConfig =
414+ validateConfig (config, shape, allowIndivisibleInnerBlock,
415+ linalgx::isGenericPackedMatmulOp (
416+ linalgOp, linalgx::PackingType::VNNI_MM2D));
417+ if (!validConfig) {
418+ LLVM_DEBUG (llvm::dbgs () << " Invalid predefined config.\n " );
419+ return false ;
420+ }
421+ return true ;
402422}
403423
404424// Analyze the workload and system description to generate the default config
@@ -482,12 +502,15 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
482502 << " M: " << M << " , N: " << N << " , K: " << K << " \n " );
483503
484504 // try to read the config from the attributes
485- SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
486- bool hasPredefinedConfig = readConfigFromAttrs (config, attrs);
505+ bool hasValidPredefinedConfig = readAndValidateConfig (
506+ config, linalgOp, SmallVector<uint32_t >{M, N, K},
507+ allowIndivisibleInnerBlock);
487508
488509 // if there is a given config, skip the cost model
489- if (!hasPredefinedConfig) {
490- LLVM_DEBUG (llvm::dbgs () << " No predefined config\n " );
510+ if (!hasValidPredefinedConfig) {
511+ LLVM_DEBUG (
512+ llvm::dbgs ()
513+ << " No valid predefined config. Setting with default config.\n " );
491514 // TODO: Could add a weight or priority for cost model
492515 SmallVector<std::tuple<CostModelFn, std::string, double >>
493516 costModelList = {
@@ -511,6 +534,11 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
511534 }
512535 if (!configCandidates.empty ())
513536 config = configCandidates[0 ];
537+
538+ assert (validateConfig (config, shape, allowIndivisibleInnerBlock,
539+ linalgx::isGenericPackedMatmulOp (
540+ root, linalgx::PackingType::VNNI_MM2D)) &&
541+ " config is invalid" );
514542 }
515543
516544 LLVM_DEBUG (llvm::dbgs ()
@@ -520,7 +548,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520548 hasConfig = true ;
521549 }
522550
523- assert (validateConfig (config) && " config is invalid" );
524551 return config;
525552}
526553} // namespace gc
0 commit comments