diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index 2b275f246..0b5d390dd 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -33,6 +33,9 @@ struct MatmulConfig { uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; }; +bool validateConfig(const MatmulConfig &cfg, ArrayRef shape, + bool allowIndivisibleInnerblock, bool isVNNIMM2D); + enum DimType { Batch, M, N, K }; // Extract the index of the given DimType in the DimType list diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index e67cc9fe2..65758ee2e 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -37,15 +37,29 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, return ss; } -bool validateConfig(const MatmulConfig &cfg) { +bool validateConfig(const MatmulConfig &cfg, ArrayRef shape, + bool allowIndivisibleInnerblock, bool isVNNIMM2D) { if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 || cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 || cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 || cfg.innerMostKBlock <= 0) return false; if (cfg.MBlock % cfg.innerMostMBlock != 0 || - cfg.NBlock % cfg.innerMostNBlock != 0 || - cfg.KBlock % cfg.innerMostKBlock != 0) + (shape[0] % cfg.innerMostMBlock != 0 && !allowIndivisibleInnerblock)) + return false; + if (cfg.NBlock % cfg.innerMostNBlock != 0 || + ((shape[1] % cfg.innerMostNBlock != 0) && !allowIndivisibleInnerblock) || + (shape[1] % cfg.NThreads != 0 && isVNNIMM2D && + cfg.NBlock != cfg.innerMostNBlock)) + return false; + // Require K % KBlock == 0 as brgemm dynamic bs is not supported now + if (cfg.KBlock % cfg.innerMostKBlock != 0 || + ((shape[2] / cfg.KThreads % cfg.KBlock != 0 || + shape[2] / cfg.KThreads % cfg.innerMostKBlock != 0) && + !allowIndivisibleInnerblock)) + return false; + // KThreads will not shrink automatically + if (llvm::divideCeil(shape[2], cfg.KBlock) < cfg.KThreads) return false; return true; } @@ -179,7 +193,6 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, CPUTargetDescriptionAnalysis &sysDesc) { - assert(validateConfig(config) && "config is invalid"); assert(shape.size() >= 3 && "shape.size() should >= 3"); uint32_t M = shape[0], N = shape[1]; double cost = 0; @@ -187,6 +200,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp, llvm::divideCeil(M / config.innerMostMBlock, config.MThreads); uint32_t MNumInnerBlockPerBlock = llvm::divideCeil(config.MBlock, config.innerMostMBlock); + assert(MNumInnerBlockPerBlock > 0 && "Invalid MNumInnerBlockPerBlock."); uint32_t MCost = MNumBlockPerThread % MNumInnerBlockPerBlock != 0 || (M / config.innerMostNBlock % config.MThreads != 0 && config.MBlock != config.innerMostMBlock); @@ -194,6 +208,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp, llvm::divideCeil(N / config.innerMostNBlock, config.NThreads); uint32_t NNumInnerBlockPerBlock = llvm::divideCeil(config.NBlock, config.innerMostNBlock); + assert(NNumInnerBlockPerBlock > 0 && "Invalid NNumInnerBlockPerBlock."); uint32_t NCost = NNumBlockPerThread % NNumInnerBlockPerBlock != 0 || (N / config.innerMostNBlock % config.NThreads != 0 && config.NBlock != config.innerMostNBlock); @@ -312,8 +327,10 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc, KBlockCandidates = innerMostKBlockCandidates; } - // TODO: improve via multi threading or add more constraints to restrict the - // candidate size + bool isVNNIMM2D = + linalgx::isGenericPackedMatmulOp(root, linalgx::PackingType::VNNI_MM2D); + // TODO: improve via multi threading or add more constraints to restrict + // the candidate size for (uint32_t MThreads : MThreadsCandidates) { for (uint32_t NThreads : NThreadsCandidates) { for (uint32_t KThreads : KThreadsCandidates) { @@ -321,30 +338,17 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc, continue; for (uint32_t MBlock : MBlockCandidates) { for (uint32_t innerMostMBlock : innerMostMBlockCandidates) { - if (MBlock % innerMostMBlock != 0 || - (shape[0] % innerMostMBlock != 0 && - !allowIndivisibleInnerblock)) - continue; for (uint32_t NBlock : NBlockCandidates) { for (uint32_t innerMostNBlock : innerMostNBlockCandidates) { - if (NBlock % innerMostNBlock != 0 || - (shape[1] % innerMostNBlock != 0 && - !allowIndivisibleInnerblock)) - continue; for (uint32_t KBlock : KBlockCandidates) { for (uint32_t innerMostKBlock : innerMostKBlockCandidates) { - // Require K % KBlock == 0 as dynamic bs is not supported - // now - if (KBlock % innerMostKBlock != 0 || - ((shape[2] / KThreads % KBlock != 0 || - shape[2] / KThreads % innerMostKBlock != 0) && - !allowIndivisibleInnerblock)) - continue; MatmulConfig config{ MThreads, NThreads, KThreads, MBlock, NBlock, KBlock, innerMostMBlock, innerMostNBlock, innerMostKBlock}; - configs.push_back(config); + if (validateConfig(config, shape, + allowIndivisibleInnerblock, isVNNIMM2D)) + configs.push_back(config); } } } @@ -393,12 +397,28 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { cfgItemCnt++; } } - if (validateConfig(config)) { - return cfgItemCnt == 9; - } else { - LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n"); + return cfgItemCnt == 9; +} + +bool readAndValidateConfig(MatmulConfig &config, + const linalg::LinalgOp &linalgOp, + ArrayRef shape, + bool allowIndivisibleInnerBlock) { + SmallVector attrs(linalgOp->getAttrs()); + bool fullConfig = readConfigFromAttrs(config, attrs); + if (!fullConfig) { + LLVM_DEBUG(llvm::dbgs() << "Missing fields in predefined config.\n"); return false; } + bool validConfig = + validateConfig(config, shape, allowIndivisibleInnerBlock, + linalgx::isGenericPackedMatmulOp( + linalgOp, linalgx::PackingType::VNNI_MM2D)); + if (!validConfig) { + LLVM_DEBUG(llvm::dbgs() << "Invalid predefined config.\n"); + return false; + } + return true; } // Analyze the workload and system description to generate the default config @@ -482,12 +502,15 @@ MatmulConfig MatmulConfigAnalysis::getConfig() { << "M: " << M << ", N: " << N << ", K: " << K << "\n"); // try to read the config from the attributes - SmallVector attrs(linalgOp->getAttrs()); - bool hasPredefinedConfig = readConfigFromAttrs(config, attrs); + bool hasValidPredefinedConfig = readAndValidateConfig( + config, linalgOp, SmallVector{M, N, K}, + allowIndivisibleInnerBlock); // if there is a given config, skip the cost model - if (!hasPredefinedConfig) { - LLVM_DEBUG(llvm::dbgs() << "No predefined config\n"); + if (!hasValidPredefinedConfig) { + LLVM_DEBUG( + llvm::dbgs() + << "No valid predefined config. Setting with default config.\n"); // TODO: Could add a weight or priority for cost model SmallVector> costModelList = { @@ -511,6 +534,11 @@ MatmulConfig MatmulConfigAnalysis::getConfig() { } if (!configCandidates.empty()) config = configCandidates[0]; + + assert(validateConfig(config, shape, allowIndivisibleInnerBlock, + linalgx::isGenericPackedMatmulOp( + root, linalgx::PackingType::VNNI_MM2D)) && + "config is invalid"); } LLVM_DEBUG(llvm::dbgs() @@ -520,7 +548,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() { hasConfig = true; } - assert(validateConfig(config) && "config is invalid"); return config; } } // namespace gc