Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "externals/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
url = https://github.com/iree-org/llvm-project.git
[submodule "externals/stablehlo"]
path = externals/stablehlo
url = https://github.com/openxla/stablehlo.git
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 23991 files
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 83 files
+32 −7 .github/actions/setup-build/action.yml
+1 −1 .github/workflows/buildAndTestCMake.yml
+1 −0 .gitignore
+208 −1 BUILD.bazel
+1 −6 MODULE.bazel
+42 −216 MODULE.bazel.lock
+7 −5 README.md
+9 −2 WORKSPACE.bazel
+3 −0 build_tools/github_actions/ci_build_docs.sh
+1 −1 build_tools/github_actions/lint_version.sh
+1 −1 build_tools/llvm_version.txt
+8 −0 docs/_toc.yaml
+444 −0 docs/generated/ChloBuilder.md
+44 −0 docs/generated/FuncBuilder.md
+1,047 −0 docs/generated/StablehloBuilder.md
+3 −2 docs/generated/stablehlo_linalg_passes.md
+3 −3 docs/generated/stablehlo_optimization_passes.md
+1 −0 docs/images/spec/broadcast_in_dim.svg
+93 −24 docs/spec.md
+1 −1 docs/tutorials/jax-export.ipynb
+31 −0 stablehlo/conversions/linalg/tests/pointwise.mlir
+4 −5 stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp
+44 −47 stablehlo/conversions/linalg/transforms/MapStablehloToScalarOp.h
+5 −1 stablehlo/conversions/linalg/transforms/Passes.td
+19 −17 stablehlo/conversions/linalg/transforms/Rewriters.h
+6 −3 stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+10 −9 stablehlo/conversions/linalg/transforms/StablehloToArith.cpp
+39 −24 stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp
+24 −24 stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir
+1 −1 stablehlo/conversions/tosa/tests/legalize_tosa_rescale_to_stablehlo.mlir
+8 −0 stablehlo/conversions/tosa/tests/unary.mlir
+35 −0 stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp
+50 −48 stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp
+1 −1 stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp
+30 −0 stablehlo/dialect/Base.cpp
+5 −0 stablehlo/dialect/Base.h
+24 −18 stablehlo/dialect/Base.td
+1 −1 stablehlo/dialect/StablehloAttrs.td
+145 −6 stablehlo/dialect/StablehloOps.cpp
+13 −0 stablehlo/dialect/StablehloOps.h
+6 −5 stablehlo/dialect/StablehloOps.td
+2 −2 stablehlo/dialect/Version.cpp
+1 −1 stablehlo/dialect/Version.h
+158 −126 stablehlo/dialect/VhloBytecode.cpp
+1 −0 stablehlo/dialect/VhloDialect.td
+11 −0 stablehlo/dialect/VhloTypes.cpp
+17 −0 stablehlo/dialect/VhloTypes.td
+1 −0 stablehlo/integrations/CMakeLists.txt
+15 −0 stablehlo/integrations/cpp/CMakeLists.txt
+177 −0 stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.cpp
+295 −0 stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.h
+282 −0 stablehlo/integrations/cpp/builder/AttrTypeBuilderUtilTest.cpp
+146 −0 stablehlo/integrations/cpp/builder/CMakeLists.txt
+35 −0 stablehlo/integrations/cpp/builder/ChloBuilder.cpp
+37 −0 stablehlo/integrations/cpp/builder/ChloBuilder.h
+94 −0 stablehlo/integrations/cpp/builder/FuncBuilder.cpp
+85 −0 stablehlo/integrations/cpp/builder/FuncBuilder.h
+57 −0 stablehlo/integrations/cpp/builder/MlirBuilder.cpp
+271 −0 stablehlo/integrations/cpp/builder/MlirBuilder.h
+45 −0 stablehlo/integrations/cpp/builder/MlirBuilderTblgen.cmake
+763 −0 stablehlo/integrations/cpp/builder/MlirBuilderTblgen.cpp
+308 −0 stablehlo/integrations/cpp/builder/MlirBuilderTest.cpp
+245 −0 stablehlo/integrations/cpp/builder/README.md
+159 −0 stablehlo/integrations/cpp/builder/StablehloBuilder.cpp
+82 −0 stablehlo/integrations/cpp/builder/StablehloBuilder.h
+1,596 −0 stablehlo/integrations/cpp/builder/StablehloBuilderTest.cpp
+17 −3 stablehlo/integrations/python/tests/testdata_generator_test.py
+24 −0 stablehlo/tests/chlo/chlo_legalize_to_stablehlo_broadcast.mlir
+1 −1 stablehlo/tests/interpret/probe.mlir
+181 −0 stablehlo/tests/ops_stablehlo.mlir
+743 −46 stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
+39 −31 stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+31 −0 stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir
+3,005 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_13_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_13_0.mlir.bc
+27 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+7 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_12_0.mlir
+156 −77 stablehlo/transforms/ChloLegalizeToStablehlo.cpp
+4 −0 stablehlo/transforms/VhloToVersion.cpp
+3 −3 stablehlo/transforms/optimization/Passes.td
+560 −145 stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
+96 −82 stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
+7 −4 stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
5 changes: 3 additions & 2 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H

