Skip to content

Commit d1b18ea

Browse files
Added function to determine bounds of dynamic dims in GEMMs using InterRangeAnalysis in TAF
Signed-off-by: Yash Deshpande <[email protected]>
1 parent 2ffd825 commit d1b18ea

File tree

3 files changed

+160
-6
lines changed

3 files changed

+160
-6
lines changed

compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ static llvm::cl::opt<bool> clEnableBlockedMatmuls(
2525
"iree-codegen-block-dynamic-dimensions-of-contractions",
2626
llvm::cl::desc("developer flag to gaurd blocking dynamic dimensions of "
2727
"contraction-like ops"),
28-
llvm::cl::Hidden, llvm::cl::init(true));
28+
llvm::cl::Hidden, llvm::cl::init(false));
2929

3030
namespace mlir::iree_compiler {
3131

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
88

99
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
10+
#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
1011
#include "iree/compiler/Codegen/Common/TileInferenceUtils.h"
1112
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1213
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
@@ -20,10 +21,16 @@
2021
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
2122
#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
2223
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
24+
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
25+
#include "llvm/ADT/DenseSet.h"
2326
#include "llvm/ADT/STLExtras.h"
2427
#include "llvm/Support/Casting.h"
2528
#include "llvm/Support/DebugLog.h"
2629
#include "llvm/Support/InterleavedRange.h"
30+
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
31+
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
32+
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
33+
#include "mlir/Analysis/DataFlowFramework.h"
2734
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2835
#include "mlir/Dialect/Utils/IndexingUtils.h"
2936
#include "mlir/IR/Attributes.h"
@@ -45,6 +52,14 @@ namespace mlir::iree_compiler::IREE::GPU {
4552
constexpr int64_t kCacheLineSizeBits = 128 * 8;
4653
constexpr int64_t kPreferredCopyNumBits = 128;
4754

55+
// Sentinel value used by IntegerRangeAnalysis when bounds are unknown.
56+
static constexpr uint64_t MAX_DIM_VALUE = (static_cast<uint64_t>(1) << 53) - 1;
57+
58+
// Fallback bound when IntegerRangeAnalysis cannot determine the actual value.
59+
// Kept small (2^14) to avoid int64_t overflow when dimensions are multiplied
60+
// together in heuristic calculations.
61+
static constexpr uint64_t MAX_BOUND_VALUE = static_cast<uint64_t>(1) << 14;
62+
4863
//===----------------------------------------------------------------------===//
4964
// Lowering Config Selection
5065
//===----------------------------------------------------------------------===//
@@ -653,7 +668,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
653668
ArrayRef<int64_t> bounds, ArrayRef<AffineMap> maps,
654669
ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
655670
bool isGemm, bool scaled, int64_t splitReductionTripCnt,
656-
bool cPromoteIfPadding, bool hasExistingAccumulator = false,
671+
bool cPromoteIfPadding, bool boundsUsingAnalysis,
672+
bool hasExistingAccumulator = false,
657673
std::optional<ConvToIgemmInfo> convToIgemmInfo = std::nullopt) {
658674
if (target.getWgp().getMma().empty()) {
659675
return failure();
@@ -969,7 +985,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
969985
: ArrayRef<Attribute>{};
970986
GPU::appendPromotedOperandsList(context, attrs, promotionList,
971987
promotionTypes);
972-
if (!mustBeAligned || couldNeedPadding) {
988+
if (!mustBeAligned || couldNeedPadding || boundsUsingAnalysis) {
973989
SmallVector<int64_t> paddingTileSizes = workgroupTileSizes;
974990

975991
// Initialize inner and outer padding sizes from reductionTileSizes.
@@ -1085,7 +1101,8 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
10851101
igemmLoopBounds, igemmContractionMaps, igemmOperands, target,
10861102
useDirectLoad, /*isGemm=*/false,
10871103
/*scaled=*/false, splitReductionTripCnt,
1088-
/*cPromoteIfPadding=*/cPromoteIfPadding, hasExistingAccumulator,
1104+
/*cPromoteIfPadding=*/cPromoteIfPadding,
1105+
/*boundsUsingAnalysis=*/false, hasExistingAccumulator,
10891106
convToIgemmInfo);
10901107
if (failed(configAndWgSize)) {
10911108
return failure();
@@ -1112,6 +1129,68 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
11121129
workgroupSize, targetSubgroupSize, pipelineConfig);
11131130
}
11141131

1132+
static FailureOr<SmallVector<int64_t>>
1133+
getLoopBoundsWithRangeAnalysis(linalg::LinalgOp linalgOp,
1134+
mlir::FunctionOpInterface entryPoint) {
1135+
// Use TensorDynamicDimAnalysis for cleaner range queries.
1136+
TensorDynamicDimAnalysis dynamicDimAnalysis(entryPoint);
1137+
if (failed(dynamicDimAnalysis.run())) {
1138+
return linalgOp.getStaticLoopRanges();
1139+
}
1140+
1141+
SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
1142+
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
1143+
1144+
for (auto [loopIdx, bound] : llvm::enumerate(bounds)) {
1145+
if (!ShapedType::isDynamic(bound)) {
1146+
continue;
1147+
}
1148+
1149+
bool boundRefined = false;
1150+
1151+
// Find operand and dimension that corresponds to this loop.
1152+
for (auto [operandIdx, operand] :
1153+
llvm::enumerate(linalgOp->getOperands())) {
1154+
auto shapedType = dyn_cast<ShapedType>(operand.getType());
1155+
if (!shapedType) {
1156+
continue;
1157+
}
1158+
1159+
AffineMap map = indexingMaps[operandIdx];
1160+
for (auto [dimIdx, expr] : llvm::enumerate(map.getResults())) {
1161+
auto dimExpr = dyn_cast<AffineDimExpr>(expr);
1162+
if (!dimExpr || dimExpr.getPosition() != loopIdx) {
1163+
continue;
1164+
}
1165+
if (!ShapedType::isDynamic(shapedType.getDimSize(dimIdx))) {
1166+
continue;
1167+
}
1168+
1169+
// Use TensorDynamicDimAnalysis to get range info directly.
1170+
if (auto range = dynamicDimAnalysis.getRangeInfo(operand, dimIdx)) {
1171+
int64_t ub = range->smax().getSExtValue();
1172+
if (ub > 0 && ub < MAX_DIM_VALUE) {
1173+
bounds[loopIdx] = ub;
1174+
boundRefined = true;
1175+
break;
1176+
}
1177+
}
1178+
}
1179+
1180+
if (boundRefined) {
1181+
break;
1182+
}
1183+
}
1184+
1185+
// If we couldn't refine the bound, set it to a large value.
1186+
if (!boundRefined && ShapedType::isDynamic(bounds[loopIdx])) {
1187+
bounds[loopIdx] = MAX_BOUND_VALUE;
1188+
}
1189+
}
1190+
1191+
return bounds;
1192+
}
1193+
11151194
LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11161195
mlir::FunctionOpInterface entryPoint,
11171196
Operation *op, bool useDirectLoad) {
@@ -1122,7 +1201,15 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11221201
return failure();
11231202
}
11241203

1204+
// Use IntegerRangeAnalysis to get better bounds for dynamic shapes.
1205+
bool boundsUsingAnalysis = false;
1206+
FailureOr<SmallVector<int64_t>> maybeBounds =
1207+
getLoopBoundsWithRangeAnalysis(linalgOp, entryPoint);
11251208
SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
1209+
if (succeeded(maybeBounds) && (maybeBounds != bounds)) {
1210+
boundsUsingAnalysis = true;
1211+
bounds = std::move(*maybeBounds);
1212+
}
11261213
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
11271214
SmallVector<Value> operands(linalgOp->getOperands());
11281215

@@ -1144,7 +1231,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11441231
getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
11451232
bounds, maps, operands, target, useDirectLoad, /*isGemm=*/true,
11461233
/*scaled=*/false, splitReductionTripCnt, cPromoteIfPadding,
1147-
hasExistingAccumulator);
1234+
boundsUsingAnalysis, hasExistingAccumulator);
11481235

