Skip to content

Commit c878427

Browse files
leonling-lletiotto
andauthored
Initial support for tritongpu.upcast_mxfp lowering (#2951)
This PR adds initial codegen support for the new upcast_mxfp operation. Currently the codegen works when the source operand has blocked layout. This is a temporary limitation (we will want to support dot layout for that operand). Also, there is a failing test in test_core.py which, in this PR, is skipped. We will address that problem in a separate PR. --------- Co-authored-by: Tiotto, Ettore <[email protected]>
1 parent e538f26 commit c878427

File tree

13 files changed

+1651
-103
lines changed

13 files changed

+1651
-103
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3838
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
3939
"TRITON_INTEL_ENABLE_INSTR_SCHED",
4040
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
41-
"TRITON_INTEL_REDUCE_TRANSPOSE"
41+
"TRITON_INTEL_REDUCE_TRANSPOSE",
42+
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"
4243
// clang-format on
4344
};
4445

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/Triton/IR/Utility.h"
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7+
#include "triton/Tools/Sys/GetEnv.hpp"
78

89
#define GET_OP_CLASSES
910
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
@@ -50,15 +51,21 @@ LogicalResult UpcastMXFPOp::verify() {
5051
return success();
5152
}
5253

54+
/// TODO: Temporarily disabled this check to allow for the blocked encoding.
55+
/// Enable once we have the dot op encoding UpcastMXFPOp lowering.
5356
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
54-
if (!dotEncoding) {
57+
if (mlir::triton::tools::getBoolEnv(
58+
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") &&
59+
!dotEncoding) {
5560
return emitOpError("Expected a DotOperandEncodingAttr for values");
5661
}
5762
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
5863
return emitOpError(
5964
"Expected a BlockOperandEncoding or LinearOperandEncoding "
6065
"for scales");
6166
}
67+
if (!dotEncoding)
68+
return success();
6269

6370
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
6471
// Necessary to keep all of the scales of a given block of values in the
@@ -114,34 +121,45 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
114121
newShape.back() *= 2;
115122
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
116123
} else {
117-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
118-
119-
const int opIdx = oldEncoding.getOpIdx();
120-
const bool hasBatch = xShape.size() == 3;
121-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
122-
newShape[kIdx] *= 2;
123124
Type elemType = FloatType::getBF16(ctx);
124-
125-
// Note: For Intel the dot operands layout's kWidth parameter must match
126-
// the parent's DPAS layout opsPerChannel so we need to materialize a new
127-
// DPAS layout.
128-
Attribute newVEncoding;
129-
if (auto dpasEncoding =
130-
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
131-
auto newDpasEncoding = intel::DpasEncodingAttr::get(
132-
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
133-
dpasEncoding.getExecutionSize(),
134-
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
135-
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
136-
dpasEncoding.getSubGroupSize());
137-
newVEncoding = DotOperandEncodingAttr::get(
138-
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
139-
} else {
140-
// Figure out the K dimension for the input A/B, given that the return
141-
// type is upcasted A/B type so we need to update the proper dim size.
142-
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
143-
oldEncoding.getParent(),
144-
oldEncoding.getKWidth() * 2);
125+
Attribute newVEncoding = nullptr;
126+
if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) {
127+
const int opIdx = oldEncoding.getOpIdx();
128+
const bool hasBatch = xShape.size() == 3;
129+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
130+
newShape[kIdx] *= 2;
131+
132+
// Note: For Intel the dot operands layout's kWidth parameter must match
133+
// the parent's DPAS layout opsPerChannel so we need to materialize a
134+
// new DPAS layout.
135+
if (auto dpasEncoding =
136+
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
137+
auto newDpasEncoding = intel::DpasEncodingAttr::get(
138+
ctx, dpasEncoding.getRepeatCount(),
139+
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
140+
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
141+
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
142+
dpasEncoding.getSubGroupSize());
143+
newVEncoding = DotOperandEncodingAttr::get(
144+
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
145+
} else {
146+
// Figure out the K dimension for the input A/B, given that the return
147+
// type is upcasted A/B type so we need to update the proper dim size.
148+
newVEncoding = DotOperandEncodingAttr::get(
149+
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
150+
oldEncoding.getKWidth() * 2);
151+
}
152+
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
153+
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
154+
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
155+
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
156+
int opIdx = sizePerThread.back() == 1 ? 1 : 0;
157+
sizePerThread[!opIdx] *= 2;
158+
newShape[!opIdx] *= 2;
159+
newVEncoding = BlockedEncodingAttr::get(
160+
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
161+
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),
162+
oldEncoding.getCTALayout());
145163
}
146164
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
147165
}

python/test/unit/language/test_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3441,7 +3441,10 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
34413441
if mma == 16 and K == 64:
34423442
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
34433443
if is_xpu():
3444-
pytest.skip("scaled_dot isn't supported on XPU")
3444+
if M == 128 and N == 128 and K == 64 and not col_a and not col_b and rhs_scale and normal_type == "e4m3" and mxfp_type == "bf16":
3445+
pytest.skip(
3446+
f"FIXME: {M}x{N}x{K} col_a={col_a} col_b={col_b} rhs_scale={rhs_scale} normal_type={normal_type} mxfp_type={mxfp_type}"
3447+
)
34453448

34463449
@triton.jit
34473450
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,

0 commit comments

Comments
 (0)