Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
let hasCanonicalizer = 1;
}
#endif // AMDGPU
141 changes: 141 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -26,8 +27,10 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

#include <cstdint>
#include <limits>
#include <optional>

Expand Down Expand Up @@ -631,6 +634,144 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ScaledMFMAOp
//===----------------------------------------------------------------------===//

namespace {
/// Check if the scales input is used in other scaled mfma's while they exist.
/// If theyre unused then pack the scales.
struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ScaledMFMAOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// If this use of a scale has a non zero opsel, packing has already been
// done.
auto checkIfUnpackable = [&](OpOperand &op) {
if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
switch (op.getOperandNumber()) {
case 3:
return smfma.getScalesIdxA() != 0;
case 4:
return smfma.getScalesIdxB() != 0;
default:
break;
}
}
return true;
};

auto setOpsel = [&](unsigned idx, int64_t val) {
switch (idx) {
case 3:
op.setScalesIdxA(val);
break;
case 4:
op.setScalesIdxB(val);
break;
default:
break;
}
};

// Obtain flat index from offsets and shape.
auto getIdxFromExtract = [](vector::ExtractOp op) {
ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
int64_t cumul = 1;
int64_t idx = 0;
for (auto [offset, size] :
reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
idx += offset * cumul;
cumul *= size;
}
return idx;
};

// For every scale operand of this ScaledMFMAOp, if the scale follows the
// following pattern:
// (f8 here means f8E8M0FNU)
// %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
// %scale = vector.insert %unit, ... : f8 into vector<4xf8>
// amdgpu.scaled_mfma(%scale[0] * ...
//
// rewrite to:
//
// %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
// %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
// amdgpu.scaled_mfma(%scale[0-3] * ...
//
// This creates duplicate shape_casts for every use but these will be
// removed in CSE.
for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
if (!insertOp) {
return rewriter.notifyMatchFailure(op,
"defining op not a vector.insert");
}
if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
return rewriter.notifyMatchFailure(op,
"some scaled mfma's already packed");
}

auto extractOp =
insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
if (!extractOp) {
return rewriter.notifyMatchFailure(op,
"defining op not a vector.extract");
}

Value scaleSrc = extractOp.getOperand(0);
auto stype = dyn_cast<VectorType>(scaleSrc.getType());
if (!stype) {
return rewriter.notifyMatchFailure(op, "not a vector type");
}
// We do not handle dynamic dims yet, assume that the input is padded to
// a static shape now.
if (!stype.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"dynamic dims not yet supported");
}

int64_t numElements = stype.getNumElements();
if (numElements <= 4) {
return rewriter.notifyMatchFailure(
op, "no packing if # of scales less than four");
}
int64_t idx = getIdxFromExtract(extractOp);
int64_t offset = idx - (idx % 4);
int64_t size = std::min(4l, numElements - offset);
int64_t opsel = idx - offset;
if (size != 4l) {
opsel += 4l - size;
offset = numElements - 4l;
size = 4l;
}

Type newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
stype.getElementType());
Value newScaleSrc =
rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
auto scaleTy = VectorType::get({4}, stype.getElementType());
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
loc, newScaleSrc, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{size},
ArrayRef<int64_t>{1});
Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
rewriter.modifyOpInPlace(
op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); });
setOpsel(opIdx, opsel);
}
return success();
}
};
} // namespace

void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<PackScales>(context);
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
MLIRVectorDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRMemRefUtils
Expand Down
114 changes: 114 additions & 0 deletions mlir/test/Dialect/AMDGPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,117 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
func.return
}

// -----

// CHECK-LABEL: func @scaled_mfma
// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
%scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
}

// -----

// CHECK-LABEL: func @scaled_mfma_less_than_4
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
%scaleA = vector.extract %scalesA[0] : f8E8M0FNU from vector<2xf8E8M0FNU>
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0 : vector<4xf32>
}


// -----

// CHECK-LABEL: func @scaled_mfma_ugly_shapes
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
%scaleA_0_0 = vector.extract %scalesA[0, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_1 = vector.extract %scalesA[1, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_2 = vector.extract %scalesA[2, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_3 = vector.extract %scalesA[3, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU>

// idx = 138 + 8 = 146 => opsel = 2
%scaleB_6_8 = vector.extract %scalesB[6, 8] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 147 => opsel = 3
%scaleB_6_9 = vector.extract %scalesB[6, 9] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 148 => opsel = 0
%scaleB_6_10 = vector.extract %scalesB[6, 10] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 149 => opsel = 1
%scaleB_6_11 = vector.extract %scalesB[6, 11] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 160 => opsel = 3 (last idx of last 4 bytes)
%scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 159 => opsel = 3
%scaleB_6_21 = vector.extract %scalesB[6, 21] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 158 => opsel = 2
%scaleB_6_20 = vector.extract %scalesB[6, 20] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 157 => opsel = 1
%scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU>

%sA_0_0 = vector.insert %scaleA_0_0, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_1 = vector.insert %scaleA_0_1, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_2 = vector.insert %scaleA_0_2, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_3 = vector.insert %scaleA_0_3, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>

%sB_6_8 = vector.insert %scaleB_6_8, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_9 = vector.insert %scaleB_6_9, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_10 = vector.insert %scaleB_6_10, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_11 = vector.insert %scaleB_6_11, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>

%res_0 = amdgpu.scaled_mfma(%sA_0_0[0] * %opA) * (%sB_6_8[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_1 = amdgpu.scaled_mfma(%sA_0_1[0] * %opA) * (%sB_6_9[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_2 = amdgpu.scaled_mfma(%sA_0_2[0] * %opA) * (%sB_6_10[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_3 = amdgpu.scaled_mfma(%sA_0_3[0] * %opA) * (%sB_6_11[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0, %res_1, %res_2, %res_3, %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
Loading