Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
"TRITON_INTEL_REDUCE_TRANSPOSE"
"TRITON_INTEL_REDUCE_TRANSPOSE",
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"
// clang-format on
};

Expand Down
74 changes: 46 additions & 28 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
Expand Down Expand Up @@ -50,15 +51,21 @@ LogicalResult UpcastMXFPOp::verify() {
return success();
}

/// TODO: Temporarily disabled this check to allow for the blocked encoding.
/// Enable once we have the dot op encoding UpcastMXFPOp lowering.
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
if (!dotEncoding) {
if (mlir::triton::tools::getBoolEnv(
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") &&
!dotEncoding) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
return emitOpError(
"Expected a BlockOperandEncoding or LinearOperandEncoding "
"for scales");
}
if (!dotEncoding)
return success();

if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// Necessary to keep all of the scales of a given block of values in the
Expand Down Expand Up @@ -114,34 +121,45 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
newShape.back() *= 2;
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
} else {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);

const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
Type elemType = FloatType::getBF16(ctx);

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a new
// DPAS layout.
Attribute newVEncoding;
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
Attribute newVEncoding = nullptr;
if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) {
const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a
// new DPAS layout.
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(),
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
}
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
int opIdx = sizePerThread.back() == 1 ? 1 : 0;
sizePerThread[!opIdx] *= 2;
newShape[!opIdx] *= 2;
newVEncoding = BlockedEncodingAttr::get(
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),
oldEncoding.getCTALayout());
}
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
}
Expand Down
5 changes: 4 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,7 +3441,10 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
if is_xpu():
pytest.skip("scaled_dot isn't supported on XPU")
if M == 128 and N == 128 and K == 64 and not col_a and not col_b and rhs_scale and normal_type == "e4m3" and mxfp_type == "bf16":
pytest.skip(
f"FIXME: {M}x{N}x{K} col_a={col_a} col_b={col_b} rhs_scale={rhs_scale} normal_type={normal_type} mxfp_type={mxfp_type}"
)

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
Expand Down
Loading