Skip to content

Commit e1ac1c4

Browse files
sahas3Lallapallooza
authored andcommitted
[Tosa] : Match accumulator type with torch for lowering aten.mm to tosa.matmul (#4264)
For `i8` input, accumulator was being set to `i8` as well which produces invalid `tosa.matmul` op. `fp16` input used `fp16` accumulator which while valid by tosa spec, fails numerical verification. This change uses the accumulator type as used by PyTorch (tosa-to-linalg does the same).
1 parent dd079b9 commit e1ac1c4

File tree

4 files changed

+194
-34
lines changed

4 files changed

+194
-34
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,25 +1849,19 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
18491849

18501850
SmallVector<int64_t> matmulOutputShape(
18511851
{matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]});
1852-
Type outputElemTy;
18531852

18541853
bool isInputElemTyQInt8 = false;
1855-
if (isa<mlir::quant::UniformQuantizedType>(lhsElemTy)) {
1856-
mlir::quant::UniformQuantizedType inputQTy =
1857-
dyn_cast<mlir::quant::UniformQuantizedType>(lhsElemTy);
1854+
Type inputElemTy{lhsElemTy};
1855+
if (auto inputQTy =
1856+
dyn_cast<mlir::quant::UniformQuantizedType>(lhsElemTy)) {
18581857
if (inputQTy.getStorageTypeIntegralWidth() == 8)
18591858
isInputElemTyQInt8 = true;
1859+
inputElemTy = inputQTy.getStorageType();
18601860
}
18611861

1862-
if (isInputElemTyQInt8) {
1863-
// qint8 emits i32 matmul output
1864-
outputElemTy = rewriter.getIntegerType(32);
1865-
} else {
1866-
outputElemTy = lhsElemTy;
1867-
}
1868-
1862+
auto accElemTy = getDefaultAccType(rewriter, inputElemTy);
18691863
auto mmOutputTy = RankedTensorType::get(
1870-
makeShapeLLVMCompatible(matmulOutputShape), outputElemTy);
1864+
makeShapeLLVMCompatible(matmulOutputShape), accElemTy);
18711865

18721866
Value mmOpResult;
18731867
if (!isInputElemTyQInt8) {
@@ -1997,7 +1991,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
19971991

19981992
// Perform reshape
19991993
auto reshapedOpType = RankedTensorType::get(
2000-
makeShapeLLVMCompatible(reshapedOpShape), outputElemTy);
1994+
makeShapeLLVMCompatible(reshapedOpShape), accElemTy);
20011995
auto reshapedOp = rewriter.create<tosa::ReshapeOp>(
20021996
op->getLoc(),
20031997
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
@@ -2007,7 +2001,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
20072001

20082002
if (opNeedsTranspose) {
20092003
auto transposedOpType = RankedTensorType::get(
2010-
makeShapeLLVMCompatible(transposedOpShape), outputElemTy);
2004+
makeShapeLLVMCompatible(transposedOpShape), accElemTy);
20112005
output = rewriter
20122006
.create<tosa::TransposeOp>(
20132007
op->getLoc(),
@@ -2043,12 +2037,14 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
20432037
return rewriter.notifyMatchFailure(op,
20442038
"Failed to perform matmul operation");
20452039

2046-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
2040+
rewriter.replaceOp(
20472041
op,
2048-
cast<RankedTensorType>(
2049-
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
2050-
op.getType())),
2051-
output);
2042+
{tosa::tosaCastTensorToType(
2043+
rewriter, output,
2044+
cast<RankedTensorType>(
2045+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
2046+
op.getType())))
2047+
.value()});
20522048

20532049
return success();
20542050
}
@@ -2165,10 +2161,6 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
21652161
auto bias = adaptor.getBias();
21662162
auto biasTy = bias.getType();
21672163

