Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
/// there is no outer bound specified, the leading entry of this result will be
/// nullptr.
SmallVector<OpFoldResult> getPaddedBasis();

/// Returns true if the result of this operation can be used as dimension id
/// within 'region', i.e., for all its uses with `region`.
bool isValidDim(Region *region);
}];

let hasVerifier = 1;
Expand Down Expand Up @@ -1253,6 +1257,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
/// If there is no outer bound specified, the leading entry of this basis will be
/// nullptr.
SmallVector<OpFoldResult> getPaddedBasis();

/// Returns true if the result of this operation can be used as dimension id
/// within 'region', i.e., for all its uses with `region`.
bool isValidDim(Region *region);
}];

let hasVerifier = 1;
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ bool mlir::affine::isValidDim(Value value) {
// *) It is valid as a symbol.
// *) It is an induction variable.
// *) It is the result of an affine apply operation with dimension id operands.
// *) It is the result of a more specialized index transformation (ex.
// delinearize_index or linearize_index) with dimension id operands.
bool mlir::affine::isValidDim(Value value, Region *region) {
// The value must be an index type.
if (!value.getType().isIndex())
Expand All @@ -326,6 +328,10 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
// Affine apply operation is ok if all of its operands are ok.
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
return applyOp.isValidDim(region);
if (auto delinearizeOp = dyn_cast<AffineDelinearizeIndexOp>(op))
return delinearizeOp.isValidDim(region);
if (auto linearizeOp = dyn_cast<AffineLinearizeIndexOp>(op))
return linearizeOp.isValidDim(region);
// The dim op is okay if its operand memref/tensor is defined at the top
// level.
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
Expand Down Expand Up @@ -4636,6 +4642,14 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
hasOuterBound);
}

// The result of the affine apply operation can be used as a dimension id if all
// its operands are valid dimension ids with the parent operation of `region`
// defining the polyhedral scope for symbols.
bool AffineDelinearizeIndexOp::isValidDim(Region *region) {
return llvm::all_of(getOperands(),
[&](Value op) { return ::isValidDim(op, region); });
}

void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ArrayRef<int64_t> basis,
Expand Down Expand Up @@ -5023,6 +5037,14 @@ LogicalResult AffineLinearizeIndexOp::verify() {
return success();
}

// The result of the affine apply operation can be used as a dimension id if all
// its operands are valid dimension ids with the parent operation of `region`
// defining the polyhedral scope for symbols.
bool AffineLinearizeIndexOp::isValidDim(Region *region) {
return llvm::all_of(getOperands(),
[&](Value op) { return ::isValidDim(op, region); });
}

OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
std::optional<SmallVector<int64_t>> maybeStaticBasis =
foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Affine/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,34 @@ func.func @dynamic_dimension_index() {

// -----

func.func @dynamic_linearized_index() {
"unknown.region"() ({
%idx = "unknown.test"() : () -> (index)
%memref = "unknown.test"() : () -> memref<?xf32>
%pos = affine.linearize_index [%idx, %idx] by (8) : index
// expected-error@below {{op operand cannot be used as a dimension id}}
affine.load %memref[%pos] : memref<?xf32>
"unknown.terminator"() : () -> ()
}) : () -> ()
return
}

// -----

func.func @dynamic_delinearized_index() {
"unknown.region"() ({
%idx = "unknown.test"() : () -> (index)
%memref = "unknown.test"() : () -> memref<?x?xf32>
%pos0, %pos1 = affine.delinearize_index %idx into (8) : index, index
// expected-error@below {{op operand cannot be used as a dimension id}}
affine.load %memref[%pos0, %pos1] : memref<?x?xf32>
"unknown.terminator"() : () -> ()
}) : () -> ()
return
}

// -----

#map = affine_map<() -> ()>
#map1 = affine_map<() -> (1)>
func.func @no_lower_bound() {
Expand Down
19 changes: 18 additions & 1 deletion mlir/test/Dialect/Affine/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@ func.func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {

// -----

// Test dimension constraints for linearize_index and delinearize_index

// CHECK-LABEL: func @valid_dim_linearize_delinearize
func.func @valid_dim_linearize_delinearize(%m : index, %n : index, %A : memref<?xf32>, %B: memref<?x32x?xf32>) {
affine.for %0 = 0 to %m {
affine.for %1 = 0 to %n {
%load_idx = affine.linearize_index disjoint [%0, %1] by (%m, %n) : index
%store_idx0, %store_idx1 = affine.delinearize_index %n into (32) : index, index
%v = affine.load %A[%load_idx] : memref<?xf32>
affine.store %v, %B[%0, %store_idx1, %store_idx0] : memref<?x32x?xf32>
}
}
return
}

// -----

// Test the fact that module op always provides an affine scope.

%idx = "test.foo"() : () -> (index)
Expand Down Expand Up @@ -309,7 +326,7 @@ func.func @linearize_mixed(%index0: index, %index1: index, %index2: index, %basi
module {
func.func @gpu_launch_affine() {
%c1 = arith.constant 1 : index
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1)
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1)
threads(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) {
%thread_id_x = gpu.thread_id x
%c128 = arith.constant 128 : index
Expand Down
Loading