From c4094366980cfc17ba7180f97e43a3a608d415b7 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Mon, 2 Dec 2024 14:51:49 -0500 Subject: [PATCH] [mlir] Add ValueBoundsOpInterfaceImpl for scf.forall Signed-off-by: Max Dawkins --- .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 42 +++++++++++++++++++ .../SCF/value-bounds-op-interface-impl.mlir | 36 ++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 17a1c016ea16d..fbd236b648cb8 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -95,6 +95,47 @@ struct ForOpInterface } }; +struct ForallOpInterface + : public ValueBoundsOpInterface::ExternalModel { + + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto forallOp = cast(op); + + // Index values should be induction variables, since the semantics of + // tensor::ParallelInsertSliceOp requires forall outputs to be ranked + // tensors. + auto blockArg = cast(value); + assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() && + "expected index value to be an induction var"); + int64_t idx = blockArg.getArgNumber(); + // TODO: Take into account step size. + AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]); + AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]); + cstr.bound(value) >= lb; + cstr.bound(value) < ub; + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto forallOp = cast(op); + + // `value` is an iter_arg or an OpResult. + int64_t iterArgIdx; + if (auto iterArg = llvm::dyn_cast(value)) { + iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size(); + } else { + iterArgIdx = llvm::cast(value).getResultNumber(); + } + + // The forall results and output arguments have the same sizes as the output + // operands. + Value outputOperand = forallOp.getOutputs()[iterArgIdx]; + cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim); + } +}; + struct IfOpInterface : public ValueBoundsOpInterface::ExternalModel { @@ -161,6 +202,7 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { scf::ForOp::attachInterface(*ctx); + scf::ForallOp::attachInterface(*ctx); scf::IfOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir index 9ab03da1c9a94..65e1017e62c1a 100644 --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -107,6 +107,42 @@ func.func @scf_for_swapping_yield(%t1: tensor, %t2: tensor, %a: in // ----- +// CHECK-LABEL: func @scf_forall( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: "test.some_use"(%[[a]], %[[b]]) +func.func @scf_forall(%a: index, %b: index, %c: index) { + scf.forall (%iv) = (%a) to (%b) step (%c) { + %0 = "test.reify_bound"(%iv) {type = "LB"} : (index) -> (index) + %1 = "test.reify_bound"(%iv) {type = "UB"} : (index) -> (index) + "test.some_use"(%0, %1) : (index, index) -> () + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_forall_tensor_result( +// CHECK-SAME: %[[size:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: "test.some_use"(%[[size]]) +// CHECK: "test.some_use"(%[[size]]) +func.func @scf_forall_tensor_result(%size: index, %a: index, %b: index, %c: index) { + %cst = arith.constant 5.0 : f32 + %empty = tensor.empty(%size) : tensor + %0 = scf.forall (%iv) = (%a) to (%b) step (%c) shared_outs(%arg = %empty) -> tensor { + %filled = linalg.fill ins(%cst : f32) outs(%arg : tensor) -> tensor + %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%1) : (index) -> () + scf.forall.in_parallel { + tensor.parallel_insert_slice %filled into %arg[0][%size][1] : tensor into tensor + } + } + %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + // CHECK-LABEL: func @scf_if_constant( func.func @scf_if_constant(%c : i1) { // CHECK: arith.constant 4 : index