Skip to content

Commit 3beeb26

Browse files
authored
[Codegen] Change swizzle hint offset logic to use arith (iree-org#21237)
After the addition of PropagateConstantOffsets, the expected incoming form for a constant offset is with `arith.addi ..., %c`, so change the logic accordingly. This logic is overall more robust now that there is a preceding pass giving some guarantee as to the form of the input.
1 parent a7a248e commit 3beeb26

File tree

3 files changed

+49
-53
lines changed

3 files changed

+49
-53
lines changed

compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,26 @@ struct ResolveSwizzleHintsPass final
2626
};
2727
} // namespace
2828

29+
static Value createOrFoldNewStaticAdd(RewriterBase &rewriter, Value v,
30+
int64_t offset) {
31+
// Early exit for the common offset = 0 case.
32+
if (offset == 0) {
33+
return v;
34+
}
35+
36+
if (auto add = v.getDefiningOp<arith::AddIOp>()) {
37+
llvm::APInt constant;
38+
if (matchPattern(add.getRhs(), m_ConstantInt(&constant))) {
39+
Value combined = rewriter.create<arith::ConstantIndexOp>(
40+
add.getLoc(), offset + constant.getSExtValue());
41+
return rewriter.create<arith::AddIOp>(add.getLoc(), add.getLhs(),
42+
combined, add.getOverflowFlags());
43+
}
44+
}
45+
Value offsetVal = rewriter.create<arith::ConstantIndexOp>(v.getLoc(), offset);
46+
return rewriter.create<arith::AddIOp>(v.getLoc(), v, offsetVal);
47+
}
48+
2949
/// Swizzles vector.load(iree_codegen.swizzle_hint, offset). The
3050
/// SwizzleInterfaceAttr exposes two methods:
3151
/// 1. getAccessElementCount -> int64_t
@@ -53,10 +73,6 @@ static void swizzleLoad(RewriterBase &rewriter, vector::LoadOp load,
5373
VectorType swizzledLoadType =
5474
VectorType::get({accessWidth}, type.getElementType());
5575

56-
AffineExpr s0, s1;
57-
bindSymbols(rewriter.getContext(), s0, s1);
58-
AffineMap sum = AffineMap::get(0, 2, s0 + s1);
59-
6076
// ~ vector.undef, overwritten by unrolling.
6177
Value replacement = rewriter.create<arith::ConstantOp>(
6278
hintLoc, type, rewriter.getZeroAttr(type));
@@ -65,10 +81,7 @@ static void swizzleLoad(RewriterBase &rewriter, vector::LoadOp load,
6581
// i = 0 -> C += k is the offset into the vector of a contiguous group of
6682
// swizzled elements.
6783
for (int64_t i = 0; i < loadWidth; i += accessWidth) {
68-
auto vecOffset = rewriter.getIndexAttr(i);
69-
auto newBaseOffset = affine::makeComposedFoldedAffineApply(
70-
rewriter, hintLoc, sum, {memrefOffset, vecOffset});
71-
84+
Value newBaseOffset = createOrFoldNewStaticAdd(rewriter, memrefOffset, i);
7285
Value newOffset = getValueOrCreateConstantIndexOp(
7386
rewriter, hintLoc,
7487
hintOp.getSwizzle().swizzleOffset(rewriter, hintOp.getLoc(),
@@ -103,20 +116,14 @@ static void swizzleStore(RewriterBase &rewriter, vector::StoreOp store,
103116
int64_t storeWidth = type.getShape()[0];
104117
Value memrefOffset = store.getIndices()[0];
105118

106-
AffineExpr s0, s1;
107-
bindSymbols(rewriter.getContext(), s0, s1);
108-
AffineMap sum = AffineMap::get(0, 2, s0 + s1);
109-
110119
// Store type = vector<C>, k = accessWidth
111120
// i = 0 -> C += k is the offset into the vector of a contiguous group of
112121
// swizzled elements.
113122
for (int64_t i = 0; i < storeWidth; i += accessWidth) {
114123
Value subVec = rewriter.create<vector::ExtractStridedSliceOp>(
115124
store.getLoc(), store.getValueToStore(), ArrayRef<int64_t>{i},
116125
ArrayRef<int64_t>{accessWidth}, ArrayRef<int64_t>{1});
117-
auto vecOffset = rewriter.getIndexAttr(i);
118-
auto newBaseOffset = affine::makeComposedFoldedAffineApply(
119-
rewriter, hintLoc, sum, {memrefOffset, vecOffset});
126+
Value newBaseOffset = createOrFoldNewStaticAdd(rewriter, memrefOffset, i);
120127

121128
Value newOffset = getValueOrCreateConstantIndexOp(
122129
rewriter, hintLoc,

compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,23 +156,26 @@ func.func @swizzle_dynamic(%src: memref<?xf32>, %vec: vector<4xf32>, %offset: in
156156

157157
// -----
158158

159-
func.func @swizzle_adjust_affine_offset(%src: memref<?xf32>, %vec: vector<4xf32>, %offset_base: index) -> vector<4xf32> {
159+
func.func @swizzle_adjust_add_offset(%src: memref<?xf32>, %vec: vector<4xf32>, %offset_base: index) -> vector<4xf32> {
160160
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<?xf32>
161-
%load_offset = affine.apply affine_map<()[s0] -> (16 + s0)>()[%offset_base]
161+
%c16 = arith.constant 16 : index
162+
%c1040 = arith.constant 1040 : index
163+
%load_offset = arith.addi %offset_base, %c16 overflow<nsw> : index
162164
%1 = vector.load %0[%load_offset] : memref<?xf32>, vector<4xf32>
163-
%store_offset = affine.apply affine_map<()[s0] -> (1040 + s0)>()[%offset_base]
165+
%store_offset = arith.addi %offset_base, %c1040 overflow<nsw> : index
164166
vector.store %vec, %0[%store_offset] : memref<?xf32>, vector<4xf32>
165167
return %1: vector<4xf32>
166168
}
167169

168-
// CHECK-LABEL: func @swizzle_adjust_affine_offset
170+
// CHECK-LABEL: func @swizzle_adjust_add_offset
169171
// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: memref<?xf32>
170172
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<4xf32>
171173
// CHECK-SAME: %[[OFFSET:[A-Za-z0-9]+]]: index
172174
// CHECK-DAG: %[[ROW_WIDTH:.+]] = arith.constant 64 : index
173175
// CHECK-DAG: %[[GROUP_COUNT:.+]] = arith.constant 16 : index
174176
// CHECK-DAG: %[[GROUP_WIDTH:.+]] = arith.constant 4 : index
175-
// CHECK: %[[APPLY_BASE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 16)>()[%[[OFFSET]]]
177+
// CHECK-DAG: %[[C1040:.+]] = arith.constant 1040 : index
178+
// CHECK: %[[APPLY_BASE:.+]] = arith.addi %[[OFFSET]], %[[GROUP_COUNT]] overflow<nsw> : index
176179
// CHECK: %[[I:.+]] = arith.divui %[[APPLY_BASE]], %[[ROW_WIDTH]] : index
177180
// CHECK: %[[JELEM:.+]] = arith.remui %[[APPLY_BASE]], %[[ROW_WIDTH]] : index
178181
// CHECK: %[[J:.+]] = arith.divui %[[JELEM]], %[[GROUP_WIDTH]] : index
@@ -184,7 +187,7 @@ func.func @swizzle_adjust_affine_offset(%src: memref<?xf32>, %vec: vector<4xf32>
184187

185188
// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
186189

187-
// CHECK: %[[STORE_BASE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1040)>()[%[[OFFSET]]]
190+
// CHECK: %[[STORE_BASE:.+]] = arith.addi %[[OFFSET]], %[[C1040]] overflow<nsw> : index
188191
// CHECK: %[[OFFSET_DIFF:.+]] = arith.subi %[[SWOFF]], %[[APPLY_BASE]] : index
189192
// CHECK: %[[STORE_SWOFF:.+]] = arith.addi %[[STORE_BASE]], %[[OFFSET_DIFF]] : index
190193
// CHECK: vector.store %[[VEC]], %[[SRC]][%[[STORE_SWOFF]]]

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -467,55 +467,41 @@ int64_t WorkgroupMappingAttr::getRelativeIndex() const {
467467
// iree_codegen.rotate_rows
468468
//===---------------------------------------------------------------------===//
469469

470-
/// Given `r = |rotationInvariant|`, simplify affine maps of the following form:
470+
/// Given `r = |rotationInvariant|`, simplify additions of the following form:
471471
///
472-
/// ```
473-
/// %offset = affine.apply (d0, ..., dn) -> (f(d0, ..., dn) + c)
474-
/// ```
472+
/// %offset = arith.addi %0, c
475473
///
476474
/// Where `c` is a constant, to:
477475
///
478-
/// ```
479-
/// %offset = affine.apply (d0, ..., dn) -> (f(d0, ..., dn) + c % r)
480-
/// ```
481-
static OpFoldResult getMinimumConstantOffsetMap(OpBuilder &b, Location loc,
482-
OpFoldResult offset,
483-
int64_t rotationInvariant) {
476+
/// %offset = arith.addi %0, c % r
477+
static OpFoldResult getMinimumConstantOffsetValue(OpBuilder &b, Location loc,
478+
OpFoldResult offset,
479+
int64_t rotationInvariant) {
484480
auto value = dyn_cast_if_present<Value>(offset);
485481
if (!value)
486482
return offset;
487483

488-
auto apply = value.getDefiningOp<affine::AffineApplyOp>();
489-
if (!apply)
484+
auto add = value.getDefiningOp<arith::AddIOp>();
485+
if (!add)
490486
return offset;
491487

492-
AffineMap map = apply.getMap();
493-
// Simplify the map to move `+ c` terms to the right most (first) expression
494-
// in the tree.
495-
map = simplifyAffineMap(map);
496-
AffineExpr resultExpr = map.getResult(0);
497-
auto addExpr = llvm::dyn_cast<AffineBinaryOpExpr>(resultExpr);
498-
499-
// After simplification, the add should be the first expression if present.
500-
if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
501-
return offset;
502-
503-
// If RHS is not constant, nothing to do.
504-
auto constantRhs = llvm::dyn_cast<AffineConstantExpr>(addExpr.getRHS());
505-
if (!constantRhs)
488+
llvm::APInt constant;
489+
if (!matchPattern(add.getRhs(), m_ConstantInt(&constant)))
506490
return offset;
507491

508-
int64_t constantOffset = constantRhs.getValue();
492+
int64_t constantOffset = constant.getSExtValue();
509493
int64_t baseMod = constantOffset % rotationInvariant;
510494

511495
// Skip constructing the new apply if it's not needed (c < rotationInvariant).
512496
if (baseMod == constantOffset)
513497
return offset;
514498

515-
AffineExpr newExpr =
516-
addExpr.getLHS() + getAffineConstantExpr(baseMod, b.getContext());
517-
map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExpr);
518-
return b.create<affine::AffineApplyOp>(loc, map, apply.getOperands())
499+
Value modOffset = b.create<arith::ConstantIndexOp>(loc, baseMod);
500+
// If the original add is nsw/nuw, then the new add must also be given we're
501+
// adding a smaller value.
502+
return b
503+
.create<arith::AddIOp>(loc, add.getLhs(), modOffset,
504+
add.getOverflowFlags())
519505
.getResult();
520506
}
521507

@@ -531,7 +517,7 @@ OpFoldResult RotateRowsAttr::swizzleOffset(OpBuilder &b, Location loc,
531517
int64_t rotationInvariant =
532518
getRowWidth() * (getRowWidth() / getAccessWidth());
533519
OpFoldResult id =
534-
getMinimumConstantOffsetMap(b, loc, offset, rotationInvariant);
520+
getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
535521

536522
// Number of elements per row.
537523
Value rowAlignmentVal = b.create<arith::ConstantIndexOp>(loc, getRowWidth());

0 commit comments

Comments
 (0)