2168-
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, bias).failed())
2169-
return rewriter.notifyMatchFailure(
2170-
op, "Failed to equalize ranks among operands and result");
2171-
21722164
// TOSA does not mandate that elementwise op tensors need to be ranked.
21732165
if (!isa<Torch::NoneType>(biasTy) && !isa<TensorType>(biasTy))
21742166
return rewriter.notifyMatchFailure(
@@ -2207,22 +2199,30 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
22072199
return rewriter.notifyMatchFailure(op,
22082200
"Failed to perform matmul operation");
22092201

2210-
Value matmulPlusBias = matmulOutput;
2202+
Value matmulPlusBias =
2203+
tosa::tosaCastTensorToType(
2204+
rewriter, matmulOutput,
2205+
cast<RankedTensorType>(
2206+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
2207+
op.getType())))
2208+
.value();
2209+
22112210
if (!isa<Torch::NoneType>(biasTy)) {
2212-
// Bias addition broadcasts to the matmul output shape.
2211+
// Broadcast bias to the matmul output shape for addition
2212+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), matmulPlusBias,
2213+
bias)
2214+
.failed())
2215+
return rewriter.notifyMatchFailure(
2216+
op, "Failed to equalize ranks among operands and result");
2217+
22132218
matmulPlusBias =
22142219
rewriter
2215-
.create<tosa::AddOp>(op->getLoc(), matmulOutput.getType(),
2216-
matmulOutput, bias)
2220+
.create<tosa::AddOp>(op->getLoc(), matmulPlusBias.getType(),
2221+
matmulPlusBias, bias)
22172222
.getResult();
22182223
}
22192224

2220-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
2221-
op,
2222-
cast<RankedTensorType>(
2223-
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
2224-
op.getType())),
2225-
matmulPlusBias);
2225+
rewriter.replaceOp(op, {matmulPlusBias});
22262226

22272227
return success();
22282228
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
"TraceModule_empty",
8686
# Crashes due to copy to a smaller destination buffer than the source buffer.
8787
"SliceCopyStartGreaterThanDimSize_Module_basic",
88+
# unimplemented: for conversion to byte or char type dstOriginalDtype has to be passed to convertScalarToDtype
89+
"AtenMmInt8Types_basic",
8890
}
8991

9092
TORCHDYNAMO_XFAIL_SET = {
@@ -641,6 +643,7 @@
641643
"AtenMatmulQint8VM_basic",
642644
"AtenMatmulQint8VV_basic",
643645
"AtenMatmulQint8_basic",
646+
"AtenMmF16Types_basic",
644647
"AtenMmQMixedSigni8_basic",
645648
"AtenMmQint8_basic",
646649
"AtenMmQuint8_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,51 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
336336
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
337337

338338

339+
# ==============================================================================
340+
341+
342+
class AtenMmInt8Types(torch.nn.Module):
343+
@export
344+
@annotate_args(
345+
[
346+
None,
347+
([-1, -1], torch.int8, True),
348+
([-1, -1], torch.int8, True),
349+
]
350+
)
351+
def forward(self, a, b):
352+
return torch.ops.aten.mm(a, b)
353+
354+
355+
@register_test_case(module_factory=lambda: AtenMmInt8Types())
356+
def AtenMmInt8Types_basic(module, tu: TestUtils):
357+
module.forward(
358+
tu.randint(16, 4, high=100).to(torch.int8),
359+
tu.randint(4, 16, high=100).to(torch.int8),
360+
)
361+
362+
363+
# ==============================================================================
364+
365+
366+
class AtenMmF16Types(torch.nn.Module):
367+
@export
368+
@annotate_args(
369+
[
370+
None,
371+
([-1, -1], torch.float16, True),
372+
([-1, -1], torch.float16, True),
373+
]
374+
)
375+
def forward(self, a, b):
376+
return torch.ops.aten.mm(a, b)
377+
378+
379+
@register_test_case(module_factory=lambda: AtenMmF16Types())
380+
def AtenMmF16Types_basic(module, tu: TestUtils):
381+
module.forward(tu.rand(16, 4).to(torch.float16), tu.rand(4, 16).to(torch.float16))
382+
383+
339384
# ==============================================================================
340385
# For DQ-Q fake quantization ops
341386
import torch.ao.quantization.fx._decomposed

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4186,6 +4186,7 @@ func.func @torch.aten.convolution$si8(%arg0: !torch.vtensor<[2,2,6,6],si8>, %arg
41864186
return %4 : !torch.vtensor<[2,8,4,4],si32>
41874187
}
41884188

4189+
// -----
41894190
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
41904191
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
41914192
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
@@ -4226,7 +4227,6 @@ func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,
42264227
}
42274228

