Skip to content

Commit 19f6438

Browse files
Added function to determine bounds of dynamic dims in GEMMs using InterRangeAnalysis in TAF
Signed-off-by: Yash Deshpande <[email protected]>
1 parent 2ffd825 commit 19f6438

File tree

3 files changed

+210
-7
lines changed

3 files changed

+210
-7
lines changed

compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ static llvm::cl::opt<bool> clEnableBlockedMatmuls(
2525
"iree-codegen-block-dynamic-dimensions-of-contractions",
2626
llvm::cl::desc("developer flag to gaurd blocking dynamic dimensions of "
2727
"contraction-like ops"),
28-
llvm::cl::Hidden, llvm::cl::init(true));
28+
llvm::cl::Hidden, llvm::cl::init(false));
2929

3030
namespace mlir::iree_compiler {
3131

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 142 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
#include "mlir/IR/TypeUtilities.h"
3232
#include "mlir/Interfaces/FunctionInterfaces.h"
3333
#include "mlir/Support/LogicalResult.h"
34+
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
35+
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
36+
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
37+
#include "mlir/Analysis/DataFlowFramework.h"
38+
#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
39+
#include "llvm/ADT/DenseSet.h"
40+
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
3441

3542
#define DEBUG_TYPE "iree-gpu-config-utils"
3643

@@ -653,7 +660,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
653660
ArrayRef<int64_t> bounds, ArrayRef<AffineMap> maps,
654661
ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
655662
bool isGemm, bool scaled, int64_t splitReductionTripCnt,
656-
bool cPromoteIfPadding, bool hasExistingAccumulator = false,
663+
bool cPromoteIfPadding, bool boundsUsingAnalysis, bool hasExistingAccumulator = false,
657664
std::optional<ConvToIgemmInfo> convToIgemmInfo = std::nullopt) {
658665
if (target.getWgp().getMma().empty()) {
659666
return failure();
@@ -969,7 +976,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
969976
: ArrayRef<Attribute>{};
970977
GPU::appendPromotedOperandsList(context, attrs, promotionList,
971978
promotionTypes);
972-
if (!mustBeAligned || couldNeedPadding) {
979+
if (!mustBeAligned || couldNeedPadding || boundsUsingAnalysis) {
973980
SmallVector<int64_t> paddingTileSizes = workgroupTileSizes;
974981

975982
// Initialize inner and outer padding sizes from reductionTileSizes.
@@ -1085,7 +1092,7 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
10851092
igemmLoopBounds, igemmContractionMaps, igemmOperands, target,
10861093
useDirectLoad, /*isGemm=*/false,
10871094
/*scaled=*/false, splitReductionTripCnt,
1088-
/*cPromoteIfPadding=*/cPromoteIfPadding, hasExistingAccumulator,
1095+
/*cPromoteIfPadding=*/cPromoteIfPadding, /*boundsUsingAnalysis=*/ false, hasExistingAccumulator,
10891096
convToIgemmInfo);
10901097
if (failed(configAndWgSize)) {
10911098
return failure();
@@ -1112,6 +1119,122 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
11121119
workgroupSize, targetSubgroupSize, pipelineConfig);
11131120
}
11141121

1122+
1123+
static FailureOr<SmallVector<int64_t>>
1124+
getLoopBoundsWithRangeAnalysis(linalg::LinalgOp linalgOp,
1125+
mlir::FunctionOpInterface entryPoint) {
1126+
// Initialize DataFlowSolver for integer range analysis.
1127+
DataFlowSolver solver;
1128+
solver.load<dataflow::DeadCodeAnalysis>();
1129+
solver.load<dataflow::SparseConstantPropagation>();
1130+
solver.load<dataflow::IntegerRangeAnalysis>();
1131+
1132+
if (failed(solver.initializeAndRun(entryPoint))) {
1133+
return linalgOp.getStaticLoopRanges();
1134+
}
1135+
1136+
SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
1137+
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
1138+
1139+
// Sentinel value used by IntegerRangeAnalysis when bounds are unknown.
1140+
constexpr int64_t unboundedSentinel = 9007199254740991;
1141+
1142+
// Helper to recursively collect index values from an operation.
1143+
// Uses a visited set instead of hardcoded depth limit.
1144+
std::function<void(Value, SmallVectorImpl<Value> &, DenseSet<Value> &)>
1145+
collectIndexValues = [&](Value value, SmallVectorImpl<Value> &indexValues,
1146+
DenseSet<Value> &visited) -> void {
1147+
// Use visited set to prevent infinite recursion.
1148+
if (!visited.insert(value).second)
1149+
return;
1150+
1151+
if (value.getType().isIndex()) {
1152+
indexValues.push_back(value);
1153+
}
1154+
1155+
Operation *defOp = value.getDefiningOp();
1156+
if (!defOp)
1157+
return;
1158+
1159+
// Recursively traverse all operands.
1160+
for (Value operand : defOp->getOperands()) {
1161+
if (operand.getType().isIndex()) {
1162+
indexValues.push_back(operand);
1163+
}
1164+
// Continue traversing for shaped types to find their dimension operands.
1165+
if (isa<ShapedType>(operand.getType())) {
1166+
Operation *operandDef = operand.getDefiningOp();
1167+
if (operandDef) {
1168+
for (Value v : operandDef->getOperands()) {
1169+
if (v.getType().isIndex()) {
1170+
collectIndexValues(v, indexValues, visited);
1171+
}
1172+
}
1173+
}
1174+
}
1175+
}
1176+
};
1177+
1178+
for (auto [loopIdx, bound] : llvm::enumerate(bounds)) {
1179+
if (!ShapedType::isDynamic(bound)) {
1180+
continue;
1181+
}
1182+
1183+
bool boundRefined = false;
1184+
1185+
// Find operand and dimension that corresponds to this loop.
1186+
for (auto [operandIdx, operand] :
1187+
llvm::enumerate(linalgOp->getOperands())) {
1188+
auto shapedType = dyn_cast<ShapedType>(operand.getType());
1189+
if (!shapedType)
1190+
continue;
1191+
1192+
AffineMap map = indexingMaps[operandIdx];
1193+
for (auto [dimIdx, expr] : llvm::enumerate(map.getResults())) {
1194+
auto dimExpr = dyn_cast<AffineDimExpr>(expr);
1195+
if (!dimExpr || dimExpr.getPosition() != loopIdx)
1196+
continue;
1197+
if (!ShapedType::isDynamic(shapedType.getDimSize(dimIdx)))
1198+
continue;
1199+
1200+
// Collect all index values related to this operand by traversing use-def chain.
1201+
SmallVector<Value> indexValues;
1202+
DenseSet<Value> visited;
1203+
collectIndexValues(operand, indexValues, visited);
1204+
1205+
// Try each index value with getDynamicUpperBound.
1206+
for (Value indexValue : indexValues) {
1207+
FailureOr<int64_t> ub = getDynamicUpperBound(indexValue, solver);
1208+
if (succeeded(ub) && *ub > 0) {
1209+
// Filter out the unbounded sentinel.
1210+
if (*ub >= unboundedSentinel) {
1211+
continue;
1212+
}
1213+
1214+
bounds[loopIdx] = *ub;
1215+
boundRefined = true;
1216+
break;
1217+
}
1218+
}
1219+
1220+
if (boundRefined)
1221+
break;
1222+
}
1223+
1224+
if (boundRefined) {
1225+
break;
1226+
}
1227+
}
1228+
1229+
// TODO: If we couldn't refine the bound, set it to the largest power of 2.
1230+
if (!boundRefined && ShapedType::isDynamic(bounds[loopIdx])) {
1231+
bounds[loopIdx] = 1 << 20;
1232+
}
1233+
}
1234+
1235+
return bounds;
1236+
}
1237+
11151238
LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11161239
mlir::FunctionOpInterface entryPoint,
11171240
Operation *op, bool useDirectLoad) {
@@ -1122,7 +1245,20 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11221245
return failure();
11231246
}
11241247

1125-
SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
1248+
// SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
1249+
// Use IntegerRangeAnalysis to get better bounds for dynamic shapes
1250+
bool boundsUsingAnalysis = false;
1251+
FailureOr<SmallVector<int64_t>> maybeBounds =
1252+
getLoopBoundsWithRangeAnalysis(linalgOp, entryPoint);
1253+
SmallVector<int64_t> bounds;
1254+
if (succeeded(maybeBounds)) {
1255+
boundsUsingAnalysis = true;
1256+
bounds = std::move(*maybeBounds);
1257+
} else {
1258+
// Fallback to static loop ranges if analysis fails completely.
1259+
bounds = linalgOp.getStaticLoopRanges();
1260+
LDBG() << "Fallback to static loop ranges: [";
1261+
}
11261262
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
11271263
SmallVector<Value> operands(linalgOp->getOperands());
11281264

@@ -1143,7 +1279,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11431279
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
11441280
getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
11451281
bounds, maps, operands, target, useDirectLoad, /*isGemm=*/true,
1146-
/*scaled=*/false, splitReductionTripCnt, cPromoteIfPadding,
1282+
/*scaled=*/false, splitReductionTripCnt, cPromoteIfPadding, boundsUsingAnalysis,
11471283
hasExistingAccumulator);
11481284

