Skip to content

Commit 22cb18f

Browse files
committed
[mlir][affine] Make [de]linearize_index a valid source of dims
There's a sense in which affine.linearize_index and affine.delinearize_index are special-cases of affine.apply (which get their own ops to enable better code generation and more accurate canonicalization). Therefore, allow these operations to be dimension operands for operations like affine.load just like affine.apply can be.
1 parent c7c1283 commit 22cb18f

File tree

4 files changed

+76
-1
lines changed

4 files changed

+76
-1
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
11531153
/// there is no outer bound specified, the leading entry of this result will be
11541154
/// nullptr.
11551155
SmallVector<OpFoldResult> getPaddedBasis();
1156+
1157+
/// Returns true if the result of this operation can be used as dimension id
1158+
/// within 'region', i.e., for all its uses with `region`.
1159+
bool isValidDim(Region *region);
11561160
}];
11571161

11581162
let hasVerifier = 1;
@@ -1253,6 +1257,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
12531257
/// If there is no outer bound specified, the leading entry of this basis will be
12541258
/// nullptr.
12551259
SmallVector<OpFoldResult> getPaddedBasis();
1260+
1261+
/// Returns true if the result of this operation can be used as dimension id
1262+
/// within 'region', i.e., for all its uses with `region`.
1263+
bool isValidDim(Region *region);
12561264
}];
12571265

12581266
let hasVerifier = 1;

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ bool mlir::affine::isValidDim(Value value) {
307307
// *) It is valid as a symbol.
308308
// *) It is an induction variable.
309309
// *) It is the result of an affine apply operation with dimension id operands.
310+
// *) It is the result of a more specialized index transformation (ex.
311+
// delinearize_index or linearize_index) with dimension id operands.
310312
bool mlir::affine::isValidDim(Value value, Region *region) {
311313
// The value must be an index type.
312314
if (!value.getType().isIndex())
@@ -326,6 +328,10 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
326328
// Affine apply operation is ok if all of its operands are ok.
327329
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
328330
return applyOp.isValidDim(region);
331+
if (auto delinearizeOp = dyn_cast<AffineDelinearizeIndexOp>(op))
332+
return delinearizeOp.isValidDim(region);
333+
if (auto linearizeOp = dyn_cast<AffineLinearizeIndexOp>(op))
334+
return linearizeOp.isValidDim(region);
329335
// The dim op is okay if its operand memref/tensor is defined at the top
330336
// level.
331337
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
@@ -4636,6 +4642,14 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
46364642
hasOuterBound);
46374643
}
46384644

4645+
// The result of the affine apply operation can be used as a dimension id if all
4646+
// its operands are valid dimension ids with the parent operation of `region`
4647+
// defining the polyhedral scope for symbols.
4648+
bool AffineDelinearizeIndexOp::isValidDim(Region *region) {
4649+
return llvm::all_of(getOperands(),
4650+
[&](Value op) { return ::isValidDim(op, region); });
4651+
}
4652+
46394653
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
46404654
OperationState &odsState,
46414655
Value linearIndex, ArrayRef<int64_t> basis,
@@ -5023,6 +5037,14 @@ LogicalResult AffineLinearizeIndexOp::verify() {
50235037
return success();
50245038
}
50255039

5040+
// The result of the affine apply operation can be used as a dimension id if all
5041+
// its operands are valid dimension ids with the parent operation of `region`
5042+
// defining the polyhedral scope for symbols.
5043+
bool AffineLinearizeIndexOp::isValidDim(Region *region) {
5044+
return llvm::all_of(getOperands(),
5045+
[&](Value op) { return ::isValidDim(op, region); });
5046+
}
5047+
50265048
OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
50275049
std::optional<SmallVector<int64_t>> maybeStaticBasis =
50285050
foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,34 @@ func.func @dynamic_dimension_index() {
544544

545545
// -----
546546

547+
func.func @dynamic_linearized_index() {
548+
"unknown.region"() ({
549+
%idx = "unknown.test"() : () -> (index)
550+
%memref = "unknown.test"() : () -> memref<?xf32>
551+
%pos = affine.linearize_index [%idx, %idx] by (8) : index
552+
// expected-error@below {{op operand cannot be used as a dimension id}}
553+
affine.load %memref[%pos] : memref<?xf32>
554+
"unknown.terminator"() : () -> ()
555+
}) : () -> ()
556+
return
557+
}
558+
559+
// -----
560+
561+
func.func @dynamic_delinearized_index() {
562+
"unknown.region"() ({
563+
%idx = "unknown.test"() : () -> (index)
564+
%memref = "unknown.test"() : () -> memref<?x?xf32>
565+
%pos0, %pos1 = affine.delinearize_index %idx into (8) : index, index
566+
// expected-error@below {{op operand cannot be used as a dimension id}}
567+
affine.load %memref[%pos0, %pos1] : memref<?x?xf32>
568+
"unknown.terminator"() : () -> ()
569+
}) : () -> ()
570+
return
571+
}
572+
573+
// -----
574+
547575
#map = affine_map<() -> ()>
548576
#map1 = affine_map<() -> (1)>
549577
func.func @no_lower_bound() {

mlir/test/Dialect/Affine/ops.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,23 @@ func.func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {
148148

149149
// -----
150150

151+
// Test dimension constraints for linearize_index and delinearize_index
152+
153+
// CHECK-LABEL: func @valid_dim_linearize_delinearize
154+
func.func @valid_dim_linearize_delinearize(%m : index, %n : index, %A : memref<?xf32>, %B: memref<?x32x?xf32>) {
155+
affine.for %0 = 0 to %m {
156+
affine.for %1 = 0 to %n {
157+
%load_idx = affine.linearize_index disjoint [%0, %1] by (%m, %n) : index
158+
%store_idx0, %store_idx1 = affine.delinearize_index %n into (32) : index, index
159+
%v = affine.load %A[%load_idx] : memref<?xf32>
160+
affine.store %v, %B[%0, %store_idx1, %store_idx0] : memref<?x32x?xf32>
161+
}
162+
}
163+
return
164+
}
165+
166+
// -----
167+
151168
// Test the fact that module op always provides an affine scope.
152169

153170
%idx = "test.foo"() : () -> (index)
@@ -309,7 +326,7 @@ func.func @linearize_mixed(%index0: index, %index1: index, %index2: index, %basi
309326
module {
310327
func.func @gpu_launch_affine() {
311328
%c1 = arith.constant 1 : index
312-
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1)
329+
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1)
313330
threads(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) {
314331
%thread_id_x = gpu.thread_id x
315332
%c128 = arith.constant 128 : index

0 commit comments

Comments
 (0)