42284229
// -----
4229-
42304230
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
42314231
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
42324232
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
@@ -4264,3 +4264,115 @@ func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,
42644264
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
42654265
return %3 : !torch.vtensor<[1,512,10],f32>
42664266
}
4267+
4268+
// -----
4269+
// CHECK-LABEL: func.func @torch.aten.mm$f32(
4270+
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[1,22],f32>,
4271+
// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[22,10],f32>) -> !torch.vtensor<[1,10],f32> {
4272+
// CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[22,10],f32> -> tensor<22x10xf32>
4273+
// CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[1,22],f32> -> tensor<1x22xf32>
4274+
// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 22]> : tensor<3xindex>} : () -> !tosa.shape<3>
4275+
// CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<1x22xf32>, !tosa.shape<3>) -> tensor<1x1x22xf32>
4276+
// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 22, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4277+
// CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<22x10xf32>, !tosa.shape<3>) -> tensor<1x22x10xf32>
4278+
// CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4279+
// CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4280+
// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x1x22xf32>, tensor<1x22x10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x10xf32>
4281+
// CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
4282+
// CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x1x10xf32>, !tosa.shape<2>) -> tensor<1x10xf32>
4283+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<1x10xf32> -> !torch.vtensor<[1,10],f32>
4284+
// CHECK: return %[[RES]]
4285+
func.func @torch.aten.mm$f32(%arg0: !torch.vtensor<[1,22],f32>, %arg1: !torch.vtensor<[22,10],f32>) -> !torch.vtensor<[1,10],f32> {
4286+
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[1,22],f32>, !torch.vtensor<[22,10],f32> -> !torch.vtensor<[1,10],f32>
4287+
return %0 : !torch.vtensor<[1,10],f32>
4288+
}
4289+
4290+
// -----
4291+
// CHECK-LABEL: func.func @torch.aten.mm$si8
4292+
// CHECK: tosa.matmul
4293+
// CHECK-SAME: (tensor<1x1x22xi8>, tensor<1x22x10xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x10xi32>
4294+
// CHECK-NOT: torch.aten.mm
4295+
// CHECK: tosa.cast
4296+
// CHECK-SAME: (tensor<1x10xi32>) -> tensor<1x10xi8>
4297+
func.func @torch.aten.mm$si8(%arg0: !torch.vtensor<[1,22],si8>, %arg1: !torch.vtensor<[22,10],si8>) -> !torch.vtensor<[1,10],si8> {
4298+
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[1,22],si8>, !torch.vtensor<[22,10],si8> -> !torch.vtensor<[1,10],si8>
4299+
return %0 : !torch.vtensor<[1,10],si8>
4300+
}
4301+
4302+
// -----
4303+
// CHECK-LABEL: func.func @torch.aten.mm$f16
4304+
// CHECK: tosa.matmul
4305+
// CHECK-SAME: (tensor<1x1x22xf16>, tensor<1x22x10xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1x10xf32>
4306+
// CHECK-NOT: torch.aten.mm
4307+
// CHECK: tosa.cast
4308+
// CHECK-SAME: (tensor<1x10xf32>) -> tensor<1x10xf16>
4309+
func.func @torch.aten.mm$f16(%arg0: !torch.vtensor<[1,22],f16>, %arg1: !torch.vtensor<[22,10],f16>) -> !torch.vtensor<[1,10],f16> {
4310+
%4 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[1,22],f16>, !torch.vtensor<[22,10],f16> -> !torch.vtensor<[1,10],f16>
4311+
return %4 : !torch.vtensor<[1,10],f16>
4312+
}
4313+
4314+
// -----
4315+
// CHECK-LABEL: func.func @torch.aten.mm$bf16
4316+
// CHECK: tosa.matmul
4317+
// CHECK-SAME: (tensor<1x1x22xbf16>, tensor<1x22x10xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1x1x10xf32>
4318+
// CHECK-NOT: torch.aten.mm
4319+
// CHECK: tosa.cast
4320+
// CHECK-SAME: (tensor<1x10xf32>) -> tensor<1x10xbf16>
4321+
func.func @torch.aten.mm$bf16(%arg0: !torch.vtensor<[1,22],bf16>, %arg1: !torch.vtensor<[22,10],bf16>) -> !torch.vtensor<[1,10],bf16> {
4322+
%4 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[1,22],bf16>, !torch.vtensor<[22,10],bf16> -> !torch.vtensor<[1,10],bf16>
4323+
return %4 : !torch.vtensor<[1,10],bf16>
4324+
}
4325+
4326+
// -----
4327+
// CHECK-LABEL: func.func @torch.aten.matmul$broadcast(
4328+
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[10,3,4],f32>,
4329+
// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[10,3],f32> {
4330+
// CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[4],f32> -> tensor<4xf32>
4331+
// CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32>
4332+
// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4333+
// CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<4xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32>
4334+
// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 30, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4335+
// CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<10x3x4xf32>, !tosa.shape<3>) -> tensor<1x30x4xf32>
4336+
// CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_RESHAPE]] {perms = array<i32: 1, 0, 2>} : (tensor<1x4x1xf32>) -> tensor<4x1x1xf32>
4337+
// CHECK: %[[WTS_SHAPE_2:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4338+
// CHECK: %[[WTS_RESHAPE_2:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE_2]] : (tensor<4x1x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32>
4339+
// CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4340+
// CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4341+
// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE_2]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x30x4xf32>, tensor<1x4x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x30x1xf32>
4342+
// CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[10, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4343+
// CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x30x1xf32>, !tosa.shape<2>) -> tensor<10x3xf32>
4344+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<10x3xf32> -> !torch.vtensor<[10,3],f32>
4345+
// CHECK: return %[[RES]]
4346+
func.func @torch.aten.matmul$broadcast(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[10,3],f32> {
4347+
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[10,3],f32>
4348+
return %0 : !torch.vtensor<[10,3],f32>
4349+
}
4350+
4351+
// -----
4352+
// CHECK-LABEL: func.func @torch.aten.linear$f16(
4353+
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[2,4],f16>,
4354+
// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[3,4],f16>,
4355+
// CHECK-SAME: %[[BIAS:.*]]: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> {
4356+
// CHECK: %[[BIAS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[BIAS]] : !torch.vtensor<[3],f16> -> tensor<3xf16>
4357+
// CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[3,4],f16> -> tensor<3x4xf16>
4358+
// CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[2,4],f16> -> tensor<2x4xf16>
4359+
// CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_TENSOR]] {perms = array<i32: 1, 0>} : (tensor<3x4xf16>) -> tensor<4x3xf16>
4360+
// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4361+
// CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<2x4xf16>, !tosa.shape<3>) -> tensor<1x2x4xf16>
4362+
// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
4363+
// CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE]] : (tensor<4x3xf16>, !tosa.shape<3>) -> tensor<1x4x3xf16>
4364+
// CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4365+
// CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4366+
// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x2x4xf16>, tensor<1x4x3xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x2x3xf32>
4367+
// CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4368+
// CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x2x3xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
4369+
// CHECK: %[[CAST:.*]] = tosa.cast %[[RES_RESHAPE]] : (tensor<2x3xf32>) -> tensor<2x3xf16>
4370+
// CHECK: %[[BIAS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4371+
// CHECK: %[[BIAS_RESHAPE:.*]] = tosa.reshape %[[BIAS_TENSOR]], %[[BIAS_SHAPE]] : (tensor<3xf16>, !tosa.shape<2>) -> tensor<1x3xf16>
4372+
// CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[BIAS_RESHAPE]] : (tensor<2x3xf16>, tensor<1x3xf16>) -> tensor<2x3xf16>
4373+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x3xf16> -> !torch.vtensor<[2,3],f16>
4374+
// CHECK: return %[[RES]]
4375+
func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch.vtensor<[3,4],f16>, %arg2: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> {
4376+
%0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16>
4377+
return %0 : !torch.vtensor<[2,3],f16>
4378+
}

0 commit comments

Comments
 (0)