Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
IntegerAttr quantBits, int filterQuantDim,
bool isSigned, BoolAttr narrowRange);

Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType);

} // namespace tosa
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
inputEType = quantType.getStorageType();
inputEType = getStorageElementTypeFromQuantized(quantType);
}

Attribute newMinValAttr, newMaxValAttr;
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
static Type getStorageElementTypeOrSelf(Type type) {
auto srcType = getElementTypeOrSelf(type);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
srcType = quantType.getStorageType();
srcType = getStorageElementTypeFromQuantized(quantType);
return srcType;
}

Expand Down Expand Up @@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
bool resultIsFloat = llvm::isa<FloatType>(resultEType);

if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
inputEType = quantType.getStorageType();
inputEType = getStorageElementTypeFromQuantized(quantType);

if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
weightEType = quantType.getStorageType();
weightEType = getStorageElementTypeFromQuantized(quantType);

if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
biasEType = quantType.getStorageType();
biasEType = getStorageElementTypeFromQuantized(quantType);

if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
resultEType = quantType.getStorageType();
resultEType = getStorageElementTypeFromQuantized(quantType);

if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
// for now, only enforce bias element type == result element type for
Expand Down Expand Up @@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {

if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
outputType.getElementType())) {
if (result.getStorageType() == attrType.getElementType())
if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
return success();
}

Expand All @@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();

if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
inputEType = quantType.getStorageType();
inputEType = getStorageElementTypeFromQuantized(quantType);

auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
Expand All @@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();

if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
resultEType = quantType.getStorageType();
resultEType = getStorageElementTypeFromQuantized(quantType);

return success();
}
Expand Down Expand Up @@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
llvm::cast<ShapedType>(getInput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
inputETy = quantType.getStorageType();
inputETy = getStorageElementTypeFromQuantized(quantType);
}
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
outputETy = quantType.getStorageType();
outputETy = getStorageElementTypeFromQuantized(quantType);
}
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
maxAttr, quantBits, filterQuantDim,
isSigned, narrowRange));
}

Type mlir::tosa::getStorageElementTypeFromQuantized(
quant::QuantizedType quantType) {
auto quantEty = quantType.getStorageType();
// StorageType doesn't capture the sign information
// Explicitly create unsigned type if needed
if (!quantType.isSigned()) {
quantEty = IntegerType::get(quantEty.getContext(),
quantEty.getIntOrFloatBitWidth(),
IntegerType::Unsigned);
}
return quantEty;
}
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens
return %1 : tensor<4xi8>
}

// -----

// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp
// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8}
func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
%0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
%1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
}

// -----

// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp
// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8}
func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
%0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
%1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
}

// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8}
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8}
func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
%0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
%1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
}


// -----

// CHECK-LABEL: @concat_fold
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
}

// -----
// CHECK-LABEL: clamp_quantized_unsigned
func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
%0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
}

// -----
// CHECK-LABEL: sigmoid
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
Expand Down
14 changes: 11 additions & 3 deletions mlir/test/Dialect/Tosa/quant-test.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s

// -----
// CHECK-LABEL: test_build_qtype
func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
// CHECK-LABEL: test_build_qtype_unsigned
func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
// CHECK: tosa.negate
%0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
%0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
}

// -----
// CHECK-LABEL: test_build_qtype_signed
func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> {
// CHECK: tosa.negate
%0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
}

// -----
// CHECK-LABEL: test_build_mult_and_shift
func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Tosa/verifier.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
}

// -----

func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
// expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
%0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
}