Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
"TRITON_INTEL_REDUCE_TRANSPOSE"
"TRITON_INTEL_REDUCE_TRANSPOSE",
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"
// clang-format on
};

Expand Down
140 changes: 80 additions & 60 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
Expand Down Expand Up @@ -50,47 +51,51 @@ LogicalResult UpcastMXFPOp::verify() {
return success();
}

auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
if (!dotEncoding) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
/// TODO: Temporarily disabled this check to allow for the blocked encoding.
/// we need to re-enable this check once we have the dot op encoding
/// UpcastMXFPOp lowering
// auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
// if (!dotEncoding) {
// return emitOpError("Expected a DotOperandEncodingAttr for values");
// }
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
return emitOpError(
"Expected a BlockOperandEncoding or LinearOperandEncoding "
"for scales");
}

if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// Necessary to keep all of the scales of a given block of values in the
// same warp
auto threadsPerWarp =
cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
return emitOpError("Expected threads per warp to be {16, 2}");
}
}

// Change to support fp8 types
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
// Figure out the K dimension for the input A/B. For A/B scale, the K
// dimension is always the last dimension.
const int opIdx = dotEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;

if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
return emitOpError("K dimension of first operand must be 16 times "
"larger than last/K dimension of the second operand");
}

// Check other dimensions match too. For input A/B, we need to figure out the
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
if (hasBatch && xShape[0] != scaleShape[0])
return emitOpError("batch dimension must match between operands");
if (xShape[mnIdx] != scaleShape[hasBatch]) {
return emitOpError("M/N dimension must match between operands");
}
// if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// // Necessary to keep all of the scales of a given block of values in the
// // same warp
// auto threadsPerWarp =
// cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
// if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
// return emitOpError("Expected threads per warp to be {16, 2}");
// }
// }

// // Change to support fp8 types
// const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
// // Figure out the K dimension for the input A/B. For A/B scale, the K
// // dimension is always the last dimension.
// const int opIdx = dotEncoding.getOpIdx();
// const bool hasBatch = xShape.size() == 3;
// const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;

// if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
// return emitOpError("K dimension of first operand must be 16 times "
// "larger than last/K dimension of the second operand");
// }

// // Check other dimensions match too. For input A/B, we need to figure out
// the
// // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
// const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
// if (hasBatch && xShape[0] != scaleShape[0])
// return emitOpError("batch dimension must match between operands");
// if (xShape[mnIdx] != scaleShape[hasBatch]) {
// return emitOpError("M/N dimension must match between operands");
// }

return success();
}
Expand All @@ -105,6 +110,8 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
auto xShape = xTy.getShape();

auto encoding = xTy.getEncoding();
bool upcastMXFPUseDotOpEnc =
mlir::triton::tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING");

if (typeEncoded == ScaleDotElemType::E2M1) {
RankedTensorType retTy;
Expand All @@ -114,34 +121,47 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
newShape.back() *= 2;
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
} else {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);

const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
Type elemType = FloatType::getBF16(ctx);

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a new
// DPAS layout.
Attribute newVEncoding;
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
if (upcastMXFPUseDotOpEnc) {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);

const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a
// new DPAS layout.
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(),
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
}
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding);
assert(oldEncoding &&
"Expected a blocked encoding for UpcastMXFP op result.");
newShape.back() *= 2;
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
sizePerThread.back() *= 2;
newVEncoding = BlockedEncodingAttr::get(
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),
oldEncoding.getCTALayout());
}
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
}
Expand Down
3 changes: 2 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,7 +3441,8 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
if is_xpu():
pytest.skip("scaled_dot isn't supported on XPU")
if rhs_scale:
pytest.skip("scaled_dot with rhs_scale not supported on XPU")

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
Expand Down
2 changes: 1 addition & 1 deletion test/TritonIntelGPU/accelerate-matmul-pvc.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s
// RUN: TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING=1 triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s

// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 1], A = [32, 16], B = [16, 16], C = [32, 16]}>
// CHECK: #[[$DPAS_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
Expand Down
18 changes: 13 additions & 5 deletions third_party/intel/include/Analysis/DPAS.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class DPASAnalysis {
FP32_FP32_TF32_TF32,
FP16_FP16_FP16_FP16,
BF16_BF16_BF16_BF16,
// data types for dot scaled.
FP32_FP32_BF16_FP8,
FP32_FP32_BF16_FP4,
FP32_FP32_FP8_BF16,
FP32_FP32_FP8_FP8,
FP32_FP32_FP8_FP4,
FP32_FP32_FP4_BF16,
FP32_FP32_FP4_FP8,
U32_U32_U8_U8,
S32_S32_S8_S8,
NOT_APPLICABLE
Expand All @@ -40,16 +48,16 @@ class DPASAnalysis {
Result canUseDPAS(FunctionOpInterface funcOp) const;

/// Given a DotOp operation, return its DPAS engine type.
static DPASEngineType getDPASType(DotOp op);
static DPASEngineType getDPASType(Operation *op);

private:
mlir::ModuleOp mod;

/// Tracks Dot operations and their DPAS engine type.
std::map<DotOp, DPASEngineType> dotToDPASEngineMap;
/// Tracks Dot/DotScaled operations and their DPAS engine type.
std::map<Operation *, DPASEngineType> dotToDPASEngineMap;

/// Tracks the Dot operations contained in a function.
std::map<FunctionOpInterface, SmallVector<DotOp>> funcToDotMap;
/// Tracks the Dot/DotScaled operations contained in a function.
std::map<FunctionOpInterface, SmallVector<Operation *>> funcToDotMap;
};

} // namespace mlir::triton::gpu::intel
Expand Down
Loading
Loading