Skip to content

Commit e93fc76

Browse files
authored
[NFC] Perform supportMMA check during IR verification (#8640)
Since we cannot handle the support MMA check failing during lowering, move the check to IR verification, where we can detect a violation earlier.
1 parent a9c3322 commit e93fc76

File tree

9 files changed

+79
-54
lines changed

9 files changed

+79
-54
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,15 +2761,34 @@ struct TritonGPUInferLayoutInterface
27612761
mlir::dyn_cast<triton::gpu::DotOperandEncodingAttr>(operandEncodingB);
27622762
if (!aEncoding && !bEncoding)
27632763
return mlir::success();
2764-
auto mmaAEncoding =
2765-
mlir::dyn_cast_or_null<NvidiaMmaEncodingAttr>(aEncoding.getParent());
2766-
if (mmaAEncoding && mmaAEncoding.isHopper())
2767-
return success();
2768-
// Verify that the encodings are valid.
27692764
if (!aEncoding || !bEncoding)
27702765
return op->emitError("mismatching encoding between A and B operands");
2766+
// Verify that the encodings are valid.
27712767
if (aEncoding.getKWidth() != bEncoding.getKWidth())
27722768
return op->emitError("mismatching kWidth between A and B operands");
2769+
2770+
// Check if we have already selected an MMA version for Nvidia. If so,
2771+
// validate that the encodings are correct and compatible.
2772+
auto mmaAEncoding =
2773+
dyn_cast_or_null<NvidiaMmaEncodingAttr>(aEncoding.getParent());
2774+
auto mmaBEncoding =
2775+
dyn_cast_or_null<NvidiaMmaEncodingAttr>(bEncoding.getParent());
2776+
auto dotOp = cast<DotOp>(op);
2777+
auto resEnc = dotOp.getResult().getType().getEncoding();
2778+
auto mmaResEncoding = dyn_cast<NvidiaMmaEncodingAttr>(resEnc);
2779+
if (mmaAEncoding || mmaBEncoding || mmaResEncoding) {
2780+
// Check that they are all set and have the same version.
2781+
if (!mmaAEncoding || !mmaBEncoding || !mmaResEncoding)
2782+
return op->emitError("mismatching MMA encoding");
2783+
auto mmaBEncoding = cast<NvidiaMmaEncodingAttr>(bEncoding.getParent());
2784+
if (mmaAEncoding.getVersionMajor() != mmaBEncoding.getVersionMajor() ||
2785+
mmaAEncoding.getVersionMajor() != mmaResEncoding.getVersionMajor()) {
2786+
return op->emitError("mismatched MMA version.");
2787+
}
2788+
// Verify that the operands are supported on the selected MMA version.
2789+
if (!supportMMA(dotOp, mmaResEncoding.getVersionMajor()))
2790+
return op->emitError("unsupported MMA version");
2791+
}
27732792
return success();
27742793
}
27752794

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/BuiltinTypes.h"
2626
#include "mlir/IR/Diagnostics.h"
2727
#include "mlir/Support/LLVM.h"
28+
#include "triton/Analysis/Utility.h"
2829
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
2930
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3031
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
@@ -91,13 +92,12 @@ LogicalResult WarpGroupDotOp::verify() {
9192
if (retShapePerCTA[1] % 8 != 0)
9293
return emitOpError("WGMMA result N dimension must be divisible by 8");
9394

94-
auto aElemTy = getA().getType().getElementType();
95-
if (!(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
96-
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
97-
aElemTy.isF32()))
98-
return emitOpError("WGMMA result element type must be F16, BF16, F32, "
99-
"F8E5M2, F8E4M3FN, or integer type");
95+
// Verify MMA version is supported for operands.
96+
int mmaVersion = nvmmaEnc.getVersionMajor();
97+
if (!supportMMA(getA(), mmaVersion) || !supportMMA(getB(), mmaVersion))
98+
return emitOpError("unsupported MMA version for the given operands");
10099

100+
auto aElemTy = getA().getType().getElementType();
101101
if (getMaxNumImpreciseAcc() < 32 &&
102102
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
103103
resTy.getElementType().isF32()) {

test/Triton/invalid.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,20 @@ tt.func @map_elementwise_store(%ptr: tensor<256x!tt.ptr<i32>>) {
628628
}) : (tensor<256x!tt.ptr<i32>>, tensor<256xi32>) -> (tensor<256xi32>)
629629
tt.return
630630
}
631+
632+
// -----
633+
634+
// Test that DotOp with f32 inputs but without TF32 precision is rejected for MMAv2
635+
// MMAv2 requires TF32 input precision for f32 operands
636+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
637+
#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>
638+
#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>
639+
640+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttg.target = "cuda:80"} {
641+
tt.func @dot_f32_without_tf32_mma_v2(%a: tensor<16x16xf32, #dot_operand_a>, %b: tensor<16x16xf32, #dot_operand_b>) {
642+
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
643+
// expected-error @below {{unsupported MMA version}}
644+
%result = tt.dot %a, %b, %cst, inputPrecision = ieee : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf32, #dot_operand_b> -> tensor<16x16xf32, #mma>
645+
tt.return
646+
}
647+
}

