diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 483057f8ccb9..06c867fe0562 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -3,6 +3,8 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// // @@ -44,6 +46,8 @@ std::unique_ptr createTosaLayerwiseConstantFoldPass( const TosaLayerwiseConstantFoldPassOptions &options); std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); +std::unique_ptr createSinkInputOpsThroughConcatPass(llvm::raw_ostream &); +std::unique_ptr createSinkInputOpsThroughConcatPass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); std::unique_ptr createTosaOptionalDecompositions(); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 52d8fb46fb7e..79e312cdcbcd 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -3,6 +3,7 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its affiliates // //===----------------------------------------------------------------------===// // @@ -142,4 +143,24 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> { }]; } +def SinkInputOpsThroughConcat : Pass<"sink-input-ops-through-concat", "mlir::ModuleOp"> { + let summary = "Sinks same operation through a Concat operation"; + let description = [{ + Pass that sinks the same operation through a concatenation, simplifying + later optimizations. + + To explain with a picture: + ``` + Replacing with + - Op -\ -\ + - Op ---> Concat ==> ---> Concat -> Op + - Op -/ -/ + ``` + + The pass works greedy (i.e., it sinks operator chains) and does not do any + cost-benefit assessment. It it restricted to an explicit list of tosa + operations as input for the concatentation that are known to be "sinkable". + }]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 9c3345b617cc..73aec6dabf49 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,3 +1,5 @@ +# Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its affiliates + add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp @@ -9,6 +11,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaReduceTransposes.cpp TosaTypeConverters.cpp TosaValidation.cpp + SinkInputOpsThroughConcat.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/SinkInputOpsThroughConcat.cpp b/mlir/lib/Dialect/Tosa/Transforms/SinkInputOpsThroughConcat.cpp new file mode 100644 index 000000000000..b0104bf4825f --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/SinkInputOpsThroughConcat.cpp @@ -0,0 +1,450 @@ +//===- SinkInputOpsThroughConcat.cpp +//-------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// (c) Copyright 2025 Advanced Micro Devices, Inc. or its affiliates +// +//===----------------------------------------------------------------------===// + +// ---------- +// Motivation: +// ---------- + +// Sometimes, models contain the same operation multiple times with an immediate +// concatenation of the results afterwards. In some cases, it's possible to sink +// these operations through the concat to enable better optimizations +// afterwards. + +// ------------------- +// High-Level Overview: +// ------------------- +// +// Replacing with +// - Op -\ -\ +// - OP ---> Concat ==> ---> Concat -> Op +// - Op -/ -/ +// + +// ----------- +// Overall design: +// ----------- +// +// The pass uses an allowlist of operations that are known to be sinkable. +// Additionally, it consists of mainly three parts/pattern matchers: +// 1. A generic matcher (SinkGenericOp): It doesn't do anything except +// outputting statistics. +// 2. A more specific matcher that can be specialized (SinkSpecificOp): It's +// implemented as template and accepts extra operation specific checks but +// always perfoms the same transformation. +// 3. Operation specific matcher with an operation specific transformation +// (currently just for reshape). +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" + +#include +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_SINKINPUTOPSTHROUGHCONCAT +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// TOSA Sink input Ops through Concat Pass +//===----------------------------------------------------------------------===// + +namespace { +struct SinkGenericOp : public OpRewritePattern { + SinkGenericOp(MLIRContext *context, PatternBenefit benefit, + llvm::StringMap &operationFrequency, + llvm::raw_ostream &os) + : OpRewritePattern::OpRewritePattern(context, benefit), + operationFrequency(operationFrequency), os(os) {} + + LogicalResult matchAndRewrite(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const override { + + Operation *sample = nullptr; + for (auto val : concatOp->getOperands()) { + if (!val.hasOneUse()) + return rewriter.notifyMatchFailure( + concatOp, "Operands must just connect to this concat."); + + auto *genericOp = val.getDefiningOp(); + if (!genericOp) + return rewriter.notifyMatchFailure( + concatOp, "Requires all operands to be operators"); + + if (!sample) + sample = genericOp; + + if (sample->getName().getIdentifier() != + genericOp->getName().getIdentifier()) + return rewriter.notifyMatchFailure( + concatOp, "Requires all operands to be the same"); + } + auto opName = sample->getName().getStringRef(); + auto amount = ++operationFrequency[opName]; + if (amount == 1 || amount % 50 == 0) + os << "SinkInputOpsThroughConcat: Operation " << opName + << " -> Matched amount: " << amount << "\n"; + return rewriter.notifyMatchFailure(concatOp, + "Only for statistics printing"); + } + +private: + llvm::StringMap &operationFrequency; + llvm::raw_ostream &os; +}; + +template +struct SinkSpecificOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const override { + auto sampleOrError = doGenericChecks(concatOp, rewriter); + if (failed(sampleOrError)) + return sampleOrError; + OpT sample = *sampleOrError; + + const auto extraChecks = doExtraChecks(sample, concatOp, rewriter); + if (!extraChecks.succeeded()) + return extraChecks; + + // rewriting, first collect all inputs and rearrange them for the concats + // to visualize the process + // + // A -. + // B -- (pOp1) --. + // C -* |- (concatOp) + // D -. | + // E -- (pOp2) --* + // F -* + // + // The for loops traverse it in this order: ((A, B, C), (D, E, F)) + // It is assigned in reverse order in the vectors: + // [[A, D], [B, E], [C, F]] + SmallVector> concatOperands(sample->getNumOperands()); + for (auto val : concatOp.getOperands()) { + auto *producerOp = val.getDefiningOp(); + assert(producerOp != nullptr && + "Previous check about null already happened."); + for (unsigned j = 0; j < producerOp->getNumOperands(); ++j) { + concatOperands[j].emplace_back(producerOp->getOperand(j)); + } + } + + // then create the concats and replacement op + SmallVector concatOps; + for (auto ops : concatOperands) { + auto concatOpReplacement = rewriter.create( + concatOp.getLoc(), ops, concatOp.getAxis()); + concatOps.emplace_back(concatOpReplacement); + } + rewriter.replaceOpWithNewOp(concatOp, concatOp->getResultTypes(), + concatOps, sample->getAttrs()); + return success(); + } + +protected: + llvm::FailureOr doGenericChecks(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const { + OpT sample = nullptr; + + for (auto val : concatOp->getOperands()) { + if (!val.hasOneUse()) + return rewriter.notifyMatchFailure( + concatOp, "Operands must just connect to this concat."); + + auto op = val.getDefiningOp(); + if (!op) + return rewriter.notifyMatchFailure( + concatOp, Twine("Operand is not a ") + OpT::getOperationName()); + + if (!sample) + sample = op; + + if (!llvm::equal(op->getOperandTypes(), sample->getOperandTypes())) + return rewriter.notifyMatchFailure( + concatOp, "Requires all operand types to be the same"); + + if (llvm::any_of(OpT::getAttributeNames(), [&](const auto &name) { + return sample->getAttr(name) != op->getAttr(name); + })) + return rewriter.notifyMatchFailure( + concatOp, "Requires all operand attributes to be the same"); + } + + if (sample->getNumOperands() == 0) { + return rewriter.notifyMatchFailure( + concatOp, "Requires all operands to have one or more inputs"); + } + + return sample; + } + + virtual LogicalResult doExtraChecks(OpT, tosa::ConcatOp, + PatternRewriter &) const { + return success(); + } +}; + +template +struct SinkElementwiseBroadcastableOp : public SinkSpecificOp { + using SinkSpecificOp::SinkSpecificOp; + +protected: + // check that the broadcast happens on another axis + LogicalResult doExtraChecks(OpT op, tosa::ConcatOp concat, + PatternRewriter &rewriter) const override { + SmallVector> shapes; + for (auto ty : op->getOperandTypes()) { + // lifetime bound to underlying ShapedType object + auto tenType = dyn_cast(ty); + if (tenType && tenType.hasStaticShape()) { + shapes.emplace_back(tenType.getShape()); + } else { + return rewriter.notifyMatchFailure( + concat, "Check for broadcast on an unshaped or not static type."); + } + } + assert(shapes.size() == op->getNumOperands() && + "Something went wrong with the above loop."); + + // check that the ranks on the axis that the concat uses is not one for all + // ops or one for all ops + const auto axis = concat.getAxis(); + const size_t oneDimensions = + llvm::count_if(shapes, [&](const auto &s) { return s[axis] == 1; }); + if (oneDimensions != shapes.size() && oneDimensions != 0) { + return rewriter.notifyMatchFailure( + concat, "Operand broadcasts on same axis then concat."); + } + return success(); + } +}; + +template +struct SinkReduceOp : public SinkSpecificOp { + using SinkSpecificOp::SinkSpecificOp; + +protected: + // check that the concat happens on another axis than this one + LogicalResult doExtraChecks(OpT op, tosa::ConcatOp concat, + PatternRewriter &rewriter) const override { + if (op.getAxis() == concat.getAxis()) + return rewriter.notifyMatchFailure( + concat, "Operator must not be on the same axis than concat."); + return success(); + } +}; + +struct SinkMatmulOp : public SinkSpecificOp { + using SinkSpecificOp::SinkSpecificOp; + +protected: + // check that the concat happens on another axis than the one used by matmul + LogicalResult doExtraChecks(tosa::MatMulOp op, tosa::ConcatOp concat, + PatternRewriter &rewriter) const override { + if (concat.getAxis() != 0) + return rewriter.notifyMatchFailure( + concat, "Matmul concat about different axis not yet supported"); + return success(); + } +}; + +struct SinkReshapeOp : public SinkSpecificOp { + using SinkSpecificOp::SinkSpecificOp; + + LogicalResult matchAndRewrite(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const override { + auto reshapeOrError = doGenericChecks(concatOp, rewriter); + if (failed(reshapeOrError)) + return reshapeOrError; + auto reshape = *reshapeOrError; + + const auto tenType = reshape.getInput1().getType(); + if (!tenType || !tenType.hasStaticShape()) + return rewriter.notifyMatchFailure( + concatOp, "Dynamic shapes for reshapes are not supported."); + const ArrayRef shapeBeforeReshape = tenType.getShape(); + const ArrayRef shapeAfterReshape = reshape.getNewShape(); + if (shapeBeforeReshape.size() == 0) + return rewriter.notifyMatchFailure( + concatOp, + "Tensors of rank 0 cannot have an independent concat axis."); + + // Approach: Before rewrite, we have a reshape followed by a concat. + // This concat concatenates on the concatAxisAfterReshape. For switching, we + // need to calculate a new shape and a new concat axis + // (concatAxisBeforeReshape). For that, we check the product of the reshape + // dimensions before the concatAxisAfterReshape and match that with the + // product of the dimensions of the shapeBeforeReshape. + // + // Example: + // 6x1x6 --(reshape)--> 2x3x1x2x3 --(concat)--> 2x3x2x2x3 + // concatAxisAfterReshape is 2, after rewrite it would be 1: + // 6x1x6 --(concat)--> 6x2x6 --(reshape)--> 2x3x2x2x3 + // + // The axis is not always unique: + // 1x1x6 --(reshape)--> 1x1x1x6 --(concat)--> 2x1x1x6 + // concatAxisAfterReshape is 0, after rewrite it can be 0 or 1: + // 1x1x6 --(concat)--> 2x1x6 --(reshape)--> 2x1x1x6 + // 1x1x6 --(concat)--> 1x2x6 --(reshape)--> 2x1x1x6 + // + // We also need to take the concat dimension into account: + // 1x4x1 --(reshape)--> 1x4 --(concat)--> 1x8 + // concatAxisAfterReshape is 1, after rewrite it would be 1 as well: + // 1x4x1 --(concat)--> 1x8x1 --(reshape)--> 1x8 + const uint32_t concatAxisAfterReshape = concatOp.getAxis(); + int64_t prefixProductAfterReshape = 1; + // also count the concat dimension itself + for (size_t i = 0; i <= concatAxisAfterReshape; ++i) { + prefixProductAfterReshape *= shapeAfterReshape[i]; + } + + int64_t prefixProductBeforeReshape = 1; + std::optional concatAxisBeforeReshape = std::nullopt; + long sizeOfConcatDim = shapeAfterReshape[concatAxisAfterReshape]; + for (size_t i = 0; i < shapeBeforeReshape.size(); ++i) { + prefixProductBeforeReshape *= shapeBeforeReshape[i]; + if (prefixProductBeforeReshape == prefixProductAfterReshape && + shapeBeforeReshape[i] == sizeOfConcatDim) { + concatAxisBeforeReshape = i; + break; + } + } + + if (!concatAxisBeforeReshape) + return rewriter.notifyMatchFailure( + concatOp, "Sinking reshape not possible. No compatible dimension for " + "concat axis found."); + + SmallVector concatOperands; + for (auto val : concatOp.getOperands()) { + auto *producerOp = val.getDefiningOp(); + assert(producerOp != nullptr && + "Previous check about null already happened."); + for (auto val : producerOp->getOperands()) { + concatOperands.emplace_back(val); + } + } + auto concatNew = rewriter.create( + concatOp.getLoc(), concatOperands, *concatAxisBeforeReshape); + // calculate new shape for reshape by combining the shape of the concat with + // the remaining prefix of the old reshape + Type concatType = concatNew.getType(); + auto concatShapeT = cast(concatType); + assert(concatShapeT.hasStaticShape() && + "op and thus concat must be static"); + auto concatShape = concatShapeT.getShape(); + + SmallVector reshapeShape(shapeAfterReshape); + reshapeShape[concatAxisAfterReshape] = + concatShape[*concatAxisBeforeReshape]; + + auto reshapeNew = rewriter.create(reshape.getLoc(), + concatNew, reshapeShape); + rewriter.replaceOp(concatOp, reshapeNew); + return success(); + } +}; + +struct SinkInputOpsThroughConcat + : public tosa::impl::SinkInputOpsThroughConcatBase< + SinkInputOpsThroughConcat> { + + SinkInputOpsThroughConcat() = default; + SinkInputOpsThroughConcat(llvm::raw_ostream &os) : os(os) {}; + + void runOnOperation() override { + auto func = getOperation(); + RewritePatternSet patterns(func.getContext()); + MLIRContext *ctx = func.getContext(); + + populateSinkInputOpsThroughConcatPattern(patterns, ctx); + + if (applyPatternsGreedily(func, std::move(patterns)).failed()) + signalPassFailure(); + } + +private: + void populateSinkInputOpsThroughConcatPattern(RewritePatternSet &patterns, + MLIRContext *ctx) { + // elementwise + patterns + .add, SinkSpecificOp, + SinkSpecificOp, SinkSpecificOp, + SinkSpecificOp, SinkSpecificOp, + SinkSpecificOp, SinkSpecificOp, + SinkSpecificOp, SinkSpecificOp, + SinkSpecificOp, SinkSpecificOp, + SinkSpecificOp, SinkSpecificOp, + SinkSpecificOp, SinkSpecificOp>( + ctx, /*benefit=*/2); + patterns.add, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp, + SinkElementwiseBroadcastableOp>(ctx, + /*benefit=*/2); + // reduce + patterns + .add, SinkReduceOp, + SinkReduceOp, SinkReduceOp, + SinkReduceOp, SinkReduceOp>( + ctx, /*benefit=*/2); + // others + patterns.add(ctx, /*benefit=*/2); + patterns.add(ctx, /*benefit=*/2); + patterns.add(ctx, /*benefit=*/1, operationFrequency, os); + } + + llvm::StringMap operationFrequency; + llvm::raw_ostream &os = llvm::errs(); +}; +} // namespace + +std::unique_ptr +mlir::tosa::createSinkInputOpsThroughConcatPass(llvm::raw_ostream &os) { + return std::make_unique(os); +} + +std::unique_ptr mlir::tosa::createSinkInputOpsThroughConcatPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Tosa/sink-input-ops-through-concat.mlir b/mlir/test/Dialect/Tosa/sink-input-ops-through-concat.mlir new file mode 100644 index 000000000000..d422bf3193e8 --- /dev/null +++ b/mlir/test/Dialect/Tosa/sink-input-ops-through-concat.mlir @@ -0,0 +1,1086 @@ +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its affiliates +// RUN: mlir-opt --split-input-file --sink-input-ops-through-concat %s | FileCheck %s + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<3x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.add %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.add %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.add %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<3x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xf32>, tensor<3x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat_no_op(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %0 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_no_op +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_0_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat_not_same(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.add %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.abs %arg0 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_not_same +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.add [[PARAM_0_]], [[PARAM_1_]] : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<1x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.abs [[PARAM_0_]] : (tensor<1x8x8xf32>) -> tensor<1x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi32> +!out_type = tensor<2x8x8xi32> +func.func @switch_op_concat_without_input() -> !out_type { + %0 = "tosa.const"() {value = dense<0> : !in_type} : () -> !in_type + %1 = "tosa.const"() {value = dense<2> : !in_type} : () -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_without_input +// CHECK-SAME: () -> tensor<2x8x8xi32> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1x8x8xi32>}> : () -> tensor<1x8x8xi32> +// CHECK: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2> : tensor<1x8x8xi32>}> : () -> tensor<1x8x8xi32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 0 : i32} : (tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<2x8x8xi32> +// CHECK: return [[VAR_2_]] : tensor<2x8x8xi32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat_mergable_broadcasting(%arg0: !in_type, %arg1: !in_type, %broadcast: tensor<1x1x8xf32>) -> !out_type { + %0 = tosa.add %arg0, %broadcast: (!in_type, tensor<1x1x8xf32>) -> !in_type + %1 = tosa.add %arg1, %broadcast: (!in_type, tensor<1x1x8xf32>) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_mergable_broadcasting +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>, [[PARAM_2_:%.+]]: tensor<1x1x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.concat [[PARAM_2_]], [[PARAM_2_]] {axis = 0 : i32} : (tensor<1x1x8xf32>, tensor<1x1x8xf32>) -> tensor<2x1x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[VAR_0_]], [[VAR_1_]] : (tensor<2x8x8xf32>, tensor<2x1x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x16x8xf32> +func.func @switch_op_concat_unmergable_broadcasting1(%arg0: !in_type, %arg1: !in_type, %broadcast: tensor<1x1x8xf32>) -> !out_type { + %0 = tosa.add %arg0, %broadcast: (!in_type, tensor<1x1x8xf32>) -> !in_type + %1 = tosa.add %arg1, %broadcast: (!in_type, tensor<1x1x8xf32>) -> !in_type + %2 = tosa.concat %0, %1 {axis = 1 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_unmergable_broadcasting1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>, [[PARAM_2_:%.+]]: tensor<1x1x8xf32>) -> tensor<1x16x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.add [[PARAM_0_]], [[PARAM_2_]] : (tensor<1x8x8xf32>, tensor<1x1x8xf32>) -> tensor<1x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.add [[PARAM_1_]], [[PARAM_2_]] : (tensor<1x8x8xf32>, tensor<1x1x8xf32>) -> tensor<1x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<1x16x8xf32> +// CHECK: return [[VAR_2_]] : tensor<1x16x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat_unmergable_broadcasting2(%arg0: !in_type, %arg1: !in_type, %broadcast: tensor<1x1x8xf32>) -> !out_type { + %0 = tosa.add %arg0, %broadcast: (!in_type, tensor<1x1x8xf32>) -> !in_type + %1 = tosa.add %arg1, %arg1: (!in_type, !in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_unmergable_broadcasting2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>, [[PARAM_2_:%.+]]: tensor<1x1x8xf32>) -> tensor<2x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.add [[PARAM_0_]], [[PARAM_2_]] : (tensor<1x8x8xf32>, tensor<1x1x8xf32>) -> tensor<1x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.add [[PARAM_1_]], [[PARAM_1_]] : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<1x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type1 = tensor<1x8x3xf32> +!in_type2 = tensor<1x3x6xf32> +!out_type = tensor<1x8x6xf32> +!concat_type = tensor<2x8x6xf32> +func.func @switch_op_concat_matmul1(%arg0: !in_type1, %arg1: !in_type2, %arg2: !in_type2) -> !concat_type { + %0 = tosa.matmul %arg0, %arg1: (!in_type1, !in_type2) -> !out_type + %1 = tosa.matmul %arg0, %arg2: (!in_type1, !in_type2) -> !out_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} + +// CHECK-LABEL: func.func @switch_op_concat_matmul1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x3xf32>, [[PARAM_1_:%.+]]: tensor<1x3x6xf32>, [[PARAM_2_:%.+]]: tensor<1x3x6xf32>) -> tensor<2x8x6xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x3xf32>, tensor<1x8x3xf32>) -> tensor<2x8x3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 0 : i32} : (tensor<1x3x6xf32>, tensor<1x3x6xf32>) -> tensor<2x3x6xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<2x8x3xf32>, tensor<2x3x6xf32>) -> tensor<2x8x6xf32> +// CHECK: return [[VAR_2_]] : tensor<2x8x6xf32> +// CHECK: } +// ----- + +!in_type1 = tensor<1x8x3xf32> +!in_type2 = tensor<1x3x6xf32> +!out_type = tensor<1x8x6xf32> +!concat_type = tensor<1x16x6xf32> +func.func @switch_op_concat_matmul2(%arg0: !in_type1, %arg1: !in_type2, %arg2: !in_type2) -> !concat_type { + %0 = tosa.matmul %arg0, %arg1: (!in_type1, !in_type2) -> !out_type + %1 = tosa.matmul %arg0, %arg2: (!in_type1, !in_type2) -> !out_type + %2 = tosa.concat %0, %1 {axis = 1 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} + +// CHECK-LABEL: func.func @switch_op_concat_matmul2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x3xf32>, [[PARAM_1_:%.+]]: tensor<1x3x6xf32>, [[PARAM_2_:%.+]]: tensor<1x3x6xf32>) -> tensor<1x16x6xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.matmul [[PARAM_0_]], [[PARAM_1_]] : (tensor<1x8x3xf32>, tensor<1x3x6xf32>) -> tensor<1x8x6xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.matmul [[PARAM_0_]], [[PARAM_2_]] : (tensor<1x8x3xf32>, tensor<1x3x6xf32>) -> tensor<1x8x6xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x8x6xf32>, tensor<1x8x6xf32>) -> tensor<1x16x6xf32> +// CHECK: return [[VAR_2_]] : tensor<1x16x6xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat_attr(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0 : f32, min_int = 0 : i64, max_int = 1 : i64} : (!in_type) -> !in_type + %1 = tosa.clamp %arg1 {min_fp = 0.0 : f32, max_fp = 1.0 : f32, min_int = 0 : i64, max_int = 1 : i64} : (!in_type) -> !in_type + %3 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %3 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_attr +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.clamp [[VAR_0_]] {max_fp = 1.000000e+00 : f32, max_int = 1 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat_attr_mismatch(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 2.0 : f32, min_int = 0 : i64, max_int = 1 : i64} : (!in_type) -> !in_type + %1 = tosa.clamp %arg1 {min_fp = 0.0 : f32, max_fp = 1.0 : f32, min_int = 0 : i64, max_int = 1 : i64} : (!in_type) -> !in_type + %3 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %3 : !out_type +} + +// CHECK-LABEL: func.func @switch_op_concat_attr_mismatch +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 2.000000e+00 : f32, max_int = 1 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x8x8xf32>) -> tensor<1x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.clamp [[PARAM_1_]] {max_fp = 1.000000e+00 : f32, max_int = 1 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x8x8xf32>) -> tensor<1x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x1x8xf32> +!concat_type = tensor<2x1x8xf32> +func.func @switch_op_concat_axis(%arg0: !in_type, %arg1: !in_type) -> !concat_type { + %0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (!in_type) -> !out_type + %1 = tosa.reduce_max %arg1 {axis = 1 : i32} : (!in_type) -> !out_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} + +// CHECK-LABEL: func.func @switch_op_concat_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x1x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reduce_max [[VAR_0_]] {axis = 1 : i32} : (tensor<2x8x8xf32>) -> tensor<2x1x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x1x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x1x8xf32> +!concat_type = tensor<1x2x8xf32> +func.func @switch_op_concat_axis_mismatch(%arg0: !in_type, %arg1: !in_type) -> !concat_type { + %0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (!in_type) -> !out_type + %1 = tosa.reduce_max %arg1 {axis = 1 : i32} : (!in_type) -> !out_type + %2 = tosa.concat %0, %1 {axis = 1 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} + +// CHECK-LABEL: func.func @switch_op_concat_axis_mismatch +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<1x2x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reduce_max [[PARAM_0_]] {axis = 1 : i32} : (tensor<1x8x8xf32>) -> tensor<1x1x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_max [[PARAM_1_]] {axis = 1 : i32} : (tensor<1x8x8xf32>) -> tensor<1x1x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x8xf32>, tensor<1x1x8xf32>) -> tensor<1x2x8xf32> +// CHECK: return [[VAR_2_]] : tensor<1x2x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x1x5xf32> +!re_type = tensor<1x1x1x5xf32> +!out_type = tensor<2x1x1x5xf32> +func.func @reshape_other_axis(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.reshape %arg0 {new_shape = array} : (!in_type) -> !re_type + %1 = tosa.reshape %arg1 {new_shape = array} : (!in_type) -> !re_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!re_type, !re_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @reshape_other_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x5xf32>, [[PARAM_1_:%.+]]: tensor<1x1x5xf32>) -> tensor<2x1x1x5xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x1x5xf32>, tensor<1x1x5xf32>) -> tensor<2x1x5xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<2x1x5xf32>) -> tensor<2x1x1x5xf32> +// CHECK: return [[VAR_1_]] : tensor<2x1x1x5xf32> +// CHECK: } +// ----- + +!in_type = tensor<6x1x6xf32> +!out_type = tensor<2x3x2x2x3xf32> +!re_type = tensor<2x3x1x2x3xf32> +func.func @reshape_with_product(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.reshape %arg0 {new_shape = array} : (!in_type) -> !re_type + %1 = tosa.reshape %arg1 {new_shape = array} : (!in_type) -> !re_type + %2 = tosa.concat %0, %1 {axis = 2 : i32} : (!re_type, !re_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @reshape_with_product +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<6x1x6xf32>, [[PARAM_1_:%.+]]: tensor<6x1x6xf32>) -> tensor<2x3x2x2x3xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 1 : i32} : (tensor<6x1x6xf32>, tensor<6x1x6xf32>) -> tensor<6x2x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<6x2x6xf32>) -> tensor<2x3x2x2x3xf32> +// CHECK: return [[VAR_1_]] : tensor<2x3x2x2x3xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x4x1xf32> +!out_type = tensor<1x8xf32> +!re_type = tensor<1x4xf32> +func.func @reshape_big_concat_axis(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.reshape %arg0 {new_shape = array} : (!in_type) -> !re_type + %1 = tosa.reshape %arg1 {new_shape = array} : (!in_type) -> !re_type + %2 = tosa.concat %0, %1 {axis = 1 : i32} : (!re_type, !re_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @reshape_big_concat_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x1xf32>, [[PARAM_1_:%.+]]: tensor<1x4x1xf32>) -> tensor<1x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 1 : i32} : (tensor<1x4x1xf32>, tensor<1x4x1xf32>) -> tensor<1x8x1xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x8xf32> +// CHECK: return [[VAR_1_]] : tensor<1x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<6x1x6xf32> +!out_type = tensor<2x6x1x2x3xf32> +!re_type = tensor<2x3x1x2x3xf32> +func.func @reshape_fail1(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.reshape %arg0 {new_shape = array} : (!in_type) -> !re_type + %1 = tosa.reshape %arg1 {new_shape = array} : (!in_type) -> !re_type + %2 = tosa.concat %0, %1 {axis = 1 : i32} : (!re_type, !re_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @reshape_fail1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<6x1x6xf32>, [[PARAM_1_:%.+]]: tensor<6x1x6xf32>) -> tensor<2x6x1x2x3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<6x1x6xf32>) -> tensor<2x3x1x2x3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<6x1x6xf32>) -> tensor<2x3x1x2x3xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<2x3x1x2x3xf32>, tensor<2x3x1x2x3xf32>) -> tensor<2x6x1x2x3xf32> +// CHECK: return [[VAR_2_]] : tensor<2x6x1x2x3xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x6xf32> +!out_type = tensor<6x2xf32> +!re_type = tensor<6x1xf32> +func.func @reshape_fail2(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.reshape %arg0 {new_shape = array} : (!in_type) -> !re_type + %1 = tosa.reshape %arg1 {new_shape = array} : (!in_type) -> !re_type + %2 = tosa.concat %0, %1 {axis = 1 : i32} : (!re_type, !re_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @reshape_fail2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x6xf32>, [[PARAM_1_:%.+]]: tensor<1x6xf32>) -> tensor<6x2xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<1x6xf32>) -> tensor<6x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1x6xf32>) -> tensor<6x1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<6x1xf32>, tensor<6x1xf32>) -> tensor<6x2xf32> +// CHECK: return [[VAR_2_]] : tensor<6x2xf32> +// CHECK: } +// ----- + +!in_type1 = tensor<1x6xf32> +!in_type2 = tensor<2x3xf32> +!out_type = tensor<2x1x6xf32> +!re_type = tensor<1x1x6xf32> +func.func @reshape_fail3(%arg0: !in_type1, %arg1: !in_type2) -> !out_type { + %0 = tosa.reshape %arg0 {new_shape = array} : (!in_type1) -> !re_type + %1 = tosa.reshape %arg1 {new_shape = array} : (!in_type2) -> !re_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!re_type, !re_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @reshape_fail3 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x6xf32>, [[PARAM_1_:%.+]]: tensor<2x3xf32>) -> tensor<2x1x6xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<1x6xf32>) -> tensor<1x1x6xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<2x3xf32>) -> tensor<1x1x6xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 0 : i32} : (tensor<1x1x6xf32>, tensor<1x1x6xf32>) -> tensor<2x1x6xf32> +// CHECK: return [[VAR_2_]] : tensor<2x1x6xf32> +// CHECK: } +// ----- + +!in_type = tensor +!out_type = tensor<2xf32> +!re_type = tensor<1xf32> +func.func @reshape_tensor0(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.reshape %arg0 {new_shape = array} : (!in_type) -> !re_type + %1 = tosa.reshape %arg1 {new_shape = array} : (!in_type) -> !re_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!re_type, !re_type) -> !out_type + return %2 : !out_type +} + +// CHECK-LABEL: func.func @reshape_tensor0 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor<2xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 0 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2xf32> +// CHECK: return [[VAR_2_]] : tensor<2xf32> +// CHECK: } +// ----- + +// This excerpt is a transformation of two convolutions with a kernel size of 1x1 outputting into the same concat on the kernel axis. +// They were transformed to a matmul to be supported by concat sinking. + +func.func @reshape_complex_match(%arg0: tensor<1x42x1x1xbf16>, %arg1: tensor<1x42x1x1xbf16>, %arg2: tensor<12x42x1x1xbf16>, %arg3: tensor<12x42x1x1xbf16>) -> tensor<1x2x1x12xbf16> { + %0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %1 = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32> + %2 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + %3 = tosa.reshape %arg0 {new_shape = array} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> + %4 = tosa.transpose %3, %1 : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> + %5 = tosa.reshape %4 {new_shape = array} : (tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x42xbf16> + %6 = tosa.reshape %arg2 {new_shape = array} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> + %7 = tosa.transpose %6, %0 : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> + %8 = tosa.matmul %5, %7 : (tensor<1x1x42xbf16>, tensor<1x42x12xbf16>) -> tensor<1x1x12xbf16> + %9 = tosa.reshape %8 {new_shape = array} : (tensor<1x1x12xbf16>) -> tensor<1x1x1x12xbf16> + %10 = tosa.reshape %arg1 {new_shape = array} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> + %11 = tosa.transpose %10, %1 : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> + %12 = tosa.reshape %11 {new_shape = array} : (tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x42xbf16> + %13 = tosa.reshape %arg3 {new_shape = array} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> + %14 = tosa.transpose %13, %0 : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> + %15 = tosa.matmul %12, %14 : (tensor<1x1x42xbf16>, tensor<1x42x12xbf16>) -> tensor<1x1x12xbf16> + %16 = tosa.reshape %15 {new_shape = array} : (tensor<1x1x12xbf16>) -> tensor<1x1x1x12xbf16> + %17 = tosa.concat %9, %16 {axis = 1 : i32} : (tensor<1x1x1x12xbf16>, tensor<1x1x1x12xbf16>) -> tensor<1x2x1x12xbf16> + return %17 : tensor<1x2x1x12xbf16> +} + +// CHECK-LABEL: func.func @reshape_complex_match +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x42x1x1xbf16>, [[PARAM_1_:%.+]]: tensor<1x42x1x1xbf16>, [[PARAM_2_:%.+]]: tensor<12x42x1x1xbf16>, [[PARAM_3_:%.+]]: tensor<12x42x1x1xbf16>) -> tensor<1x2x1x12xbf16> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_1_]] : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> +// CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.transpose [[VAR_4_]], [[VAR_0_]] : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = tosa.transpose [[VAR_6_]], [[VAR_1_]] : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> +// CHECK-DAG: [[VAR_8_:%.+]] = tosa.reshape [[PARAM_3_]] {new_shape = array} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = tosa.transpose [[VAR_8_]], [[VAR_0_]] : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> +// CHECK-DAG: [[VAR_10_:%.+]] = tosa.concat [[VAR_3_]], [[VAR_7_]] {axis = 0 : i32} : (tensor<1x1x1x1x1x42xbf16>, tensor<1x1x1x1x1x42xbf16>) -> tensor<2x1x1x1x1x42xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = tosa.reshape [[VAR_10_]] {new_shape = array} : (tensor<2x1x1x1x1x42xbf16>) -> tensor<2x1x42xbf16> +// CHECK-DAG: [[VAR_12_:%.+]] = tosa.concat [[VAR_5_]], [[VAR_9_]] {axis = 0 : i32} : (tensor<1x42x12xbf16>, tensor<1x42x12xbf16>) -> tensor<2x42x12xbf16> +// CHECK: [[VAR_13_:%.+]] = tosa.matmul [[VAR_11_]], [[VAR_12_]] : (tensor<2x1x42xbf16>, tensor<2x42x12xbf16>) -> tensor<2x1x12xbf16> +// CHECK: [[VAR_14_:%.+]] = tosa.reshape [[VAR_13_]] {new_shape = array} : (tensor<2x1x12xbf16>) -> tensor<1x2x1x12xbf16> +// CHECK: return [[VAR_14_]] : tensor<1x2x1x12xbf16> +// CHECK: } +// ----- + +// valid tests for all supported operations + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.abs %arg0 : (!in_type) -> !in_type + %1 = tosa.abs %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.abs [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } + +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.bitwise_not %arg0 : (!in_type) -> !in_type + %1 = tosa.bitwise_not %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.bitwise_not [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.ceil %arg0 : (!in_type) -> !in_type + %1 = tosa.ceil %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.ceil [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.clz %arg0 : (!in_type) -> !in_type + %1 = tosa.clz %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.clz [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.cos %arg0 : (!in_type) -> !in_type + %1 = tosa.cos %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.cos [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.erf %arg0 : (!in_type) -> !in_type + %1 = tosa.erf %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.erf [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.exp %arg0 : (!in_type) -> !in_type + %1 = tosa.exp %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.exp [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.floor %arg0 : (!in_type) -> !in_type + %1 = tosa.floor %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.floor [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.log %arg0 : (!in_type) -> !in_type + %1 = tosa.log %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.log [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<2x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.logical_not %arg0 : (!in_type) -> !in_type + %1 = tosa.logical_not %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<2x8x8xi1> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<2x8x8xi1> +// CHECK: [[VAR_1_:%.+]] = tosa.logical_not [[VAR_0_]] : (tensor<2x8x8xi1>) -> tensor<2x8x8xi1> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.negate %arg0 : (!in_type) -> !in_type + %1 = tosa.negate %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.negate [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.reciprocal %arg0 : (!in_type) -> !in_type + %1 = tosa.reciprocal %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reciprocal [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.rsqrt %arg0 : (!in_type) -> !in_type + %1 = tosa.rsqrt %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.rsqrt [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.sigmoid %arg0 : (!in_type) -> !in_type + %1 = tosa.sigmoid %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.sigmoid [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.sin %arg0 : (!in_type) -> !in_type + %1 = tosa.sin %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.sin [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.tanh %arg0 : (!in_type) -> !in_type + %1 = tosa.tanh %arg1 : (!in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.tanh [[VAR_0_]] : (tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.equal %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.equal %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.equal %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.equal [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.greater %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.greater %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.greater %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.greater [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.greater_equal %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.greater_equal %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.greater_equal %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.greater_equal [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.logical_and %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.logical_and %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.logical_and %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.logical_and [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.logical_or %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.logical_or %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.logical_or %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.logical_or [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<2x8x8xf32> +!select_type = tensor<1x8x8xi1> +func.func @switch_op_concat(%arg0: !select_type, %arg1: !in_type, %arg2: !in_type) -> !out_type { + %0 = tosa.select %arg0, %arg1, %arg2 : (!select_type, !in_type, !in_type) -> !in_type + %1 = tosa.select %arg0, %arg2, %arg1 : (!select_type, !in_type, !in_type) -> !in_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type + return %2 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>, [[PARAM_2_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<2x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.concat [[PARAM_2_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_0_]], [[VAR_1_]], [[VAR_2_]] : (tensor<2x8x8xi1>, tensor<2x8x8xf32>, tensor<2x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: return [[VAR_3_]] : tensor<2x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi32> +!out_type = tensor<3x8x8xi32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (!in_type, !in_type) -> !in_type + %1 = tosa.arithmetic_right_shift %arg1, %arg0 {round = false} : (!in_type, !in_type) -> !in_type + %2 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi32>, [[PARAM_1_:%.+]]: tensor<1x8x8xi32>) -> tensor<3x8x8xi32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi32>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<3x8x8xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi32>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<3x8x8xi32> +// CHECK: [[VAR_2_:%.+]] = tosa.arithmetic_right_shift [[VAR_0_]], [[VAR_1_]] {round = false} : (tensor<3x8x8xi32>, tensor<3x8x8xi32>) -> tensor<3x8x8xi32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.bitwise_and %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.bitwise_and %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.bitwise_and %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.bitwise_and [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.bitwise_or %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.bitwise_or %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.bitwise_or %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.bitwise_or [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.bitwise_xor %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.bitwise_xor %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.bitwise_xor %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.bitwise_xor [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- +!in_type = tensor<1x8x8xi32> +!out_type = tensor<3x8x8xi32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.int_div %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.int_div %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.int_div %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi32>, [[PARAM_1_:%.+]]: tensor<1x8x8xi32>) -> tensor<3x8x8xi32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi32>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<3x8x8xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi32>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<3x8x8xi32> +// CHECK: [[VAR_2_:%.+]] = tosa.int_div [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi32>, tensor<3x8x8xi32>) -> tensor<3x8x8xi32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<3x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.logical_left_shift %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.logical_left_shift %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.logical_left_shift %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<3x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.logical_left_shift [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xf32>, tensor<3x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<3x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.logical_right_shift %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.logical_right_shift %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.logical_right_shift %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<3x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.logical_right_shift [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xf32>, tensor<3x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xi1> +!out_type = tensor<3x8x8xi1> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.logical_xor %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.logical_xor %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.logical_xor %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xi1>) -> tensor<3x8x8xi1> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: [[VAR_2_:%.+]] = tosa.logical_xor [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xi1>, tensor<3x8x8xi1>) -> tensor<3x8x8xi1> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xi1> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<3x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.maximum %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.maximum %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.maximum %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<3x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.maximum [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xf32>, tensor<3x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<3x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.minimum %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.minimum %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.minimum %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<3x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.minimum [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xf32>, tensor<3x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xf32> +// CHECK: } +// ----- +!in_type = tensor<1x8x8xf32> +!out_type = tensor<3x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.pow %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.pow %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.pow %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<3x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.pow [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xf32>, tensor<3x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<3x8x8xf32> +func.func @switch_op_concat(%arg0: !in_type, %arg1: !in_type) -> !out_type { + %0 = tosa.sub %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %1 = tosa.sub %arg1, %arg0 : (!in_type, !in_type) -> !in_type + %2 = tosa.sub %arg0, %arg1 : (!in_type, !in_type) -> !in_type + %3 = tosa.concat %0, %1, %2 {axis = 0 : i32} : (!in_type, !in_type, !in_type) -> !out_type + return %3 : !out_type +} +// CHECK-LABEL: func.func @switch_op_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<3x8x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.sub [[VAR_0_]], [[VAR_1_]] : (tensor<3x8x8xf32>, tensor<3x8x8xf32>) -> tensor<3x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<3x8x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x1x8xf32> +!concat_type = tensor<2x1x8xf32> +func.func @switch_op_concat_axis(%arg0 : !in_type, %arg1 : !in_type) -> !concat_type { + %0 = tosa.reduce_all %arg0 {axis = 1 : i32} : (!in_type) -> !out_type + %1 = tosa.reduce_all %arg1 {axis = 1 : i32} : (!in_type) -> !out_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} +// CHECK-LABEL: func.func @switch_op_concat_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x1x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reduce_all [[VAR_0_]] {axis = 1 : i32} : (tensor<2x8x8xf32>) -> tensor<2x1x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x1x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x1x8xf32> +!concat_type = tensor<2x1x8xf32> +func.func @switch_op_concat_axis(%arg0 : !in_type, %arg1 : !in_type) -> !concat_type { + %0 = tosa.reduce_any %arg0 {axis = 1 : i32} : (!in_type) -> !out_type + %1 = tosa.reduce_any %arg1 {axis = 1 : i32} : (!in_type) -> !out_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} +// CHECK-LABEL: func.func @switch_op_concat_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x1x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reduce_any [[VAR_0_]] {axis = 1 : i32} : (tensor<2x8x8xf32>) -> tensor<2x1x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x1x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x1x8xf32> +!concat_type = tensor<2x1x8xf32> +func.func @switch_op_concat_axis(%arg0 : !in_type, %arg1 : !in_type) -> !concat_type { + %0 = tosa.reduce_min %arg0 {axis = 1 : i32} : (!in_type) -> !out_type + %1 = tosa.reduce_min %arg1 {axis = 1 : i32} : (!in_type) -> !out_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} +// CHECK-LABEL: func.func @switch_op_concat_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x1x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reduce_min [[VAR_0_]] {axis = 1 : i32} : (tensor<2x8x8xf32>) -> tensor<2x1x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x1x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x1x8xf32> +!concat_type = tensor<2x1x8xf32> +func.func @switch_op_concat_axis(%arg0 : !in_type, %arg1 : !in_type) -> !concat_type { + %0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (!in_type) -> !out_type + %1 = tosa.reduce_prod %arg1 {axis = 1 : i32} : (!in_type) -> !out_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} +// CHECK-LABEL: func.func @switch_op_concat_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x1x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reduce_prod [[VAR_0_]] {axis = 1 : i32} : (tensor<2x8x8xf32>) -> tensor<2x1x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x1x8xf32> +// CHECK: } +// ----- + +!in_type = tensor<1x8x8xf32> +!out_type = tensor<1x1x8xf32> +!concat_type = tensor<2x1x8xf32> +func.func @switch_op_concat_axis(%arg0 : !in_type, %arg1 : !in_type) -> !concat_type { + %0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (!in_type) -> !out_type + %1 = tosa.reduce_sum %arg1 {axis = 1 : i32} : (!in_type) -> !out_type + %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!out_type, !out_type) -> !concat_type + return %2 : !concat_type +} +// CHECK-LABEL: func.func @switch_op_concat_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x1x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.reduce_sum [[VAR_0_]] {axis = 1 : i32} : (tensor<2x8x8xf32>) -> tensor<2x1x8xf32> +// CHECK: return [[VAR_1_]] : tensor<2x1x8xf32> +// CHECK: }