#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand All @@ -26,8 +27,8 @@ namespace tosa {
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp, StringRef rounding_mode,
bool scale32);
int64_t input_zp, int64_t output_zp,
tosa::RoundingMode rounding_mode, bool scale32);

// Creates TOSA rescale op with int32 output
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Expand Down
69 changes: 48 additions & 21 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOpT> {
// tosa.minimum
binaryOp = rewriter.create<TosaOpT>(
op->getLoc(), outTy, lhs, rhs,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
Expand Down Expand Up @@ -907,7 +909,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, outTy, self, minFloatAttr, maxFloatAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));
return success();
}

Expand Down Expand Up @@ -1237,7 +1241,9 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
.create<tosa::ArgMaxOp>(
op->getLoc(), getTypeConverter()->convertType(outputReduceTy),
input, reduceDimAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE))
.getResult();
};

Expand Down Expand Up @@ -3925,7 +3931,9 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
self, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
reduceOp = rewriter.create<TosaOpT>(
op->getLoc(),
Expand All @@ -3946,14 +3954,18 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
negateOp, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
negateOp, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
self, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
}

if (argMaxOp.getType() != indicesType) {
Expand Down Expand Up @@ -5202,7 +5214,9 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));
} else {
FloatAttr minFloatAttr, maxFloatAttr;
if (outElemTy.isF16()) {
Expand Down Expand Up @@ -5231,7 +5245,9 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));
}

return success();
Expand Down Expand Up @@ -5340,13 +5356,17 @@ LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
// Use default NaN Propagation mode "PROPAGATE" for tosa.maximum
auto minThresholdCheck = rewriter.create<tosa::MaximumOp>(
op->getLoc(), resultType, self, min,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));

// yi = min(max(xi, min_valuei), max_valuei)
// Use default NaN Propagation mode "PROPAGATE" for tosa.minimum
auto result = rewriter.create<tosa::MinimumOp>(
op->getLoc(), resultType, minThresholdCheck, max,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));

rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -5934,7 +5954,10 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
pooledOutput = rewriter
.create<TosaOpT>(
op->getLoc(), outputTy, input, kernel, stride, pad,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE))
.getResult();
} else if constexpr (std::is_same<TosaOpT, tosa::AvgPool2dOp>::value) {
TypeAttr accType;
Expand Down Expand Up @@ -6830,11 +6853,11 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only nearest and bilinear interpolation modes supported");

std::string mode;
tosa::ResizeMode mode;
if (pyMode == "bilinear") {
mode = "BILINEAR";
mode = tosa::ResizeMode::BILINEAR;
} else {
mode = "NEAREST_NEIGHBOR";
mode = tosa::ResizeMode::NEAREST_NEIGHBOR;
}

bool alignCorners;
Expand Down Expand Up @@ -6896,7 +6919,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
offset = 0;

// If nearest neighbours we need to guarantee we round up.
if (mode == "NEAREST_NEIGHBOR" && alignCorners) {
if (mode == tosa::ResizeMode::NEAREST_NEIGHBOR && alignCorners) {
offset += n / 2;
}

Expand All @@ -6916,7 +6939,8 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
tosa::getTosaConstShape(rewriter, op->getLoc(), {offset_y, offset_x});
auto border =
tosa::getTosaConstShape(rewriter, op->getLoc(), {border_y, border_x});
StringAttr modeAttr = rewriter.getStringAttr(mode);

auto modeAttr = tosa::ResizeModeAttr::get(rewriter.getContext(), mode);

auto resizeOpResult =
rewriter
Expand Down Expand Up @@ -8610,11 +8634,14 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
// Clamp input to [eps, 1 - eps] when eps is not None
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
if (!isEpsNone) {
zi = rewriter
.create<tosa::ClampOp>(
op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
.getResult();
zi =
rewriter
.create<tosa::ClampOp>(
op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr,
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE))
.getResult();
}

auto one =
Expand Down
7 changes: 5 additions & 2 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand Down Expand Up @@ -764,7 +765,9 @@ std::optional<Value> convertReduceOpCommon(
// and tosa.reduce_max
reduce_op = CreateOpAndInfer<T>(
rewriter, op->getLoc(), reduce_type, val, axis_attr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
Expand All @@ -777,7 +780,7 @@ std::optional<Value> convertReduceOpCommon(
RankedTensorType output_rescale_type =
RankedTensorType::get(shape_vec, output_type.getElementType());
val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
0, output_zp, "SINGLE_ROUND", true);
0, output_zp, tosa::RoundingMode::SINGLE_ROUND, true);
}

// Optionally squeeze out the reduced axes.
Expand Down
17 changes: 11 additions & 6 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter,
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp, StringRef rounding_mode,
bool scale32) {
int64_t input_zp, int64_t output_zp,
tosa::RoundingMode rounding_mode, bool scale32) {
int32_t multiplier;
int32_t shift;

Expand Down Expand Up @@ -70,7 +70,8 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, input_val, multiplier_val, shift_val,
input_zp_val.value(), output_zp_val.value(),
rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode),
rewriter.getBoolAttr(scale32),
tosa::RoundingModeAttr::get(rewriter.getContext(), rounding_mode),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
rewriter.getBoolAttr(output_unsigned));

