Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,31 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Creates IntegerAttrs for clamping, using provided min/max values or the
// numeric limits of the element type if the values are not provided.
LogicalResult getIntegerClampAttrs(ConversionPatternRewriter &rewriter,
Operation *op, Type elemTy,
std::optional<int64_t> minInt,
std::optional<int64_t> maxInt,
IntegerAttr &minAttr, IntegerAttr &maxAttr);

// Creates FloatAttrs for clamping, using provided min/max values or the numeric
// limits of the element type if the values are not provided.
LogicalResult getFloatClampAttrs(ConversionPatternRewriter &rewriter,
Operation *op, Type elemTy,
std::optional<double> minFloat,
std::optional<double> maxFloat,
FloatAttr &minAttr, FloatAttr &maxAttr);

// Implements "round half to even" logic for aten.round using TOSA ops.
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
// res = floor(input)
// else:
// res = ceil(input)
std::optional<Value> createRoundHalfToEven(ConversionPatternRewriter &rewriter,
Operation *op, Value input,
RankedTensorType resultTy);

} // namespace tosa
} // namespace mlir

Expand Down
264 changes: 119 additions & 145 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5304,69 +5304,45 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
auto outElemTy = outType.getElementType();

int64_t minInt, maxInt;
double minFloat, maxFloat;
bool isMinNotNone = false;
bool isMaxNotNone = false;

auto isMinInt = matchPattern(op.getMin(), m_TorchConstantInt(&minInt));
auto isMinFloat = matchPattern(op.getMin(), m_TorchConstantFloat(&minFloat));
if (isMinInt) {
minFloat = static_cast<float>(minInt);
isMinNotNone = true;
} else if (isMinFloat) {
minInt = static_cast<int64_t>(minFloat);
isMinNotNone = true;
} else {
if (succeeded(checkNotNone(rewriter, op, op.getMin())))
std::optional<int64_t> minInt;
std::optional<double> minFloat;
{
int64_t minIntVal;
double minFloatVal;
if (matchPattern(op.getMin(), m_TorchConstantInt(&minIntVal))) {
minInt = minIntVal;
minFloat = static_cast<double>(minIntVal);
} else if (matchPattern(op.getMin(), m_TorchConstantFloat(&minFloatVal))) {
minFloat = minFloatVal;
minInt = static_cast<int64_t>(minFloatVal);
} else if (succeeded(checkNotNone(rewriter, op, op.getMin()))) {
return rewriter.notifyMatchFailure(op,
"min attr should be a torch constant");
}
}

auto isMaxInt = matchPattern(op.getMax(), m_TorchConstantInt(&maxInt));
auto isMaxFloat = matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloat));
if (isMaxInt) {
maxFloat = static_cast<float>(maxInt);
isMaxNotNone = true;
} else if (isMaxFloat) {
maxInt = static_cast<int64_t>(maxFloat);
isMaxNotNone = true;
} else {
if (succeeded(checkNotNone(rewriter, op, op.getMax())))
std::optional<int64_t> maxInt;
std::optional<double> maxFloat;
{
int64_t maxIntVal;
double maxFloatVal;
if (matchPattern(op.getMax(), m_TorchConstantInt(&maxIntVal))) {
maxInt = maxIntVal;
maxFloat = static_cast<double>(maxIntVal);
} else if (matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloatVal))) {
maxFloat = maxFloatVal;
maxInt = static_cast<int64_t>(maxFloatVal);
} else if (succeeded(checkNotNone(rewriter, op, op.getMax()))) {
return rewriter.notifyMatchFailure(op,
"max attr should be a torch constant");
}
}

