Skip to content

Commit a3ace81

Browse files
pchen7e2meta-codesync[bot]
authored andcommitted
[4/N][TLX-2cta] Special logic for DotOp verification when tlx is in 2cta mode (#642)
Summary: By default, this is the verifier for output dims generated by `DotOpInterfaceTrait` https://github.com/facebookexperimental/triton/blob/70aa21cb8602116e1feedbf8348609b4f4b568b9/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td#L62-L79 In our implementation, we chose to maintain shapes for A and D but shrink tensor B by half so we need to override the verifier logic if 2cta flag is ON for the mma op. ``` % make test-lit (all passing) % third_party/tlx/run_all.sh Need to build triton in this script? {y|n}n Run all LITs? {y|n}n Run core Triton python unit tests? {y|n}n Run all TLX unit tests? {y|n}y Running TLX Unit Tests ... (all passing or skipped) Run TLX tutorial kernels (correctness|performance|no)? {c|p|n} c Verifying correctness of TLX tutorial kernels (all passing) ``` Pull Request resolved: #642 Reviewed By: htyu Differential Revision: D86337216 Pulled By: pchen7e2 fbshipit-source-id: be2286cc1258d20efa22a0eb2eb92fb6e38b7fc8
1 parent 3d18916 commit a3ace81

File tree

6 files changed

+35
-10
lines changed

6 files changed

+35
-10
lines changed

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
7474
auto bShape = bTy.getShape();
7575
auto cShape = cTy.getShape();
7676
return cShape[cShape.size() - 2] == aShape[aShape.size() - 2] &&
77-
cShape[cShape.size() - 1] == bShape[aShape.size() - 1];
77+
cShape[cShape.size() - 1] == bShape[bShape.size() - 1];
7878
}]>
7979
];
8080

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
489489

490490
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
491491
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
492-
DeclareOpInterfaceMethods<DotOpInterface>,
492+
DeclareOpInterfaceMethods<DotOpInterface, ["verifyOutputDims"]>,
493493
DeclareOpInterfaceMethods<MMAv5OpInterface>,
494494
AttrSizedOperandSegments
495495
]> {

lib/Dialect/Triton/IR/OpInterfaces.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ LogicalResult verifyDotOpInterface(Operation *op) {
6767
if (!dotOp.verifyOutputDims())
6868
return dotOp->emitOpError(
6969
"expected the output shape to be the concatenation of the last "
70-
"dimension of the first operand and the last dimension of the "
71-
"second ");
70+
"dimension of the first operand and (the last dimension of the "
71+
"second if 1cta; or 2 times the last dimension of the second operand "
72+
"if 2cta TLX)");
7273
return success();
7374
}
7475

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,29 @@ bool TCGen5MMAOp::verifyDims() {
392392
return aShape[aShape.size() - 1] == bShape[aShape.size() - 2];
393393
}
394394

395+
bool TCGen5MMAOp::verifyOutputDims() {
396+
397+
if (getTwoCtas()) {
398+
// Here we have to relax the verification to support two possibilities
399+
// - For TLX 2CTA:
400+
// - Full MMA shape: [2M, K] x [K, N] -> [2M, N]
401+
// - Each CTA: [M, K] x [K, N/2] -> [M, N]. We're verifying each CTA here.
402+
// - For non TLX 2CTA: each CTA has [M, K] x [K, N] -> [M, N]
403+
// We cannot rely on module attr to differentiate them here because this
404+
// verification can run before Fixup pass. If we want to be as accurate as
405+
// possible, we should have a tlxTwoCTAs flag on MMA Op in the future
406+
auto aShape = getA().getType().getShape();
407+
auto bShape = getB().getType().getShape();
408+
auto dShape = getD().getType().getShape();
409+
return dShape[dShape.size() - 2] == aShape[aShape.size() - 2] &&
410+
(dShape[dShape.size() - 1] == bShape[bShape.size() - 1] /* non TLX*/
411+
|| dShape[dShape.size() - 1] ==
412+
2 * bShape[bShape.size() - 1] /* TLX 2CTA*/);
413+
}
414+
// 1cta case still delegates to default verifiers
415+
return DotOpInterfaceTrait::verifyOutputDims();
416+
}
417+
395418
Value TCGen5MMAOp::useAccumulator() { return getUseD(); }
396419

397420
void TCGen5MMAOp::setUseAccumulator(Value flag) {

test/TLX/attach-metadata.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,15 @@ module attributes {tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.
180180
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
181181
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
182182
tt.func @tc_gen5_mma(%a: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
183-
%b: !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>,
183+
%b: !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory>,
184184
%c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
185185
%useAcc: i1,
186186
%pred: i1,
187187
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
188188
%barrierPred: i1) {
189189
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async, two_ctas}:
190190
!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
191-
!ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>,
191+
!ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory>,
192192
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
193193
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>
194194
tt.return

test/TLX/tlx-verifier.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@ module attributes {tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.
3434
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
3535
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
3636
tt.func @tc_gen5_mma(%a: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
37-
%b: !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>,
37+
%b1: !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory>,
38+
%b2: !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>,
3839
%c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
3940
%useAcc: i1,
4041
%pred: i1,
4142
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
4243
%barrierPred: i1) {
43-
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async, two_ctas}:
44+
ttng.tc_gen5_mma %a, %b1, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async, two_ctas}:
4445
!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
45-
!ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>,
46+
!ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory>,
4647
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
4748
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>
4849
// expected-error @+1 {{Expecting all dot ops to be 2cta together}}
49-
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async}:
50+
ttng.tc_gen5_mma %a, %b2, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async}:
5051
!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
5152
!ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>,
5253
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,

0 commit comments

Comments
 (0)