@@ -37,22 +37,29 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737 return ss;
3838}
3939
40- bool validateConfig (const MatmulConfig &cfg, ArrayRef<uint32_t > shape) {
40+ bool validateConfig (const MatmulConfig &cfg, ArrayRef<uint32_t > shape,
41+ bool allowIndivisibleInnerblock, bool isVNNIMM2D) {
4142 if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4243 cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4344 cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
4445 cfg.innerMostKBlock <= 0 )
4546 return false ;
4647 if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
47- cfg.NBlock % cfg.innerMostNBlock != 0 ||
48- cfg.KBlock % cfg.innerMostKBlock != 0 )
48+ (shape[0 ] % cfg.innerMostMBlock != 0 && !allowIndivisibleInnerblock))
49+ return false ;
50+ if (cfg.NBlock % cfg.innerMostNBlock != 0 ||
51+ ((shape[1 ] % cfg.innerMostNBlock != 0 ) && !allowIndivisibleInnerblock) ||
52+ (shape[1 ] % cfg.NThreads != 0 && isVNNIMM2D &&
53+ cfg.NBlock != cfg.innerMostNBlock ))
54+ return false ;
55+ if (cfg.KBlock % cfg.innerMostKBlock != 0 ||
56+ ((shape[2 ] / cfg.KThreads % cfg.KBlock != 0 ||
57+ shape[2 ] / cfg.KThreads % cfg.innerMostKBlock != 0 ) &&
58+ !allowIndivisibleInnerblock))
59+ return false ;
60+ // KThreads will not shrink automatically
61+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
4962 return false ;
50- if (!shape.empty ()) {
51- // KThreads will not shrink automatically
52- // K is shape[2]
53- if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
54- return false ;
55- }
5663 return true ;
5764}
5865
@@ -185,7 +192,6 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
185192 ArrayRef<uint32_t > shape,
186193 const MatmulConfig &config,
187194 CPUTargetDescriptionAnalysis &sysDesc) {
188- assert (validateConfig (config, shape) && " config is invalid" );
189195 assert (shape.size () >= 3 && " shape.size() should >= 3" );
190196 uint32_t M = shape[0 ], N = shape[1 ];
191197 double cost = 0 ;
@@ -367,8 +373,7 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
367373}
368374
369375// read the config from the attributes for tuning
370- bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371- ArrayRef<uint32_t > shape) {
376+ bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
372377 size_t cfgItemCnt = 0 ;
373378 for (const auto &attr : attrs) {
374379 if (attr.getName () == " KBlock" ) {
@@ -400,17 +405,28 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
400405 cfgItemCnt++;
401406 }
402407 }
403- if (cfgItemCnt != 9 ) {
404- LLVM_DEBUG (llvm::dbgs () << " The predefined matmul config is incomplete. "
405- " Default matmul config will be set.\n " );
408+ return cfgItemCnt == 9 ;
409+ }
410+
411+ bool readAndValidateConfig (MatmulConfig &config,
412+ const linalg::LinalgOp &linalgOp,
413+ ArrayRef<uint32_t > shape,
414+ bool allowIndivisibleInnerBlock) {
415+ SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
416+ bool fullConfig = readConfigFromAttrs (config, attrs);
417+ if (!fullConfig) {
418+ LLVM_DEBUG (llvm::dbgs () << " Missing fields in predefined config.\n " );
406419 return false ;
407420 }
408- if (validateConfig (config, shape))
409- return true ;
410- else {
411- assert (0 && " config is invalid" );
421+ bool validConfig =
422+ validateConfig (config, shape, allowIndivisibleInnerBlock,
423+ linalgx::isGenericPackedMatmulOp (
424+ linalgOp, linalgx::PackingType::VNNI_MM2D));
425+ if (!validConfig) {
426+ LLVM_DEBUG (llvm::dbgs () << " Invalid predefined config.\n " );
412427 return false ;
413428 }
429+ return true ;
414430}
415431
416432// Analyze the workload and system description to generate the default config
@@ -494,13 +510,15 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
494510 << " M: " << M << " , N: " << N << " , K: " << K << " \n " );
495511
496512 // try to read the config from the attributes
497- SmallVector<NamedAttribute> attrs (linalgOp-> getAttrs ());
498- bool hasPredefinedConfig =
499- readConfigFromAttrs (config, attrs, SmallVector< uint32_t >{M, N, K} );
513+ bool hasValidPredefinedConfig = readAndValidateConfig (
514+ config, linalgOp, SmallVector< uint32_t >{M, N, K},
515+ allowIndivisibleInnerBlock );
500516
501517 // if there is a given config, skip the cost model
502- if (!hasPredefinedConfig) {
503- LLVM_DEBUG (llvm::dbgs () << " No predefined config\n " );
518+ if (!hasValidPredefinedConfig) {
519+ LLVM_DEBUG (
520+ llvm::dbgs ()
521+ << " No valid predefined config. Setting with default config.\n " );
504522 // TODO: Could add a weight or priority for cost model
505523 SmallVector<std::tuple<CostModelFn, std::string, double >>
506524 costModelList = {
@@ -525,7 +543,10 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
525543 if (!configCandidates.empty ())
526544 config = configCandidates[0 ];
527545
528- assert (validateConfig (config, shape) && " config is invalid" );
546+ assert (validateConfig (config, shape, allowIndivisibleInnerBlock,
547+ linalgx::isGenericPackedMatmulOp (
548+ root, linalgx::PackingType::VNNI_MM2D)) &&
549+ " config is invalid" );
529550 }
530551
531552 LLVM_DEBUG (llvm::dbgs ()
0 commit comments