Skip to content

Commit 56cdd98

Browse files
committed
Allow norm + bias fusion even if there is a reshape between them
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 71948db commit 56cdd98

File tree

2 files changed

+245
-44
lines changed

2 files changed

+245
-44
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 168 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
//===----------------------------------------------------------------------===//
1818

1919
#include <math.h>
20+
#include <numeric>
2021

22+
#include "mlir/Dialect/Traits.h"
2123
#include "mlir/IR/Matchers.h"
2224
#include "mlir/IR/PatternMatch.h"
2325
#include "mlir/IR/TypeUtilities.h"
@@ -1634,7 +1636,25 @@ struct RecomposeConcatPattern : public OpRewritePattern<ONNXConcatOp> {
16341636
// Rewrite pattern LayerNormalization
16351637
// =============================================================================
16361638
namespace {
1637-
bool isValueNoneOrConstZero(Value value) {
1639+
1640+
// Checks if B is unidiretional broadcastable to A. Requires static shapes
1641+
[[nodiscard]] bool areUnidirectionalBroadcastCompatible(Type a, Type b) {
1642+
auto aShaped = dyn_cast<ShapedType>(a);
1643+
auto bShaped = dyn_cast<ShapedType>(b);
1644+
if (!aShaped || !bShaped || !aShaped.hasStaticShape() ||
1645+
!bShaped.hasStaticShape()) {
1646+
return false;
1647+
}
1648+
SmallVector<int64_t> broadcastedShape;
1649+
if (!OpTrait::util::getBroadcastedShape(
1650+
aShaped.getShape(), bShaped.getShape(), broadcastedShape)) {
1651+
return false;
1652+
}
1653+
// For unidirectional broadcasting, a and the resulting shape need to match
1654+
return aShaped.getShape() == ArrayRef<int64_t>(broadcastedShape);
1655+
}
1656+
1657+
[[nodiscard]] bool isValueNoneOrConstZero(Value value) {
16381658
if (!value) {
16391659
return false;
16401660
}
@@ -1727,63 +1747,166 @@ struct PropagateBiasIntoLayerNormRewritePattern
17271747

17281748
LogicalResult matchAndRewrite(
17291749
ONNXAddOp addOp, PatternRewriter &rewriter) const final {
1750+
PatternRewriter::InsertionGuard guard(rewriter);
17301751
using namespace onnx_mlir;
17311752
Value y, bias;
1732-
Operation *yLayerNormOp;
1733-
Operation *ywbAddOp = addOp.getOperation();
1753+
Operation *yLayerNormOp = nullptr;
1754+
Operation *reshapeOp = nullptr;
1755+
SmallVector<int64_t> newBiasShape; // only used if there is a reshape
1756+
17341757
// Match
17351758
// %noBias = "onnx.NoValue"()
17361759
// %y, %mean, %invStdDev = "onnx.LayerNormalization"(%x, %scale, %noBias)
17371760
// {axis = 2 : si64, epsilon = 9.994E-6 : f32, stash_type = 1 : si64}
1761+
// optional reshape between norm and add
17381762
// %yBias = "onnx.Add"(%y, %bias)
1739-
if (!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1740-
yLayerNormOp, ywbAddOp, y, bias, 0) &&
1741-
!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1742-
yLayerNormOp, ywbAddOp, bias, y, 1))
1743-
return reportFailure("missing y, layer norm op");
1763+
1764+
if (onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>(
1765+
reshapeOp, addOp, y, bias, 0) ||
1766+
onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>(
1767+
reshapeOp, addOp, bias, y, 1)) {
1768+
yLayerNormOp = reshapeOp->getOperand(0).getDefiningOp<OP_TYPE>();
1769+
if (!yLayerNormOp) {
1770+
return rewriter.notifyMatchFailure(
1771+
reshapeOp, "reshape op does not have a layer norm as input");
1772+
}
1773+
if (!reshapeOp->hasOneUse()) {
1774+
return rewriter.notifyMatchFailure(
1775+
reshapeOp, "reshape op does not have a single use");
1776+
}
1777+
} else {
1778+
if (!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1779+
yLayerNormOp, addOp, y, bias, 0) &&
1780+
!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1781+
yLayerNormOp, addOp, bias, y, 1))
1782+
return rewriter.notifyMatchFailure(addOp, "missing y, layer norm op");
1783+
}
1784+
17441785
// Study layer norm op; make sure its used only one and that bias is not
17451786
// used.
1746-
if (!yLayerNormOp->hasOneUse())
1747-
return reportFailure("y/layer norm has too many uses");
1787+
assert(yLayerNormOp && "yLayerNormOp should not be null");
1788+
if (!yLayerNormOp->hasOneUse()) {
1789+
return rewriter.notifyMatchFailure(
1790+
yLayerNormOp, "y/layer norm has too many uses");
1791+
}
17481792
auto lnOp = mlir::cast<OP_TYPE>(yLayerNormOp);
17491793
if (!isValueNoneOrConstZero(lnOp.getB()))
1750-
return reportFailure("layer norm already has a bias");
1794+
return rewriter.notifyMatchFailure(lnOp, "layer norm already has a bias");
1795+
1796+
if (reshapeOp) {
1797+
// if we have a reshape, check that the add is not changing the shape by
1798+
// broadcasting
1799+
auto reshapeResultType =
1800+
dyn_cast<ShapedType>(reshapeOp->getResult(0).getType());
1801+
auto addResultType = dyn_cast<ShapedType>(addOp->getResult(0).getType());
1802+
if (!reshapeResultType || !addResultType ||
1803+
!reshapeResultType.hasStaticShape() ||
1804+
!addResultType.hasStaticShape() ||
1805+
reshapeResultType.getShape() != addResultType.getShape()) {
1806+
return rewriter.notifyMatchFailure(
1807+
addOp, "incompatible shapes, add is broadcasting");
1808+
}
1809+
// Check that the bias is only on a single dimension, that is not affected
1810+
// by the reshape. The bias could be multi-dimentional, but this increases
1811+
// the complexity and was not seen in models
1812+
auto biasType = dyn_cast<ShapedType>(bias.getType());
1813+
if (!biasType || !biasType.hasStaticShape()) {
1814+
return rewriter.notifyMatchFailure(
1815+
addOp, "bias has not a static shape");
1816+
}
1817+
1818+
SmallVector<int64_t> biasRankFixedShape;
1819+
biasRankFixedShape.append(
1820+
addResultType.getRank() - biasType.getRank(), 1);
1821+
biasRankFixedShape.append(
1822+
biasType.getShape().begin(), biasType.getShape().end());
1823+
1824+
// biasRankFixedShape should have exactly one dimension that is not one
1825+
std::optional<int64_t> biasDim;
1826+
for (auto [idx, dimSize] : enumerate(biasRankFixedShape)) {
1827+
if (dimSize != 1) {
1828+
if (biasDim) {
1829+
return rewriter.notifyMatchFailure(
1830+
addOp, "bias has more than one non-one dimension");
1831+
}
1832+
biasDim = idx;
1833+
}
1834+
}
1835+
if (!biasDim) {
1836+
return rewriter.notifyMatchFailure(
1837+
addOp, "bias has no non-one dimension");
1838+
}
1839+
1840+
const auto biasShapeUntilDim =
1841+
ArrayRef<int64_t>(reshapeResultType.getShape())
1842+
.slice(0, *biasDim + 1);
1843+
const uint64_t biasShapeRelevantSize =
1844+
std::accumulate(biasShapeUntilDim.begin(), biasShapeUntilDim.end(), 1,
1845+
std::multiplies<uint64_t>());
1846+
1847+
// The bias dim should be not affected by the reshape. We need to map it
1848+
// back through it.
1849+
size_t reshapeInBiasDim;
1850+
auto reshapeInType =
1851+
dyn_cast<ShapedType>(reshapeOp->getOperand(0).getType());
1852+
if (!reshapeInType || !reshapeInType.hasStaticShape()) {
1853+
return rewriter.notifyMatchFailure(
1854+
addOp, "reshape input has not a static shape");
1855+
}
1856+
const auto reshapeInShape = reshapeInType.getShape();
1857+
1858+
// trace the dim through the reshape
1859+
uint64_t acc = 1;
1860+
for (auto [idx, dimSize] : enumerate(reshapeInShape)) {
1861+
acc *= dimSize;
1862+
if (acc == biasShapeRelevantSize) {
1863+
if (dimSize != biasRankFixedShape[*biasDim]) {
1864+
return rewriter.notifyMatchFailure(
1865+
addOp, "bias shape is not compatible with reshape input");
1866+
}
1867+
reshapeInBiasDim = idx;
1868+
break;
1869+
}
1870+
if (acc > biasShapeRelevantSize) {
1871+
return rewriter.notifyMatchFailure(
1872+
addOp, "bias shape is not compatible with reshape input");
1873+
}
1874+
}
1875+
1876+
newBiasShape.push_back(reshapeInShape[reshapeInBiasDim]);
1877+
newBiasShape.append(reshapeInShape.size() - reshapeInBiasDim - 1, 1);
1878+
}
17511879

17521880
// Norms only support unidirectional broadcating from bias to y
1753-
const auto yType = dyn_cast<ShapedType>(y.getType());
1754-
const auto addType = dyn_cast<ShapedType>(addOp.getType());
1755-
if (!yType || !addType || !yType.hasStaticShape() ||
1756-
!addType.hasStaticShape() || yType.getShape() != addType.getShape()) {
1757-
return rewriter.notifyMatchFailure(addOp, "incompatible shapes");
1881+
if (!reshapeOp && !areUnidirectionalBroadcastCompatible(
1882+
lnOp.getX().getType(), bias.getType())) {
1883+
return rewriter.notifyMatchFailure(addOp,
1884+
"layer norm and bias are not unidirectional broadcast compatible");
17581885
}
1759-
// We are fine.
1760-
Value x = lnOp.getX();
1761-
Value scale = lnOp.getScale();
1762-
FloatAttr epsilon = lnOp.getEpsilonAttr();
1763-
int64_t axis = lnOp.getAxis();
1764-
LLVM_DEBUG(llvm::dbgs() << "LayerNorm from add, axis : " << axis << "\n");
1765-
1766-
// Replace
1767-
MultiDialectBuilder<OnnxBuilder> create(
1768-
rewriter, rewriter.getFusedLoc({lnOp.getLoc(), addOp->getLoc()}));
1769-
Type xType = x.getType();
1770-
Value res;
1771-
if constexpr (std::is_same<OP_TYPE, ONNXLayerNormalizationOp>::value)
1772-
res = create.onnx.layerNorm(xType, x, scale, bias, axis, epsilon);
1773-
else if constexpr (std::is_same<OP_TYPE,
1774-
ONNXRMSLayerNormalizationOp>::value)
1775-
res = create.onnx.RMSLayerNorm(xType, x, scale, bias, axis, epsilon);
1776-
else
1777-
llvm_unreachable("unsupported op");
1778-
rewriter.replaceOp(addOp, res);
1779-
return success();
1780-
}
17811886

1782-
private:
1783-
LogicalResult reportFailure(std::string msg) const {
1784-
// Can disable line below if not needed.
1785-
LLVM_DEBUG(llvm::dbgs() << "LayerNorm failure:" << msg << "\n");
1786-
return failure();
1887+
rewriter.moveOpAfter(
1888+
lnOp, addOp); // Make sure we can use the const of the mul
1889+
rewriter.setInsertionPoint(addOp);
1890+
if (reshapeOp) {
1891+
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
1892+
rewriter, reshapeOp->getLoc());
1893+
const auto newShapeConst = create.onnx.constantInt64(newBiasShape);
1894+
bias = create.onnx.reshape(
1895+
RankedTensorType::get(
1896+
newBiasShape, cast<ShapedType>(bias.getType()).getElementType()),
1897+
bias, newShapeConst);
1898+
}
1899+
rewriter.modifyOpInPlace(lnOp, [&] {
1900+
lnOp.setOperand(/*bias*/ 2, bias);
1901+
lnOp->setLoc(rewriter.getFusedLoc({lnOp.getLoc(), addOp->getLoc()}));
1902+
});
1903+
if (reshapeOp) {
1904+
rewriter.moveOpAfter(reshapeOp, lnOp);
1905+
rewriter.replaceOp(addOp, reshapeOp->getResult(0));
1906+
} else {
1907+
rewriter.replaceOp(addOp, lnOp.getY());
1908+
}
1909+
return success();
17871910
}
17881911
};
17891912