test/TritonGPU/combine.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,7 +2776,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32}
27762776
%cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
27772777
%cst_2 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
27782778
%cst_3 = arith.constant dense<1.230000e+02> : tensor<32x16xf32, #mma1>
2779-
%0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
2779+
%0 = tt.dot %cst_0, %cst_1, %cst, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
27802780
%1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked>
27812781
%2 = "tt.reduce" (%1) ({
27822782
^bb0(%arg1: f32, %arg2: f32):
@@ -2786,7 +2786,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32}
27862786
%4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked>
27872787
%5 = tt.broadcast %4 : tensor<32x1xf32, #blocked> -> tensor<32x16xf32, #blocked>
27882788
%6 = ttg.convert_layout %5 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
2789-
%7 = tt.dot %cst_2, %6, %cst_3 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1>
2789+
%7 = tt.dot %cst_2, %6, %cst_3, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1>
27902790
%addr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked>
27912791
%8 = ttg.convert_layout %7 : tensor<32x16xf32, #mma1> -> tensor<32x16xf32, #blocked>
27922792
tt.store %addr, %8 : tensor<32x16x!tt.ptr<f32>, #blocked>
@@ -2992,7 +2992,7 @@ tt.func @hoist_multiple_conditional(
29922992
}
29932993
%2 = arith.addf %0, %1 : tensor<128x32xf32, #blocked>
29942994
%3 = ttg.convert_layout %2 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2995-
%4 = tt.dot %3, %arg4, %arg5 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
2995+
%4 = tt.dot %3, %arg4, %arg5, inputPrecision = tf32 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
29962996
tt.return %4 : tensor<128x128xf32, #mma>
29972997
}
29982998

@@ -3021,7 +3021,7 @@ tt.func @hoist_across_loop(
30213021
}
30223022
// CHECK-NOT: ttg.convert_layout
30233023
%2 = ttg.convert_layout %1 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
3024-
%3 = tt.dot %2, %arg2, %acc : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
3024+
%3 = tt.dot %2, %arg2, %acc, inputPrecision = tf32 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
30253025
scf.yield %1, %3 : tensor<128x32xf32, #blocked>, tensor<128x128xf32, #mma>
30263026
}
30273027
tt.return %0#1 : tensor<128x128xf32, #mma>
@@ -3335,7 +3335,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
33353335
// CHECK-DAG: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
33363336
// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
33373337
// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
3338-
// CHECK-DAG: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
3338+
// CHECK-DAG: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}}, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
33393339
tt.func @push_convert_both_operands(
33403340
%pa: tensor<16x16x!tt.ptr<f16>, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
33413341
%pb: tensor<16x16x!tt.ptr<f16>, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
@@ -3346,7 +3346,7 @@ tt.func @push_convert_both_operands(
33463346
%be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB>
33473347
%al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
33483348
%bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
3349-
%r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
3349+
%r = tt.dot %al, %bl, %c, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
33503350
tt.return %r : tensor<16x16xf32, #mma>
33513351
}
33523352

@@ -3372,7 +3372,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
33723372
// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
33733373
// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
33743374
// CHECK-DAG: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
3375-
// CHECK-DAG: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
3375+
// CHECK-DAG: tt.dot %[[AEXT]], %[[ADD]], %{{.*}}, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
33763376
tt.func @update_kwidth_slice(
33773377
%pa: tensor<16x16x!tt.ptr<f16>, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
33783378
%pb: tensor<16x16x!tt.ptr<f16>, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
@@ -3385,7 +3385,7 @@ tt.func @update_kwidth_slice(
33853385
%add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB>
33863386
%al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
33873387
%bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
3388-
%r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
3388+
%r = tt.dot %al, %bl, %c, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
33893389
tt.return %r : tensor<16x16xf32, #mma>
33903390
}
33913391
}
@@ -3403,7 +3403,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
34033403
%cst2 = arith.constant dense<1.000000e+00> : tensor<64x32xf32, #mma>
34043404
%0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
34053405
%1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
3406-
%2 = tt.dot %cst1, %1, %cst2 : tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma>
3406+
%2 = tt.dot %cst1, %1, %cst2, inputPrecision = tf32 : tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma>
34073407
tt.return %2 : tensor<64x32xf32, #mma>
34083408
}
34093409
}
@@ -3484,7 +3484,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
34843484
%a = tt.load %pa2 : tensor<16x16x!tt.ptr<f16>, #blocked>
34853485
%ae = arith.extf %a : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked>
34863486
%ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
3487-
%r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
3487+
%r = tt.dot %ac, %b, %c, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
34883488
tt.return %r : tensor<16x16xf32, #mma>
34893489
}
34903490
}
@@ -3581,7 +3581,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
35813581
%aa = arith.addf %ab, %a2 : tensor<16x16xf16, #blocked>
35823582
%ae = arith.extf %aa : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked>
35833583
%ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
3584-
%r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
3584+
%r = tt.dot %ac, %b, %c, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
35853585
tt.return %r : tensor<16x16xf32, #mma>
35863586
}
35873587
}

test/TritonGPU/loop-pipeline-cuda.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
144144
%72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #smem>
145145
%73 = ttg.memdesc_trans %72 {order=array<i32: 1,0>} : !ttg.memdesc<32x64xf32, #shared, #smem> -> !ttg.memdesc<64x32xf32, #shared1, #smem>
146146
%74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #smem> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
147-
%75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma>
147+
%75 = tt.dot %71, %74, %cst, inputPrecision = tf32 : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma>
148148
%76 = tt.load %61 : tensor<32x32x!tt.ptr<f32>, #blocked1>
149149
%77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
150150
%78 = ttg.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
151-
%79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma>
151+
%79 = tt.dot %77, %78, %arg7, inputPrecision = tf32 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma>
152152
scf.yield %79 : tensor<64x32xf32, #mma>
153153
}
154154
%64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked>

0 commit comments

Comments
 (0)