Skip to content

Commit fa72e68

Browse files
authored
TosaToTensor: Support reshape on unsigned (#179)
1 parent d7c3454 commit fa72e68

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Pass/Pass.h"
1717

1818
namespace mlir {
19+
class TypeConverter;
1920

2021
#define GEN_PASS_DECL_TOSATOTENSOR
2122
#include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
2425

2526
std::unique_ptr<Pass> createTosaToTensor();
2627

27-
void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
28+
void populateTosaToTensorConversionPatterns(TypeConverter &converter,
29+
RewritePatternSet *patterns);
2830

2931
} // namespace tosa
3032
} // namespace mlir

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,11 @@ class ReshapeConverterCollapseExpand
207207
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
208208
ConversionPatternRewriter &rewriter) const final {
209209
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
210-
ShapedType resultTy = cast<ShapedType>(reshape.getType());
210+
ShapedType resultTy = cast_if_present<ShapedType>(getTypeConverter()->convertType(reshape.getType()));
211+
if (!resultTy) {
212+
return rewriter.notifyMatchFailure(
213+
reshape.getLoc(), "could not convert result type");
214+
}
211215
bool isDynamic = !operandTy.hasStaticShape();
212216

213217
SmallVector<int64_t> intermediateShape;
@@ -218,7 +222,7 @@ class ReshapeConverterCollapseExpand
218222
"the given two shapes");
219223
}
220224
auto intermediateTy = RankedTensorType::get(
221-
intermediateShape, reshape.getType().getElementType());
225+
intermediateShape, resultTy.getElementType());
222226

223227
Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
224228
adaptor.getInput1());
@@ -415,9 +419,9 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
415419
} // namespace
416420

417421
void mlir::tosa::populateTosaToTensorConversionPatterns(
418-
RewritePatternSet *patterns) {
422+
TypeConverter &converter, RewritePatternSet *patterns) {
419423
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
420424
patterns->getContext());
421425

422-
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
426+
patterns->add<ReshapeConverterCollapseExpand>(converter, patterns->getContext());
423427
}

mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
1414

15+
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -21,6 +22,7 @@
2122
#include "mlir/Transforms/DialectConversion.h"
2223
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2324

25+
2426
namespace mlir {
2527
#define GEN_PASS_DEF_TOSATOTENSOR
2628
#include "mlir/Conversion/Passes.h.inc"
@@ -42,7 +44,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
4244
target.addLegalDialect<arith::ArithDialect>();
4345
target.addLegalDialect<tensor::TensorDialect>();
4446

45-
mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
47+
TypeConverter converter;
48+
mlir::tosa::populateTosaToLinalgTypeConversion(converter);
49+
50+
mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
4651

4752
if (failed(applyPartialConversion(getOperation(), target,
4853
std::move(patterns))))

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
5656

5757
// -----
5858

59+
// CHECK-LABEL: @test_reshape_samerank_unsigned
60+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
61+
func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
62+
// CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
63+
// CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
64+
// CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] : tensor<6xi8> into tensor<2x3xi8>
65+
// CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8>
66+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
67+
// CHECK-NEXT: return %[[CAST2]]
68+
return %0 : tensor<2x3xui8>
69+
}
70+
// -----
71+
5972
// CHECK-LABEL: @test_reshape_samerank_dyn
6073
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
6174
func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {

0 commit comments

Comments
 (0)