@@ -1930,7 +2053,8 @@ struct RemoveInstanceNormPattern
19302053
rewriter, instanceNormOp.getLoc());
19312054
int64_t axis = nonSpacialRank;
19322055
int64_t numInNorm = inputRank - axis;
1933-
// Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm 1s.
2056+
// Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm
2057+
// 1s.
19342058
llvm::SmallVector<int64_t, 4> axesList, biasScaleShape;
19352059
biasScaleShape.emplace_back(C);
19362060
for (int64_t i = 1; i <= numInNorm; ++i) {

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2508,3 +2508,80 @@ func.func @rmslayernorm_with_neutral_scale(%arg0: tensor<1x384x768xf32>, %arg1:
25082508
// CHECK: }
25092509
}
25102510

2511+
// -----
2512+
2513+
func.func @layernorm_with_reshape_without_bias_simple_reshape(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<768xf32>) -> tensor<384x768xf32> {
2514+
%0 = "onnx.NoValue"() {value} : () -> none
2515+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2516+
%Shape = "onnx.Constant"() {value = dense<[384, 768]> : tensor<2xi64>} : () -> tensor<2xi64>
2517+
%Reshaped = "onnx.Reshape"(%NormScaled, %Shape) : (tensor<1x384x768xf32>, tensor<2xi64>) -> tensor<384x768xf32>
2518+
%Y = "onnx.Add"(%bias, %Reshaped) : (tensor<768xf32>, tensor<384x768xf32>) -> tensor<384x768xf32>
2519+
return %Y : tensor<384x768xf32>
2520+
// CHECK-LABEL: func.func @layernorm_with_reshape_without_bias_simple_reshape
2521+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<384x768xf32> {
2522+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[384, 768]> : tensor<2xi64>
2523+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
2524+
// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[VAR_Y_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<1x384x768xf32>, tensor<2xi64>) -> tensor<384x768xf32>
2525+
// CHECK: return [[VAR_1_]] : tensor<384x768xf32>
2526+
// CHECK: }
2527+
}
2528+
2529+
// -----
2530+
2531+
func.func @layernorm_with_reshape_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<384x1x1xf32>) -> tensor<1x384x2x384xf32> {
2532+
%0 = "onnx.NoValue"() {value} : () -> none
2533+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2534+
%Shape = "onnx.Constant"() {value = dense<[1, 384, 2, 384]> : tensor<4xi64>} : () -> tensor<4xi64>
2535+
%Reshaped = "onnx.Reshape"(%NormScaled, %Shape) : (tensor<1x384x768xf32>, tensor<4xi64>) -> tensor<1x384x2x384xf32>
2536+
%Y = "onnx.Add"(%bias, %Reshaped) : (tensor<384x1x1xf32>, tensor<1x384x2x384xf32>) -> tensor<1x384x2x384xf32>
2537+
return %Y : tensor<1x384x2x384xf32>
2538+
// CHECK-LABEL: func.func @layernorm_with_reshape_without_bias
2539+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<384x1x1xf32>) -> tensor<1x384x2x384xf32> {
2540+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 384, 2, 384]> : tensor<4xi64>
2541+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[384, 1]> : tensor<2xi64>
2542+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<384x1x1xf32>, tensor<2xi64>) -> tensor<384x1xf32>
2543+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<384x1xf32>) -> (tensor<1x384x768xf32>, none, none)
2544+
// CHECK: [[VAR_3_:%.+]] = "onnx.Reshape"([[VAR_Y_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<1x384x768xf32>, tensor<4xi64>) -> tensor<1x384x2x384xf32>
2545+
// CHECK: return [[VAR_3_]] : tensor<1x384x2x384xf32>
2546+
// CHECK: }
2547+
}
2548+
2549+
// -----
2550+
2551+
func.func @layernorm_with_reshape_multi_use(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>) -> tensor<384x768xf32> {
2552+
%0 = "onnx.NoValue"() {value} : () -> none
2553+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2554+
%Shape = "onnx.Constant"() {value = dense<[384, 768]> : tensor<2xi64>} : () -> tensor<2xi64>
2555+
%Reshaped = "onnx.Reshape"(%NormScaled, %Shape) : (tensor<1x384x768xf32>, tensor<2xi64>) -> tensor<384x768xf32>
2556+
%Y = "onnx.Add"(%Reshaped, %Reshaped) : (tensor<384x768xf32>, tensor<384x768xf32>) -> tensor<384x768xf32>
2557+
return %Y : tensor<384x768xf32>
2558+
// CHECK-LABEL: func.func @layernorm_with_reshape_multi_use
2559+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>) -> tensor<384x768xf32> {
2560+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[384, 768]> : tensor<2xi64>
2561+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none
2562+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_1_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2563+
// CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_Y_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<1x384x768xf32>, tensor<2xi64>) -> tensor<384x768xf32>
2564+
// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_2_]], [[VAR_2_]]) : (tensor<384x768xf32>, tensor<384x768xf32>) -> tensor<384x768xf32>
2565+
// CHECK: return [[VAR_3_]] : tensor<384x768xf32>
2566+
// CHECK: }
2567+
}
2568+
2569+
// -----
2570+
2571+
func.func @layernorm_with_reshape_split_dim(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<384xf32>) -> tensor<1x384x2x384xf32> {
2572+
%0 = "onnx.NoValue"() {value} : () -> none
2573+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2574+
%Shape = "onnx.Constant"() {value = dense<[1, 384, 2, 384]> : tensor<4xi64>} : () -> tensor<4xi64>
2575+
%Reshaped = "onnx.Reshape"(%NormScaled, %Shape) : (tensor<1x384x768xf32>, tensor<4xi64>) -> tensor<1x384x2x384xf32>
2576+
%Y = "onnx.Add"(%bias, %Reshaped) : (tensor<384xf32>, tensor<1x384x2x384xf32>) -> tensor<1x384x2x384xf32>
2577+
return %Y : tensor<1x384x2x384xf32>
2578+
// CHECK-LABEL: func.func @layernorm_with_reshape_split_dim
2579+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<384xf32>) -> tensor<1x384x2x384xf32> {
2580+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 384, 2, 384]> : tensor<4xi64>
2581+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none
2582+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_1_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2583+
// CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_Y_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<1x384x768xf32>, tensor<4xi64>) -> tensor<1x384x2x384xf32>
2584+
// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[PARAM_2_]], [[VAR_2_]]) : (tensor<384xf32>, tensor<1x384x2x384xf32>) -> tensor<1x384x2x384xf32>
2585+
// CHECK: return [[VAR_3_]] : tensor<1x384x2x384xf32>
2586+
// CHECK: }
2587+
}

0 commit comments

Comments
 (0)