11491236
// TODO (muzasyed) : add generalization for scaled and nonscaled versions of
11501237
// matmul lowering.
@@ -1155,7 +1242,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11551242
configAndWgSize = getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
11561243
bounds, maps, operands, target, useDirectLoad, /*isGemm=*/true,
11571244
/*scaled=*/true, splitReductionTripCnt, cPromoteIfPadding,
1158-
hasExistingAccumulator);
1245+
boundsUsingAnalysis, hasExistingAccumulator);
11591246
}
11601247

11611248
if (failed(configAndWgSize)) {

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,3 +1451,70 @@ hal.executable public @multi_result_index_generic_with_scatter_fusion {
14511451
// CHECK: vector.transfer_write
14521452
// CHECK: vector.transfer_write
14531453
// CHECK: iree_linalg_ext.scatter
1454+
1455+
// -----
1456+
1457+
// Test dynamic matmul with util.assume.int providing bounds for range analysis.
1458+
// The getLoopBoundsWithRangeAnalysis function uses IntegerRangeAnalysis to infer
1459+
// the upper bound from util.assume.int and select appropriate tile sizes.
1460+
1461+
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
1462+
#hal.pipeline.binding<storage_buffer>,
1463+
#hal.pipeline.binding<storage_buffer>,
1464+
#hal.pipeline.binding<storage_buffer>
1465+
]>
1466+
#config = #iree_gpu.lowering_config<{
1467+
workgroup = [128, 128, 0],
1468+
reduction = [0, 0, 4],
1469+
subgroup = [4, 4],
1470+
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
1471+
promote_operands = [0, 1],
1472+
padding = [128, 128, 16]
1473+
}>
1474+
hal.executable public @main {
1475+
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
1476+
hal.executable.export public @matmul_dynamic_m_with_assume ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) ->
1477+
(index, index, index) {
1478+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
1479+
hal.return %x, %y, %z : index, index, index
1480+
}
1481+
builtin.module {
1482+
func.func @matmul_dynamic_m_with_assume()
1483+
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1]
1484+
subgroup_size = 64>} {
1485+
%cst = arith.constant 0.000000e+00 : f32
1486+
%c0 = arith.constant 0 : index
1487+
%dim = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
1488+
%m = util.assume.int %dim<umin = 0, umax = 1024, udiv = 16> : index
1489+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2048xf32>>{%m}
1490+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2048x4096xf32>>
1491+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%m}
1492+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%m, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2048xf32>>{%m} -> tensor<?x2048xf32>
1493+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2048, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2048x4096xf32>> -> tensor<2048x4096xf32>
1494+
%5 = tensor.empty(%m) : tensor<?x4096xf32>
1495+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
1496+
%7 = linalg.matmul {lowering_config = #config}
1497+
ins(%3, %4 : tensor<?x2048xf32>, tensor<2048x4096xf32>)
1498+
outs(%6 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
1499+
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [%m, 4096], strides = [1, 1] : tensor<?x4096xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%m}
1500+
return
1501+
}
1502+
}
1503+
}
1504+
}
1505+
1506+
// CHECK-LABEL: func @matmul_dynamic_m_with_assume
1507+
// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
1508+
// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
1509+
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
1510+
// CHECK-DAG: memref.alloc() : memref<16x130xf32, #gpu.address_space<workgroup>>
1511+
// CHECK-DAG: memref.alloc() : memref<128x18xf32, #gpu.address_space<workgroup>>
1512+
// CHECK: scf.forall ({{.*}}) in (%{{.+}}, 32) {
1513+
// CHECK: scf.for {{.*}} = %c0 to %c512 step %c4 {{.*}} -> (vector<4x4x4x1xf32>)
1514+
// CHECK: gpu.barrier
1515+
// CHECK: vector.transfer_read
1516+
// CHECK: vector.transfer_write
1517+
// CHECK: gpu.barrier
1518+
// CHECK-COUNT-64: amdgpu.mfma 16x16x4
1519+
// CHECK: scf.yield
1520+
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

0 commit comments

Comments
 (0)