@@ -37,7 +37,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737 return ss;
3838}
3939
40- bool validateConfig (const MatmulConfig &cfg) {
40+ bool validateConfig (const MatmulConfig &cfg, ArrayRef< uint32_t > shape ) {
4141 if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4242 cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4343 cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
@@ -47,6 +47,12 @@ bool validateConfig(const MatmulConfig &cfg) {
4747 cfg.NBlock % cfg.innerMostNBlock != 0 ||
4848 cfg.KBlock % cfg.innerMostKBlock != 0 )
4949 return false ;
50+ if (!shape.empty ()) {
51+ // KThreads will not shrink automatically
52+ // K is shape[2]
53+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
54+ return false ;
55+ }
5056 return true ;
5157}
5258
@@ -179,7 +185,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179185 ArrayRef<uint32_t > shape,
180186 const MatmulConfig &config,
181187 CPUTargetDescriptionAnalysis &sysDesc) {
182- assert (validateConfig (config) && " config is invalid" );
188+ assert (validateConfig (config, shape ) && " config is invalid" );
183189 assert (shape.size () >= 3 && " shape.size() should >= 3" );
184190 uint32_t M = shape[0 ], N = shape[1 ];
185191 double cost = 0 ;
@@ -361,7 +367,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
361367}
362368
363369// read the config from the attributes for tuning
364- bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
370+ bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371+ ArrayRef<uint32_t > shape) {
365372 size_t cfgItemCnt = 0 ;
366373 for (const auto &attr : attrs) {
367374 if (attr.getName () == " KBlock" ) {
@@ -393,10 +400,15 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393400 cfgItemCnt++;
394401 }
395402 }
396- if (validateConfig (config)) {
397- return cfgItemCnt == 9 ;
398- } else {
399- LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
403+ if (cfgItemCnt != 9 ) {
404+ LLVM_DEBUG (llvm::dbgs () << " The predefined matmul config is incomplete. "
405+ " Default matmul config will be set.\n " );
406+ return false ;
407+ }
408+ if (validateConfig (config, shape))
409+ return true ;
410+ else {
411+ assert (0 && " config is invalid" );
400412 return false ;
401413 }
402414}
@@ -483,7 +495,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
483495
484496 // try to read the config from the attributes
485497 SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
486- bool hasPredefinedConfig = readConfigFromAttrs (config, attrs);
498+ bool hasPredefinedConfig =
499+ readConfigFromAttrs (config, attrs, SmallVector<uint32_t >{M, N, K});
487500
488501 // if there is a given config, skip the cost model
489502 if (!hasPredefinedConfig) {
@@ -511,6 +524,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
511524 }
512525 if (!configCandidates.empty ())
513526 config = configCandidates[0 ];
527+
528+ assert (validateConfig (config, shape) && " config is invalid" );
514529 }
515530
516531 LLVM_DEBUG (llvm::dbgs ()
@@ -520,7 +535,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520535 hasConfig = true ;
521536 }
522537
523- assert (validateConfig (config) && " config is invalid" );
524538 return config;
525539}
526540} // namespace gc
0 commit comments