11491285
// TODO (muzasyed) : add generalization for scaled and nonscaled versions of
@@ -1154,7 +1290,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11541290
useDirectLoad = true;
11551291
configAndWgSize = getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
11561292
bounds, maps, operands, target, useDirectLoad, /*isGemm=*/true,
1157-
/*scaled=*/true, splitReductionTripCnt, cPromoteIfPadding,
1293+
/*scaled=*/true, splitReductionTripCnt, cPromoteIfPadding, boundsUsingAnalysis,
11581294
hasExistingAccumulator);
11591295
}
11601296

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,3 +1451,70 @@ hal.executable public @multi_result_index_generic_with_scatter_fusion {
14511451
// CHECK: vector.transfer_write
14521452
// CHECK: vector.transfer_write
14531453
// CHECK: iree_linalg_ext.scatter
1454+
1455+
// -----
1456+
1457+
// Test dynamic matmul with util.assume.int providing bounds for range analysis.
1458+
// The getLoopBoundsWithRangeAnalysis function uses IntegerRangeAnalysis to infer
1459+
// the upper bound from util.assume.int and select appropriate tile sizes.
1460+
1461+
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
1462+
#hal.pipeline.binding<storage_buffer>,
1463+
#hal.pipeline.binding<storage_buffer>,
1464+
#hal.pipeline.binding<storage_buffer>
1465+
]>
1466+
#config = #iree_gpu.lowering_config<{
1467+
workgroup = [128, 128, 0],
1468+
reduction = [0, 0, 4],
1469+
subgroup = [4, 4],
1470+
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
1471+
promote_operands = [0, 1],
1472+
padding = [128, 128, 16]
1473+
}>
1474+
hal.executable public @main {
1475+
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
1476+
hal.executable.export public @matmul_dynamic_m_with_assume ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) ->
1477+
(index, index, index) {
1478+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
1479+
hal.return %x, %y, %z : index, index, index
1480+
}
1481+
builtin.module {
1482+
func.func @matmul_dynamic_m_with_assume()
1483+
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1]
1484+
subgroup_size = 64>} {
1485+
%cst = arith.constant 0.000000e+00 : f32
1486+
%c0 = arith.constant 0 : index
1487+
%dim = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
1488+
%m = util.assume.int %dim<umin = 0, umax = 1024, udiv = 16> : index
1489+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2048xf32>>{%m}
1490+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2048x4096xf32>>
1491+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%m}
1492+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%m, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2048xf32>>{%m} -> tensor<?x2048xf32>
1493+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2048, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2048x4096xf32>> -> tensor<2048x4096xf32>
1494+
%5 = tensor.empty(%m) : tensor<?x4096xf32>
1495+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
1496+
%7 = linalg.matmul {lowering_config = #config}
1497+
ins(%3, %4 : tensor<?x2048xf32>, tensor<2048x4096xf32>)
1498+
outs(%6 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
1499+
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [%m, 4096], strides = [1, 1] : tensor<?x4096xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%m}
1500+
return
1501+
}
1502+
}
1503+
}
1504+
}
1505+
1506+
// CHECK-LABEL: func @matmul_dynamic_m_with_assume
1507+
// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
1508+
// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
1509+
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
1510+
// CHECK-DAG: memref.alloc() : memref<16x130xf32, #gpu.address_space<workgroup>>
1511+
// CHECK-DAG: memref.alloc() : memref<128x18xf32, #gpu.address_space<workgroup>>
1512+
// CHECK: scf.forall ({{.*}}) in (%{{.+}}, 32) {
1513+
// CHECK: scf.for {{.*}} = %c0 to %c512 step %c4 {{.*}} -> (vector<4x4x4x1xf32>)
1514+
// CHECK: gpu.barrier
1515+
// CHECK: vector.transfer_read
1516+
// CHECK: vector.transfer_write
1517+
// CHECK: gpu.barrier
1518+
// CHECK-COUNT-64: amdgpu.mfma 16x16x4
1519+
// CHECK: scf.yield
1520+
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

0 commit comments

Comments
 (0)