Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
"TRITON_INTEL_REDUCE_TRANSPOSE"
"TRITON_INTEL_REDUCE_TRANSPOSE",
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"
// clang-format on
};

Expand Down
74 changes: 46 additions & 28 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
Expand Down Expand Up @@ -50,15 +51,21 @@ LogicalResult UpcastMXFPOp::verify() {
return success();
}

/// TODO: Temporarily disabled this check to allow for the blocked encoding.
/// Enable once we have the dot op encoding UpcastMXFPOp lowering.
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
if (!dotEncoding) {
if (mlir::triton::tools::getBoolEnv(
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") &&
!dotEncoding) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
return emitOpError(
"Expected a BlockOperandEncoding or LinearOperandEncoding "
"for scales");
}
if (!dotEncoding)
return success();

if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// Necessary to keep all of the scales of a given block of values in the
Expand Down Expand Up @@ -114,34 +121,45 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
newShape.back() *= 2;
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
} else {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);

const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
Type elemType = FloatType::getBF16(ctx);

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a new
// DPAS layout.
Attribute newVEncoding;
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
Attribute newVEncoding = nullptr;
if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) {
const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a
// new DPAS layout.
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(),
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
}
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
int opIdx = sizePerThread.back() == 1 ? 1 : 0;
sizePerThread[!opIdx] *= 2;
newShape[!opIdx] *= 2;
newVEncoding = BlockedEncodingAttr::get(
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),
oldEncoding.getCTALayout());
}
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
}
Expand Down
5 changes: 4 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,7 +3441,10 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
if is_xpu():
pytest.skip("scaled_dot isn't supported on XPU")
if M == 128 and N == 128 and K == 64 and not col_a and not col_b and rhs_scale and normal_type == "e4m3" and mxfp_type == "bf16":
pytest.skip(
f"FIXME: {M}x{N}x{K} col_a={col_a} col_b={col_b} rhs_scale={rhs_scale} normal_type={normal_type} mxfp_type={mxfp_type}"
)

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
Expand Down
2 changes: 1 addition & 1 deletion test/TritonIntelGPU/accelerate-matmul-pvc.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s
// RUN: TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING=1 triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s

// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 1], A = [32, 16], B = [16, 16], C = [32, 16]}>
// CHECK: #[[$DPAS_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
Expand Down
27 changes: 21 additions & 6 deletions third_party/intel/include/Analysis/DPAS.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ class DPASAnalysis {
BF16_BF16_BF16_BF16,
U32_U32_U8_U8,
S32_S32_S8_S8,
// data types for dot scaled.
FP32_FP32_BF16_FP8,
FP32_FP32_BF16_FP4,
FP32_FP32_FP8_BF16,
FP32_FP32_FP8_FP8,
FP32_FP32_FP8_FP4,
FP32_FP32_FP4_BF16,
FP32_FP32_FP4_FP8,
NOT_APPLICABLE
};

Expand All @@ -39,17 +47,24 @@ class DPASAnalysis {
/// (aka threads per warp) size.
Result canUseDPAS(FunctionOpInterface funcOp) const;

/// Given a DotOp operation, return its DPAS engine type.
static DPASEngineType getDPASType(DotOp op);
/// Given a 'DotOp' or 'ScaledDot' operation, return its DPAS engine type.
static DPASEngineType getDPASType(Operation *op);

// clang-format off
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need to turn off clang-format?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it formats the template horribly

template <typename OpTy>
typename std::enable_if<llvm::is_one_of<OpTy, DotOp, DotScaledOp>::value,
DPASAnalysis::DPASEngineType>::type
static getDPASType(OpTy);
// clang-format on

private:
mlir::ModuleOp mod;

/// Tracks Dot operations and their DPAS engine type.
std::map<DotOp, DPASEngineType> dotToDPASEngineMap;
/// Tracks Dot/DotScaled operations and their DPAS engine type.
std::map<Operation *, DPASEngineType> dotToDPASEngineMap;

/// Tracks the Dot operations contained in a function.
std::map<FunctionOpInterface, SmallVector<DotOp>> funcToDotMap;
/// Tracks the Dot/DotScaled operations contained in a function.
std::map<FunctionOpInterface, SmallVector<Operation *>> funcToDotMap;
};

} // namespace mlir::triton::gpu::intel
Expand Down
147 changes: 106 additions & 41 deletions third_party/intel/lib/Analysis/DPAS.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "intel/include/Analysis/DPAS.h"
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/Support/Casting.h"
#include <iostream>
#include <type_traits>