Expand All @@ -87,7 +88,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
auto output_type = input_type.clone(rewriter.getI32Type());

return buildRescale(rewriter, op, output_type, input_val, input_scale,
input_zp, 0, "SINGLE_ROUND", true);
input_zp, 0, tosa::RoundingMode::SINGLE_ROUND, true);
}

// Creates a TOSA rescale op based on conv2d parameters.
Expand Down Expand Up @@ -146,7 +147,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
shift_val, input_zp_val.value(), output_zp_val.value(),
rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"),
rewriter.getBoolAttr(scale32),
tosa::RoundingModeAttr::get(rewriter.getContext(),
tosa::RoundingMode::DOUBLE_ROUND),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
rewriter.getBoolAttr(output_unsigned));

Expand Down Expand Up @@ -188,7 +191,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
shift_val, input_zp_val.value(), output_zp_val.value(),
rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"),
rewriter.getBoolAttr(scale32),
tosa::RoundingModeAttr::get(rewriter.getContext(),
tosa::RoundingMode::DOUBLE_ROUND),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(input_unsigned),
rewriter.getBoolAttr(output_unsigned));

Expand Down
29 changes: 15 additions & 14 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,21 +373,21 @@ LogicalResult ClassTypeOp::verify() {
// PrimLoopOp
//===----------------------------------------------------------------------===//

OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getRegion());
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionSuccessor successor) {
assert(successor.getSuccessor() == &getRegion());
return getIterArgsInit();
}

void PrimLoopOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region &region = getRegion();
if (!point.getRegionOrNull()) {
if (!point.getTerminatorPredecessorOrNull()) {
regions.emplace_back(&region, region.getArguments().slice(1));
return;
}
assert(point == region);
assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == &region);
regions.emplace_back(&region, region.getArguments().slice(1));
regions.emplace_back(getResults());
regions.emplace_back(getOperation(), getResults());
}

bool PrimLoopOp::isForLike() {
Expand All @@ -400,7 +400,7 @@ bool PrimLoopOp::isForLike() {
//===----------------------------------------------------------------------===//

MutableOperandRange
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
PrimLoopConditionOp::getMutableSuccessorOperands(RegionSuccessor successor) {
// Pass all operands except the condition to the successor which is the
// parent loop op.
return getIterArgsMutable();
Expand Down Expand Up @@ -452,8 +452,8 @@ void PrimIfOp::print(OpAsmPrinter &p) {
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (point.getRegionOrNull()) {
regions.push_back(RegionSuccessor(getResults()));
if (point.getTerminatorPredecessorOrNull()) {
regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}

Expand Down Expand Up @@ -5321,17 +5321,18 @@ template <typename CalculateOp>
static void
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.getRegionOrNull()) {
if (!point.getTerminatorPredecessorOrNull()) {
// First thing the op does is branch into the calculation.
regions.emplace_back(&op.getCalculation());
return;
}
if (point == op.getBody()) {
Region *region = point.getTerminatorPredecessorOrNull()->getParentRegion();
if (region == &op.getBody()) {
// Body returns control to the outer op, passing through results.
regions.emplace_back(op.getResults());
regions.emplace_back(op.getOperation(), op.getResults());
return;
}
assert(point == op.getCalculation());
assert(region == &op.getCalculation());
// Calculation branches to the body.
regions.emplace_back(&op.getBody());
}
Expand All @@ -5355,7 +5356,7 @@ void DtypeCalculateOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//

MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
RegionBranchPoint point) {
RegionSuccessor successor) {
// The shape operands don't get forwarded to the body.
// MutableOperandRange always has an owning operation, even if empty, so
// create a 0-length range.
Expand Down Expand Up @@ -5846,7 +5847,7 @@ LogicalResult AtenKthvalueOp::verify() {
//===----------------------------------------------------------------------===//

MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
RegionBranchPoint point) {
RegionSuccessor successor) {
// The dtype operands don't get forwarded to the body.
// MutableOperandRange always has an owning operation, even if empty, so
// create a 0-length range.
Expand Down
3 changes: 2 additions & 1 deletion lib/RefBackend/RefBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
auto func = getOperation();
auto *context = &getContext();
RewritePatternSet patterns(context);
populateExpandTanhPattern(patterns);
math::populateExpansionPatterns(patterns,
{math::TanhOp::getOperationName()});
patterns.add<math::ErfPolynomialApproximation>(patterns.getContext());
ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect>();
Expand Down
Loading
Loading