Skip to content

Commit e4fa38e

Browse files
Merge commit 'cc25374fa480c0b3e51cf218ed6fe7eb4c50a5bb'
2 parents 40392dc + cc25374 commit e4fa38e

File tree

29 files changed

+1060
-862
lines changed

29 files changed

+1060
-862
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ python/*.whl
1111
python/triton/_C/*.pyd
1212
python/triton/_C/*.so
1313
python/triton/_C/*.dylib
14+
python/triton/_C/*.pdb
15+
python/triton/_C/*.exe
16+
python/triton/_C/*.ilk
1417

1518
benchmarks/dist
1619
benchmarks/*.egg-info/

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
8686
// TODO: Add verifier
8787
}
8888

89-
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
89+
def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
90+
SameOperandsAndResultShape,
9091
SameOperandsAndResultEncoding,
9192
Pure,
9293
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -675,6 +676,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
675676
// DotScaled Op
676677
//
677678
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
679+
AttrSizedOperandSegments,
678680
DotLike,
679681
TypesMatchWith<"result's type matches accumulator's type",
680682
"d", "c", "$_self">]> {
@@ -692,7 +694,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
692694
RankedTensorOf<[TT_Float,I8]>:$lhs,
693695
RankedTensorOf<[TT_Float,I8]>:$rhs,
694696
TT_FloatTensor:$c,
695-
RankedTensorOf<[I8]>:$lhs_scale,
697+
Optional<RankedTensorOf<[I8]>>:$lhs_scale,
696698
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
697699
TT_ScaleDotElemTypeAttr:$lhs_type,
698700
TT_ScaleDotElemTypeAttr:$rhs_type
@@ -702,8 +704,8 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
702704

703705
// Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file
704706
let assemblyFormat = [{
705-
$lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
706-
`:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
707+
$lhs (`scale` $lhs_scale^)? `,` $rhs (`scale` $rhs_scale^)? `,` $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
708+
`:` type($lhs) (`,` type($lhs_scale)^)? `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
707709
}];
708710
}
709711

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
371371
auto srcTy = op.getSrc().getType();
372372
auto dstTy = op.getType();
373373

374-
// TODO (Keren): Currently, we handle general mma/blocked/slice ->
375-
// mma/blocked/slice conversions.
376-
// The following tasks must be completed before we can remove the layoutIsOK
377-
// check:
374+
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
375+
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
376+
// completed before we can remove the layoutIsOK check:
378377
// 1. Support for AMD's MFMA and WMMA
379378
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
380379
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,11 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
140140
// Do for all DotOperandEncodingAttr once we have LLs for all of them
141141
static bool isSupportedDotOpLayout(Attribute layout) {
142142
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
143+
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
144+
// - kWidth == 8
143145
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
144-
return mma.isAmpere() && dot.getKWidth() == 8;
146+
bool legacyLoweringIsBuggy = dot.getKWidth() >= 8;
147+
return legacyLoweringIsBuggy && mma.isAmpere();
145148
}
146149
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
147150
return true;

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,23 @@ LogicalResult UpcastMXFPOp::verify() {
5252
"all dimensions except the last must match between operands");
5353
}
5454

55-
auto dotEncoding =
56-
dyn_cast_or_null<DotOperandEncodingAttr>(xTy.getEncoding());
55+
auto layoutX = xTy.getEncoding();
56+
auto layoutScale = scaleTy.getEncoding();
57+
if (bool(layoutX) != bool(layoutScale)) {
58+
return emitOpError(
59+
"Expected either both or neither operands to have an encoding");
60+
}
61+
// Nothing to check if no encoding. This is used to infer the return type in
62+
// AccelerateMatmul.cpp
63+
if (!layoutX) {
64+
return success();
65+
}
66+
67+
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
5768
if (!dotEncoding) {
5869
return emitOpError("Expected a DotOperandEncodingAttr for values");
5970
}
60-
61-
auto blockedScale =
62-
dyn_cast_or_null<BlockedEncodingAttr>(scaleTy.getEncoding());
71+
auto blockedScale = dyn_cast<BlockedEncodingAttr>(layoutScale);
6372
if (!blockedScale) {
6473
return emitOpError("Expected a BlockOperandEncoding for scales");
6574
}
@@ -86,22 +95,23 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
8695
auto xShape = xTy.getShape();
8796

8897
auto encoding = xTy.getEncoding();
89-
if (!encoding) {
90-
return emitOptionalError(loc, "expected an encoding");
91-
}
92-
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
93-
return emitOptionalError(loc, "expected a dotOperand encoding");
94-
}
9598

9699
if (typeEncoded == ScaleDotElemType::E2M1) {
97-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
98-
auto newVEncoding = DotOperandEncodingAttr::get(
99-
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
100-
oldEncoding.getKWidth() * 2);
100+
RankedTensorType retTy;
101+
101102
auto newShape = SmallVector<int64_t>(xShape);
102103
newShape.back() *= 2;
103-
inferredReturnTypes.push_back(
104-
RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding));
104+
if (!encoding) {
105+
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
106+
} else {
107+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
108+
auto newVEncoding = DotOperandEncodingAttr::get(
109+
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
110+
oldEncoding.getKWidth() * 2);
111+
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
112+
newVEncoding);
113+
}
114+
inferredReturnTypes.push_back(retTy);
105115
} else {
106116
inferredReturnTypes.push_back(xTy);
107117
}

0 commit comments

Comments
 (0)