3131#include " mlir/IR/TypeUtilities.h"
3232#include " mlir/Interfaces/FunctionInterfaces.h"
3333#include " mlir/Support/LogicalResult.h"
34+ #include " mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
35+ #include " mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
36+ #include " mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
37+ #include " mlir/Analysis/DataFlowFramework.h"
38+ #include " iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
39+ #include " llvm/ADT/DenseSet.h"
40+ #include " iree/compiler/Dialect/Util/IR/UtilOps.h"
3441
3542#define DEBUG_TYPE " iree-gpu-config-utils"
3643
@@ -653,7 +660,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
653660 ArrayRef<int64_t > bounds, ArrayRef<AffineMap> maps,
654661 ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
655662 bool isGemm, bool scaled, int64_t splitReductionTripCnt,
656- bool cPromoteIfPadding, bool hasExistingAccumulator = false ,
663+ bool cPromoteIfPadding, bool boundsUsingAnalysis, bool hasExistingAccumulator = false ,
657664 std::optional<ConvToIgemmInfo> convToIgemmInfo = std::nullopt ) {
658665 if (target.getWgp ().getMma ().empty ()) {
659666 return failure ();
@@ -969,7 +976,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
969976 : ArrayRef<Attribute>{};
970977 GPU::appendPromotedOperandsList (context, attrs, promotionList,
971978 promotionTypes);
972- if (!mustBeAligned || couldNeedPadding) {
979+ if (!mustBeAligned || couldNeedPadding || boundsUsingAnalysis ) {
973980 SmallVector<int64_t > paddingTileSizes = workgroupTileSizes;
974981
975982 // Initialize inner and outer padding sizes from reductionTileSizes.
@@ -1085,7 +1092,7 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
10851092 igemmLoopBounds, igemmContractionMaps, igemmOperands, target,
10861093 useDirectLoad, /* isGemm=*/ false ,
10871094 /* scaled=*/ false , splitReductionTripCnt,
1088- /* cPromoteIfPadding=*/ cPromoteIfPadding, hasExistingAccumulator,
1095+ /* cPromoteIfPadding=*/ cPromoteIfPadding, /* boundsUsingAnalysis= */ false , hasExistingAccumulator,
10891096 convToIgemmInfo);
10901097 if (failed (configAndWgSize)) {
10911098 return failure ();
@@ -1112,6 +1119,122 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
11121119 workgroupSize, targetSubgroupSize, pipelineConfig);
11131120}
11141121
1122+
1123+ static FailureOr<SmallVector<int64_t >>
1124+ getLoopBoundsWithRangeAnalysis (linalg::LinalgOp linalgOp,
1125+ mlir::FunctionOpInterface entryPoint) {
1126+ // Initialize DataFlowSolver for integer range analysis.
1127+ DataFlowSolver solver;
1128+ solver.load <dataflow::DeadCodeAnalysis>();
1129+ solver.load <dataflow::SparseConstantPropagation>();
1130+ solver.load <dataflow::IntegerRangeAnalysis>();
1131+
1132+ if (failed (solver.initializeAndRun (entryPoint))) {
1133+ return linalgOp.getStaticLoopRanges ();
1134+ }
1135+
1136+ SmallVector<int64_t > bounds = linalgOp.getStaticLoopRanges ();
1137+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
1138+
1139+ // Sentinel value used by IntegerRangeAnalysis when bounds are unknown.
1140+ constexpr int64_t unboundedSentinel = 9007199254740991 ;
1141+
1142+ // Helper to recursively collect index values from an operation.
1143+ // Uses a visited set instead of hardcoded depth limit.
1144+ std::function<void (Value, SmallVectorImpl<Value> &, DenseSet<Value> &)>
1145+ collectIndexValues = [&](Value value, SmallVectorImpl<Value> &indexValues,
1146+ DenseSet<Value> &visited) -> void {
1147+ // Use visited set to prevent infinite recursion.
1148+ if (!visited.insert (value).second )
1149+ return ;
1150+
1151+ if (value.getType ().isIndex ()) {
1152+ indexValues.push_back (value);
1153+ }
1154+
1155+ Operation *defOp = value.getDefiningOp ();
1156+ if (!defOp)
1157+ return ;
1158+
1159+ // Recursively traverse all operands.
1160+ for (Value operand : defOp->getOperands ()) {
1161+ if (operand.getType ().isIndex ()) {
1162+ indexValues.push_back (operand);
1163+ }
1164+ // Continue traversing for shaped types to find their dimension operands.
1165+ if (isa<ShapedType>(operand.getType ())) {
1166+ Operation *operandDef = operand.getDefiningOp ();
1167+ if (operandDef) {
1168+ for (Value v : operandDef->getOperands ()) {
1169+ if (v.getType ().isIndex ()) {
1170+ collectIndexValues (v, indexValues, visited);
1171+ }
1172+ }
1173+ }
1174+ }
1175+ }
1176+ };
1177+
1178+ for (auto [loopIdx, bound] : llvm::enumerate (bounds)) {
1179+ if (!ShapedType::isDynamic (bound)) {
1180+ continue ;
1181+ }
1182+
1183+ bool boundRefined = false ;
1184+
1185+ // Find operand and dimension that corresponds to this loop.
1186+ for (auto [operandIdx, operand] :
1187+ llvm::enumerate (linalgOp->getOperands ())) {
1188+ auto shapedType = dyn_cast<ShapedType>(operand.getType ());
1189+ if (!shapedType)
1190+ continue ;
1191+
1192+ AffineMap map = indexingMaps[operandIdx];
1193+ for (auto [dimIdx, expr] : llvm::enumerate (map.getResults ())) {
1194+ auto dimExpr = dyn_cast<AffineDimExpr>(expr);
1195+ if (!dimExpr || dimExpr.getPosition () != loopIdx)
1196+ continue ;
1197+ if (!ShapedType::isDynamic (shapedType.getDimSize (dimIdx)))
1198+ continue ;
1199+
1200+ // Collect all index values related to this operand by traversing use-def chain.
1201+ SmallVector<Value> indexValues;
1202+ DenseSet<Value> visited;
1203+ collectIndexValues (operand, indexValues, visited);
1204+
1205+ // Try each index value with getDynamicUpperBound.
1206+ for (Value indexValue : indexValues) {
1207+ FailureOr<int64_t > ub = getDynamicUpperBound (indexValue, solver);
1208+ if (succeeded (ub) && *ub > 0 ) {
1209+ // Filter out the unbounded sentinel.
1210+ if (*ub >= unboundedSentinel) {
1211+ continue ;
1212+ }
1213+
1214+ bounds[loopIdx] = *ub;
1215+ boundRefined = true ;
1216+ break ;
1217+ }
1218+ }
1219+
1220+ if (boundRefined)
1221+ break ;
1222+ }
1223+
1224+ if (boundRefined) {
1225+ break ;
1226+ }
1227+ }
1228+
1229+ // TODO: If we couldn't refine the bound, set it to the largest power of 2.
1230+ if (!boundRefined && ShapedType::isDynamic (bounds[loopIdx])) {
1231+ bounds[loopIdx] = 1 << 20 ;
1232+ }
1233+ }
1234+
1235+ return bounds;
1236+ }
1237+
11151238LogicalResult setMatmulLoweringConfig (IREE::GPU::TargetAttr target,
11161239 mlir::FunctionOpInterface entryPoint,
11171240 Operation *op, bool useDirectLoad) {
@@ -1122,7 +1245,20 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11221245 return failure ();
11231246 }
11241247
1125- SmallVector<int64_t > bounds = linalgOp.getStaticLoopRanges ();
1248+ // SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
1249+ // Use IntegerRangeAnalysis to get better bounds for dynamic shapes
1250+ bool boundsUsingAnalysis = false ;
1251+ FailureOr<SmallVector<int64_t >> maybeBounds =
1252+ getLoopBoundsWithRangeAnalysis (linalgOp, entryPoint);
1253+ SmallVector<int64_t > bounds;
1254+ if (succeeded (maybeBounds)) {
1255+ boundsUsingAnalysis = true ;
1256+ bounds = std::move (*maybeBounds);
1257+ } else {
1258+ // Fallback to static loop ranges if analysis fails completely.
1259+ bounds = linalgOp.getStaticLoopRanges ();
1260+ LDBG () << " Fallback to static loop ranges: [" ;
1261+ }
11261262 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray ();
11271263 SmallVector<Value> operands (linalgOp->getOperands ());
11281264
@@ -1143,7 +1279,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11431279 FailureOr<std::pair<LoweringConfigAttr, int64_t >> configAndWgSize =
11441280 getMatmulOrIGEMMLoweringConfigAndWorkgroupSize (
11451281 bounds, maps, operands, target, useDirectLoad, /* isGemm=*/ true ,
1146- /* scaled=*/ false , splitReductionTripCnt, cPromoteIfPadding,
1282+ /* scaled=*/ false , splitReductionTripCnt, cPromoteIfPadding, boundsUsingAnalysis,
11471283 hasExistingAccumulator);
11481284
11491285 // TODO (muzasyed) : add generalization for scaled and nonscaled versions of
@@ -1154,7 +1290,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11541290 useDirectLoad = true ;
11551291 configAndWgSize = getMatmulOrIGEMMLoweringConfigAndWorkgroupSize (
11561292 bounds, maps, operands, target, useDirectLoad, /* isGemm=*/ true ,
1157- /* scaled=*/ true , splitReductionTripCnt, cPromoteIfPadding,
1293+ /* scaled=*/ true , splitReductionTripCnt, cPromoteIfPadding, boundsUsingAnalysis,
11581294 hasExistingAccumulator);
11591295 }
11601296
0 commit comments