@@ -37,7 +37,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737 return ss;
3838}
3939
40- bool validateConfig (const MatmulConfig &cfg) {
40+ bool validateConfig (const MatmulConfig &cfg, ArrayRef< uint32_t > shape = {} ) {
4141 if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4242 cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4343 cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
@@ -47,6 +47,12 @@ bool validateConfig(const MatmulConfig &cfg) {
4747 cfg.NBlock % cfg.innerMostNBlock != 0 ||
4848 cfg.KBlock % cfg.innerMostKBlock != 0 )
4949 return false ;
50+ if (!shape.empty ()) {
51+ // KThreads will not shrink automatically
52+ // K is shape[2]
53+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
54+ return false ;
55+ }
5056 return true ;
5157}
5258
@@ -179,7 +185,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179185 ArrayRef<uint32_t > shape,
180186 const MatmulConfig &config,
181187 CPUTargetDescriptionAnalysis &sysDesc) {
182- assert (validateConfig (config) && " config is invalid" );
188+ assert (validateConfig (config, shape ) && " config is invalid" );
183189 assert (shape.size () >= 3 && " shape.size() should >= 3" );
184190 uint32_t M = shape[0 ], N = shape[1 ];
185191 double cost = 0 ;
@@ -361,7 +367,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
361367}
362368
363369// read the config from the attributes for tuning
364- bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
370+ bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371+ ArrayRef<uint32_t > shape) {
365372 size_t cfgItemCnt = 0 ;
366373 for (const auto &attr : attrs) {
367374 if (attr.getName () == " KBlock" ) {
@@ -393,7 +400,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393400 cfgItemCnt++;
394401 }
395402 }
396- if (validateConfig (config)) {
403+ if (validateConfig (config, shape )) {
397404 return cfgItemCnt == 9 ;
398405 } else {
399406 LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
@@ -483,7 +490,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
483490
484491 // try to read the config from the attributes
485492 SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
486- bool hasPredefinedConfig = readConfigFromAttrs (config, attrs);
493+ bool hasPredefinedConfig =
494+ readConfigFromAttrs (config, attrs, SmallVector<uint32_t >{M, N, K});
487495
488496 // if there is a given config, skip the cost model
489497 if (!hasPredefinedConfig) {
@@ -520,7 +528,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520528 hasConfig = true ;
521529 }
522530
523- assert (validateConfig (config) && " config is invalid" );
524531 return config;
525532}
526533} // namespace gc
0 commit comments