Skip to content

Commit d6a7139

Browse files
authored
[Dialect][Nvidia] Verify WGMMA kWidth (triton-lang#8107)
1 parent f4789ef commit d6a7139

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/BuiltinTypes.h"
2626
#include "mlir/IR/Diagnostics.h"
2727
#include "mlir/Support/LLVM.h"
28+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
2829
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2930
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
3031
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -101,6 +102,16 @@ LogicalResult WarpGroupDotOp::verify() {
101102
return emitOpError("Cannot use F32 as the accumulator element type when "
102103
"the max_num_imprecise_acc is less than 32");
103104
}
105+
106+
if (auto aTensorTy = dyn_cast<RankedTensorType>(getA().getType())) {
107+
auto aDotOpEnc = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
108+
unsigned kWidth = 32 / aTensorTy.getElementTypeBitWidth();
109+
if (aDotOpEnc.getKWidth() != kWidth) {
110+
return emitOpError("in-register LHS operand must have a kWidth of ")
111+
<< kWidth << " but got " << aDotOpEnc.getKWidth();
112+
}
113+
}
114+
104115
return success();
105116
}
106117

test/TritonNvidiaGPU/invalid.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ tt.func @async_tma_gather(%desc: !tt.tensordesc<tensor<1x128xbf16, #shared>>, %x
6565

6666
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
6767

68-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
68+
module attributes {"ttg.num-warps" = 4 : i32} {
6969
tt.func @async_tma_gather(%desc: !tt.tensordesc<tensor<1x128xbf16, #shared>>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32,
7070
%bar: !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>,
7171
%result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory>,
@@ -75,3 +75,17 @@ tt.func @async_tma_gather(%desc: !tt.tensordesc<tensor<1x128xbf16, #shared>>, %x
7575
tt.return
7676
}
7777
}
78+
79+
// -----
80+
81+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}>
82+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
83+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
84+
85+
module attributes {"ttg.num-warps" = 4 : i32} {
86+
tt.func @wgmma(%a: tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, %b: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, %c: tensor<128x128xf16, #mma>) {
87+
// expected-error @below {{in-register LHS operand must have a kWidth of 2 but got 1}}
88+
%0 = ttng.warp_group_dot %a, %b, %c : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf16, #mma>
89+
tt.return
90+
}
91+
}

0 commit comments

Comments
 (0)