Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 45 additions & 31 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,36 +398,43 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value broadcastBias =
linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);

bool localBound = false;
if (auto localBoundAttr =
op->template getAttrOfType<BoolAttr>("local_bound"))
localBound = localBoundAttr.getValue();

if (hasZp) {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);

auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);

Value conv =
rewriter
.create<LinalgConvQOp>(
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
auto conv = rewriter.create<LinalgConvQOp>(
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{broadcastBias}, strideAttr, dilationAttr);

if (localBound)
conv->setAttr("tosa.local_bound", rewriter.getBoolAttr(true));

rewriter.replaceOp(op, conv);
rewriter.replaceOp(op, conv->getResult(0));
return success();
}

Value conv = rewriter
.create<LinalgConvOp>(
loc, accTy, ValueRange{input, weight},
ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
auto conv = rewriter.create<LinalgConvOp>(
loc, accTy, ValueRange{input, weight}, ValueRange{broadcastBias},
strideAttr, dilationAttr);
Value convVal = conv.getResult(0);

if (localBound)
conv->setAttr("tosa.local_bound", rewriter.getBoolAttr(true));

// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (resultTy != accTy)
conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
convVal = rewriter.create<tosa::CastOp>(loc, resultTy, convVal);

rewriter.replaceOp(op, conv);
rewriter.replaceOp(op, convVal);
return success();
}
};
Expand Down Expand Up @@ -551,26 +558,32 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));

bool localBound = false;
if (auto localBoundAttr = op->getAttrOfType<BoolAttr>("local_bound"))
localBound = localBoundAttr.getValue();

if (hasNullZps) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight}, ValueRange{zeroTensor},
strideAttr, dilationAttr);
Value convVal = conv.getResult(0);

if (localBound)
conv->setAttr("tosa.local_bound", rewriter.getBoolAttr(true));

// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (accETy != resultETy)
conv = rewriter.create<tosa::CastOp>(
convVal = rewriter.create<tosa::CastOp>(
loc,
RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
resultETy),
conv);
RankedTensorType::get(
cast<ShapedType>(convVal.getType()).getShape(), resultETy),
convVal);

SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
loc, resultTy, conv, reassociationMap);
loc, resultTy, convVal, reassociationMap);

Value result =
rewriter
Expand All @@ -596,16 +609,17 @@ class DepthwiseConvConverter
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
Value conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{zeroTensor}, strideAttr, dilationAttr);

if (localBound)
conv->setAttr("tosa.local_bound", rewriter.getBoolAttr(true));

SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
loc, resultTy, conv, reassociationMap);
loc, resultTy, conv.getResult(0), reassociationMap);
Value result = linalgIntBroadcastExtSIAdd(
rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
rewriter.replaceOp(op, result);
Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,15 @@ class TransposeConvNonStridedConverter
auto reverse2 = rewriter.create<tosa::ReverseOp>(
loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));

bool localBound = false;
if (auto localBoundAttr = op->getAttrOfType<BoolAttr>("local_bound"))
localBound = localBoundAttr.getValue();
Value conv2d = rewriter.create<tosa::Conv2DOp>(
loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(),
rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr({1, 1}),
/* acc_type = */ op.getAccType());
/* acc_type = */ op.getAccType(), localBound);

rewriter.replaceOp(op, conv2d);
return success();
Expand Down Expand Up @@ -238,13 +241,16 @@ class TransposeConvStridedConverter
}

// Perform the convolution using the zero bias.
bool localBound = false;
if (auto localBoundAttr = op->getAttrOfType<BoolAttr>("local_bound"))
localBound = localBoundAttr.getValue();
Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias, inputZp.value(), weightZp.value(),
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/* acc_type = */ op.getAccType())
/* acc_type = */ op.getAccType(), localBound)
.getResult();

// Factor the resulting width / height.
Expand Down
66 changes: 66 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1045,3 +1045,69 @@ func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
return %0: tensor<1x4x32x62xf32>
}

// -----

// CHECK-LABEL: @conv2d_local_bound_true
func.func @conv2d_local_bound_true(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>, tosa.local_bound = true}
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
return
}

// -----

// CHECK-LABEL: @conv2d_local_bound_false
func.func @conv2d_local_bound_false(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
return
}

// -----

// CHECK-LABEL: @conv3d_local_bound_true
func.func @conv3d_local_bound_true(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>, tosa.local_bound = true}
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
return
}

// -----

// CHECK-LABEL: @conv3d_local_bound_false
func.func @conv3d_local_bound_false(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
return
}

// -----

// CHECK-LABEL: @depthwise_conv_local_bound_true
func.func @depthwise_conv_local_bound_true(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>, tosa.local_bound = true}
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
return
}

// -----

// CHECK-LABEL: @depthwise_conv_local_bound_false
func.func @depthwise_conv_local_bound_false(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
return
}
22 changes: 22 additions & 0 deletions mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,25 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
}

// -----

// CHECK-LABEL: @transpose_conv2d_with_local_bound
func.func @transpose_conv2d_with_local_bound(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x18x19x5xf32> {
// CHECK: tosa.conv2d [[ANY:.*]] local_bound = true
%zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, local_bound = true, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x18x19x5xf32>
return %0 : tensor<2x18x19x5xf32>
}

// -----

// CHECK-LABEL: @transpose_conv2d_strided_with_local_bound

func.func @transpose_conv2d_strided_with_local_bound(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
// CHECK: tosa.conv2d [[ANY:.*]] local_bound = true
%zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, local_bound = true, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x35x47x5xf32>
%1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
return %1 : tensor<2x?x?x5xf32>
}
Loading