namespace mlir::triton::gpu::intel {

Expand All @@ -16,19 +21,22 @@ DPASAnalysis::DPASAnalysis(Operation *root) {
mod.walk([&](FunctionOpInterface funcOp) {
auto it = funcToDotMap.find(funcOp);

funcOp.walk([&](DotOp dotOp) {
funcOp.walk([&](Operation *op) {
if (!isa<DotOp, DotScaledOp>(op))
return;

if (it != funcToDotMap.end())
it->second.push_back(dotOp);
it->second.push_back(op);
else
funcToDotMap[funcOp] = {dotOp};
funcToDotMap[funcOp] = {op};

DPASEngineType dpasEngineType = supportDPAS
? DPASAnalysis::getDPASType(dotOp)
? DPASAnalysis::getDPASType(op)
: DPASEngineType::NOT_APPLICABLE;
if (dpasEngineType == DPASEngineType::FP32_FP32_TF32_TF32 &&
dotOp.getInputPrecision() != InputPrecision::TF32)
cast<DotOp>(op).getInputPrecision() != InputPrecision::TF32)
dpasEngineType = DPASEngineType::NOT_APPLICABLE;
dotToDPASEngineMap[dotOp] = dpasEngineType;
dotToDPASEngineMap[op] = dpasEngineType;
});
});
}
Expand All @@ -44,7 +52,7 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {

// Ensure all dot operations in the function can be lowered to DPAS
// instructions.
for (const DotOp &dotOp : it->second) {
for (Operation *dotOp : it->second) {
DPASEngineType dpasEngineType = dotToDPASEngineMap.at(dotOp);
if (dpasEngineType == DPASEngineType::NOT_APPLICABLE)
return Result::False;
Expand All @@ -65,53 +73,110 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {
return (threadsPerWarp == minSGSize) ? Result::True : Result::False;
}

DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(DotOp op) {
// d = a * b + c
auto aTy = cast<RankedTensorType>(op.getA().getType());
auto bTy = cast<RankedTensorType>(op.getB().getType());
DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
if (auto dotOp = dyn_cast<DotOp>(op))
return DPASAnalysis::getDPASType<DotOp>(dotOp);
if (auto dotScaledOp = dyn_cast<DotScaledOp>(op))
return DPASAnalysis::getDPASType(dotScaledOp);
return DPASEngineType::NOT_APPLICABLE;
}

// This function determines the DPAS engine type for the given operation.
// It checks the element types of the tensors involved in the operation
// and returns the appropriate DPAS engine type based on the type combinations.
template <typename OpTy>
typename std::enable_if<llvm::is_one_of<OpTy, DotOp, DotScaledOp>::value,
DPASAnalysis::DPASEngineType>::type
DPASAnalysis::getDPASType(OpTy op) {
auto cTy = cast<RankedTensorType>(op.getC().getType());
auto dTy = cast<RankedTensorType>(op.getD().getType());
Type aElemTy = aTy.getElementType();
Type bElemTy = bTy.getElementType();
Type cElemTy = cTy.getElementType();
Type dElemTy = dTy.getElementType();

assert(cElemTy == dElemTy && "Unexpected element type mismatch");

if (aElemTy != bElemTy)
return DPASEngineType::NOT_APPLICABLE;
RankedTensorType aTy, bTy;
Type aElemTy, bElemTy;

if constexpr (std::is_same_v<OpTy, DotOp>) {
// d = a * b + c
aTy = cast<RankedTensorType>(op.getA().getType());
bTy = cast<RankedTensorType>(op.getB().getType());
aElemTy = aTy.getElementType();
bElemTy = bTy.getElementType();

if (aElemTy != bElemTy)
return DPASEngineType::NOT_APPLICABLE;

if (dElemTy.isIntOrIndex()) {
if (dElemTy.getIntOrFloatBitWidth() == 32 &&
aElemTy.getIntOrFloatBitWidth() == 8)
return dElemTy.isSignedInteger() ? DPASEngineType::S32_S32_S8_S8
: DPASEngineType::U32_U32_U8_U8;
return DPASEngineType::NOT_APPLICABLE;
}

if (dElemTy.isIntOrIndex()) {
if (dElemTy.getIntOrFloatBitWidth() == 32 &&
aElemTy.getIntOrFloatBitWidth() == 8)
return dElemTy.isSignedInteger() ? DPASEngineType::S32_S32_S8_S8
: DPASEngineType::U32_U32_U8_U8;
return DPASEngineType::NOT_APPLICABLE;
if (isa<FloatType>(dElemTy)) {
if (dElemTy.isF32()) {
if (aElemTy.isF16())
return DPASEngineType::FP32_FP32_FP16_FP16;
if (aElemTy.isBF16())
return DPASEngineType::FP32_FP32_BF16_BF16;
if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32)
return DPASEngineType::FP32_FP32_TF32_TF32;
// For FP8XFP8->FP32, upcast to FP16
if (aElemTy.isFloat8E5M2())
return DPASEngineType::FP32_FP32_FP16_FP16;
if (aElemTy.isFloat8E4M3FN())
return DPASEngineType::FP32_FP32_FP16_FP16;
} else if (dElemTy.isF16()) {
if (aElemTy.isF16())
return DPASEngineType::FP16_FP16_FP16_FP16;
} else if (dElemTy.isBF16()) {
if (aElemTy.isBF16())
return DPASEngineType::BF16_BF16_BF16_BF16;
}
}
}

if (isa<FloatType>(dElemTy)) {
if (dElemTy.isF32()) {
if (aElemTy.isF16())
return DPASEngineType::FP32_FP32_FP16_FP16;
if (aElemTy.isBF16())
return DPASEngineType::FP32_FP32_BF16_BF16;
if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32)
return DPASEngineType::FP32_FP32_TF32_TF32;
// For FP8XFP8->FP32, upcast to FP16
if (aElemTy.isFloat8E5M2())
return DPASEngineType::FP32_FP32_FP16_FP16;
if (aElemTy.isFloat8E4M3FN())
return DPASEngineType::FP32_FP32_FP16_FP16;
} else if (dElemTy.isF16()) {
if (aElemTy.isF16())
return DPASEngineType::FP16_FP16_FP16_FP16;
} else if (dElemTy.isBF16()) {
if (aElemTy.isBF16())
return DPASEngineType::BF16_BF16_BF16_BF16;
if constexpr (std::is_same_v<OpTy, DotScaledOp>) {
aTy = cast<RankedTensorType>(op.getLhs().getType());
bTy = cast<RankedTensorType>(op.getRhs().getType());
aElemTy = aTy.getElementType();
bElemTy = bTy.getElementType();

if (isa<FloatType>(dElemTy)) {
if (dElemTy.isF32()) {
if (aElemTy.isBF16() &&
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
return DPASEngineType::FP32_FP32_BF16_FP8;
// 2 E2M1 are packed into 1 int8
if (aElemTy.isBF16() && bElemTy.isInteger(8))
return DPASEngineType::FP32_FP32_BF16_FP4;
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
bElemTy.isBF16())
return DPASEngineType::FP32_FP32_FP8_BF16;
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
return DPASEngineType::FP32_FP32_FP8_FP8;
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
bElemTy.isInteger(8))
return DPASEngineType::FP32_FP32_FP8_FP4;
if (aElemTy.isInteger(8) && bElemTy.isBF16())
return DPASEngineType::FP32_FP32_FP4_BF16;
if (aElemTy.isInteger(8) &&
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
return DPASEngineType::FP32_FP32_FP4_FP8;
}
}
}

return DPASEngineType::NOT_APPLICABLE;
}

// Explicit instantiations.
template DPASAnalysis::DPASEngineType
DPASAnalysis::getDPASType<DotOp>(DotOp op);
template DPASAnalysis::DPASEngineType
DPASAnalysis::getDPASType<DotScaledOp>(DotScaledOp op);

} // namespace mlir::triton::gpu::intel
1 change: 1 addition & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_triton_library(TritonIntelGPUToLLVM
TritonGPUToLLVM.cpp
TritonOpsToLLVM.cpp
TypeConverter.cpp
UpcastMXFPToLLVM.cpp
Utility.cpp
ViewOpToLLVM.cpp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ void populateElementwiseOpToLLVMPatterns(
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfo &targetInfo,
PatternBenefit benefit);

void populateBF16CastsLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);
Expand Down
Loading
Loading