if (!isa<mlir::FloatType>(outElemTy)) {
IntegerAttr minIntAttr, maxIntAttr;
if (outElemTy.isInteger(8)) {
minIntAttr = rewriter.getIntegerAttr(
outElemTy,
isMinNotNone ? minInt : std::numeric_limits<int8_t>::min());
maxIntAttr = rewriter.getIntegerAttr(
outElemTy,
isMaxNotNone ? maxInt : std::numeric_limits<int8_t>::max());
} else if (outElemTy.isInteger(16)) {
minIntAttr = rewriter.getIntegerAttr(
outElemTy,
isMinNotNone ? minInt : std::numeric_limits<int16_t>::min());
maxIntAttr = rewriter.getIntegerAttr(
outElemTy,
isMaxNotNone ? maxInt : std::numeric_limits<int16_t>::max());
} else if (outElemTy.isInteger(32)) {
minIntAttr = rewriter.getIntegerAttr(
outElemTy,
isMinNotNone ? minInt : std::numeric_limits<int32_t>::min());
maxIntAttr = rewriter.getIntegerAttr(
outElemTy,
isMaxNotNone ? maxInt : std::numeric_limits<int32_t>::max());
} else if (outElemTy.isInteger(64)) {
minIntAttr = rewriter.getI64IntegerAttr(
isMinNotNone ? minInt : std::numeric_limits<int64_t>::min());
maxIntAttr = rewriter.getI64IntegerAttr(
isMaxNotNone ? maxInt : std::numeric_limits<int64_t>::max());
} else {
return rewriter.notifyMatchFailure(op, "Unsupported integer type");
if (failed(tosa::getIntegerClampAttrs(rewriter, op, outElemTy, minInt,
maxInt, minIntAttr, maxIntAttr))) {
return failure();
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
Expand All @@ -5376,28 +5352,10 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
tosa::NanPropagationMode::PROPAGATE));
} else {
FloatAttr minFloatAttr, maxFloatAttr;
if (outElemTy.isF16()) {
minFloatAttr =
rewriter.getF16FloatAttr(isMinNotNone ? minFloat : Float16Lowest);
maxFloatAttr =
rewriter.getF16FloatAttr(isMaxNotNone ? maxFloat : Float16Max);
} else if (outElemTy.isBF16()) {
minFloatAttr = rewriter.getFloatAttr(
rewriter.getBF16Type(), isMinNotNone ? minFloat : BFloat16Lowest);
maxFloatAttr = rewriter.getFloatAttr(
rewriter.getBF16Type(), isMaxNotNone ? maxFloat : BFloat16Max);
} else if (outElemTy.isF32()) {
minFloatAttr = rewriter.getF32FloatAttr(
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
maxFloatAttr = rewriter.getF32FloatAttr(
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
} else if (outElemTy.isF64()) {
minFloatAttr = rewriter.getF64FloatAttr(
isMinNotNone ? minFloat : std::numeric_limits<double>::lowest());
maxFloatAttr = rewriter.getF64FloatAttr(
isMaxNotNone ? maxFloat : std::numeric_limits<double>::max());
} else {
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
if (failed(tosa::getFloatClampAttrs(rewriter, op, outElemTy, minFloat,
maxFloat, minFloatAttr,
maxFloatAttr))) {
return failure();
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
Expand Down Expand Up @@ -7308,17 +7266,6 @@ template <>
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
AtenRoundOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// To round to the nearest integer, we will consider the fractional part of
// the input element (= input element - integer part of element). If the
// fractional part is smaller than 0.5, round the number down. If the
// fractional part is 0.5, apply "round half to even" rule. If the fractional
// part is greater than 0.5, round up.
//
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
// res = floor(input)
// else:
// res = ceil(input)

auto self = adaptor.getSelf();

auto selfTy = dyn_cast<TensorType>(self.getType());
Expand All @@ -7328,67 +7275,13 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
auto resultTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));

auto boolTy =
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));

auto resultElemTy = resultTy.getElementType();

auto oneHalf =
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();

auto two =
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf)
.failed() ||
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, two).failed())
auto result = tosa::createRoundHalfToEven(rewriter, op, self, resultTy);
if (!result) {
return rewriter.notifyMatchFailure(
op, "Failed to equalize ranks among operands and result");

auto floorInput =
tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, self);

// input - floor(input)
auto fractionalPart = tosa::SubOp::create(rewriter, op->getLoc(), resultTy,
self, floorInput.getResult());

auto ceilInput = tosa::CeilOp::create(rewriter, op->getLoc(), resultTy, self);

auto floorInputDivByTwo = tosa::createMulOpAndCast(
rewriter, op, resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);

auto floorDivResult = tosa::FloorOp::create(rewriter, op->getLoc(), resultTy,
floorInputDivByTwo.getResult());

// (floor(input) // 2) * 2
auto evenComparison = tosa::createMulOpAndCast(
rewriter, op, resultTy, floorDivResult.getResult(), two, /*shift=*/0);

// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
auto floorInputEven =
tosa::EqualOp::create(rewriter, op->getLoc(), boolTy,
floorInput.getResult(), evenComparison.getResult());

auto fracEqualOneHalf = tosa::EqualOp::create(
rewriter, op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);

auto fracLtOneHalf = tosa::GreaterOp::create(
rewriter, op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());

// (frac == 0.5) && (floor(input) % 2 == 0)
auto fracEqualOneHalfCond = tosa::LogicalAndOp::create(
rewriter, op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
floorInputEven.getResult());

// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
auto floorResultCond = tosa::LogicalOrOp::create(
rewriter, op->getLoc(), boolTy, fracLtOneHalf.getResult(),
fracEqualOneHalfCond.getResult());

rewriter.replaceOpWithNewOp<tosa::SelectOp>(
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
ceilInput.getResult());
op, "failed to implement round-half-to-even with TOSA ops");
}

rewriter.replaceOp(op, *result);
return success();
}

Expand Down Expand Up @@ -9339,6 +9232,86 @@ LogicalResult ConvertAtenOp<AtenDequantizeTensorOp>::matchAndRewrite(
return success();
}

// Legalization for aten.quantize_per_tensor
// Implements
// Q = clamp(round(X / scale) + zero_point)
template <>
LogicalResult ConvertAtenOp<AtenQuantizePerTensorOp>::matchAndRewrite(
AtenQuantizePerTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto loc = op->getLoc();

// Get scale and zero_point as constants.
double scaleConst;
if (!matchPattern(op.getScale(), m_TorchConstantFloat(&scaleConst)))
return rewriter.notifyMatchFailure(op, "scale must be a Scalar constant");

int64_t zpConst;
if (!matchPattern(op.getZeroPoint(), m_TorchConstantInt(&zpConst)))
return rewriter.notifyMatchFailure(op,
"zero point must be a Scalar constant");

// Get input and result types.
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto resultTy = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
auto resultElemTy = resultTy.getElementType();

// Rescale the input: input * (1.0 / scale)
auto scaleReciprocal = 1.0 / scaleConst;
auto scaleConstTensor = tosa::getConstTensor<float>(
rewriter, op, scaleReciprocal, {}, inputElemTy)
.value();
if (mlir::tosa::EqualizeRanks(rewriter, loc, input, scaleConstTensor)
.failed())
return rewriter.notifyMatchFailure(
op, "Failed to equalize ranks among operands");
Value rescaledInput = tosa::createMulOpAndCast(
rewriter, op, inputTy, input, scaleConstTensor, /*shift =*/0);

// Round
auto rounded =
tosa::createRoundHalfToEven(rewriter, op, rescaledInput, inputTy);
if (!rounded) {
return rewriter.notifyMatchFailure(
op, "failed to implement round-half-to-even with TOSA ops");
}

// Cast to the destination integer type.
auto intermediateIntTy = resultTy.clone(resultElemTy);
Value castToInt =
tosa::CastOp::create(rewriter, loc, intermediateIntTy, *rounded);

// Add the zero point.
Value zpTensor =
tosa::createZeroPointTensor(rewriter, loc, intermediateIntTy, zpConst)
.value();
if (mlir::tosa::EqualizeRanks(rewriter, loc, castToInt, zpTensor).failed())
return failure();
Value withZp = tosa::AddOp::create(rewriter, loc, intermediateIntTy,
castToInt, zpTensor);

// Clamp the result to the valid range of the quantized type.
std::optional<int64_t> minInt,
maxInt; // no initialization needed as we want to clamp to the numeric
// limits of the type
IntegerAttr minIntAttr, maxIntAttr;
if (failed(tosa::getIntegerClampAttrs(rewriter, op, resultElemTy, minInt,
maxInt, minIntAttr, maxIntAttr))) {
return failure();
}
Value clamped = tosa::ClampOp::create(
rewriter, loc, resultTy, withZp, minIntAttr, maxIntAttr,
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));

rewriter.replaceOp(op, clamped);
return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -9713,6 +9686,7 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
INSERT_ATENOP_PATTERN(AtenTanOp);
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
INSERT_ATENOP_PATTERN(AtenDequantizeTensorOp);
INSERT_ATENOP_PATTERN(AtenQuantizePerTensorOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
Loading
Loading