Skip to content

Commit b9da9cc

Browse files
authored
Use DotOp layout for UpcastMXFPOp Lowering (#3057)
This pull request support dot layout codegen for upcast_mxfp operation, which could be more efficient than previous blocked layout implementation. The 2 skipped tests are failed for L0 runtime error, they will be addressed in a seperate PR #2968.
1 parent caf24cc commit b9da9cc

File tree

6 files changed

+287
-146
lines changed

6 files changed

+287
-146
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3939
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
4040
"TRITON_INTEL_ENABLE_INSTR_SCHED",
4141
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
42-
"TRITON_INTEL_REDUCE_TRANSPOSE",
43-
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"
42+
"TRITON_INTEL_REDUCE_TRANSPOSE"
4443
// clang-format on
4544
};
4645

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
8-
#include "triton/Tools/Sys/GetEnv.hpp"
98

109
#define GET_OP_CLASSES
1110
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
@@ -339,21 +338,15 @@ LogicalResult UpcastMXFPOp::verify() {
339338
return success();
340339
}
341340

342-
/// TODO: Temporarily disabled this check to allow for the blocked encoding.
343-
/// Enable once we have the dot op encoding UpcastMXFPOp lowering.
344341
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
345-
if (mlir::triton::tools::getBoolEnv(
346-
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") &&
347-
!dotEncoding) {
342+
if (!dotEncoding) {
348343
return emitOpError("Expected a DotOperandEncodingAttr for values");
349344
}
350345
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
351346
return emitOpError(
352347
"Expected a BlockOperandEncoding or LinearOperandEncoding "
353348
"for scales");
354349
}
355-
if (!dotEncoding)
356-
return success();
357350

358351
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
359352
// Necessary to keep all of the scales of a given block of values in the
@@ -411,43 +404,36 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
411404
} else {
412405
Type elemType = FloatType::getBF16(ctx);
413406
Attribute newVEncoding = nullptr;
414-
if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) {
415-
const int opIdx = oldEncoding.getOpIdx();
416-
const bool hasBatch = xShape.size() == 3;
417-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
418-
newShape[kIdx] *= 2;
419-
420-
// Note: For Intel the dot operands layout's kWidth parameter must match
421-
// the parent's DPAS layout opsPerChannel so we need to materialize a
422-
// new DPAS layout.
423-
if (auto dpasEncoding =
424-
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
425-
auto newDpasEncoding = intel::DpasEncodingAttr::get(
426-
ctx, dpasEncoding.getRepeatCount(),
427-
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
428-
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
429-
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
430-
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
431-
newVEncoding = DotOperandEncodingAttr::get(
432-
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
433-
} else {
434-
// Figure out the K dimension for the input A/B, given that the return
435-
// type is upcasted A/B type so we need to update the proper dim size.
436-
newVEncoding = DotOperandEncodingAttr::get(
437-
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
438-
oldEncoding.getKWidth() * 2);
439-
}
440-
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
441-
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
442-
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
443-
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
444-
int opIdx = sizePerThread.back() == 1 ? 1 : 0;
445-
sizePerThread[!opIdx] *= 2;
446-
newShape[!opIdx] *= 2;
447-
newVEncoding = BlockedEncodingAttr::get(
448-
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
449-
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),
450-
oldEncoding.getCTALayout());
407+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
408+
const int opIdx = oldEncoding.getOpIdx();
409+
const bool hasBatch = xShape.size() == 3;
410+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
411+
newShape[kIdx] *= 2;
412+
413+
// Note: For Intel the dot operands layout's kWidth parameter must match
414+
// the parent's DPAS layout opsPerChannel so we need to materialize a
415+
// new DPAS layout.
416+
if (auto dpasEncoding =
417+
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
418+
unsigned opsPerChannel =
419+
intel::DpasEncodingAttr::getOpsPerChannel(elemType);
420+
// e2m1 is packed 2 elements per int8, we must handle continuous 2
421+
// elements when upcasting to bf16
422+
if (xTy.getElementType() == IntegerType::get(ctx, 8))
423+
opsPerChannel *= 2;
424+
auto newDpasEncoding = intel::DpasEncodingAttr::get(
425+
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
426+
dpasEncoding.getExecutionSize(), opsPerChannel,
427+
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
428+
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
429+
newVEncoding = DotOperandEncodingAttr::get(
430+
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
431+
} else {
432+
// Figure out the K dimension for the input A/B, given that the return
433+
// type is upcasted A/B type so we need to update the proper dim size.
434+
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
435+
oldEncoding.getParent(),
436+
oldEncoding.getKWidth() * 2);
451437
}
452438
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
453439
}

python/test/unit/language/test_core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3537,7 +3537,12 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
35373537
if mma == 16 and K == 64:
35383538
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
35393539
if is_xpu():
3540-
if M == 128 and N == 128 and K == 64 and not col_a and not col_b and rhs_scale and normal_type == "e4m3" and mxfp_type == "bf16":
3540+
# skip cases: test_scaled_dot[32-64-128-False-False-True-e5m2-bf16-4-16-1]
3541+
# test_scaled_dot[64-32-128-False-False-True-e4m3-bf16-4-16-1]
3542+
# for L0 runtime error
3543+
if ((M == 32 and N == 64 and K == 128 and not col_a and not col_b and rhs_scale and normal_type == "e5m2"
3544+
and mxfp_type == "bf16") or (M == 64 and N == 32 and K == 128 and not col_a and not col_b and rhs_scale
3545+
and normal_type == "e4m3" and mxfp_type == "bf16")):
35413546
pytest.skip(
35423547
f"FIXME: {M}x{N}x{K} col_a={col_a} col_b={col_b} rhs_scale={rhs_scale} normal_type={normal_type} mxfp_type={mxfp_type}"
35433548
)

0 commit comments

Comments
 (0)