Skip to content

Commit bc32cd2

Browse files
committed
Address code review comments
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 9ab90fa commit bc32cd2

File tree

2 files changed

+46
-23
lines changed

2 files changed

+46
-23
lines changed

third_party/intel/include/Analysis/DPAS.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,16 @@ class DPASAnalysis {
4747
/// (aka threads per warp) size.
4848
Result canUseDPAS(FunctionOpInterface funcOp) const;
4949

50-
/// Given a DotOp operation, return its DPAS engine type.
50+
/// Given a 'DotOp' or 'ScaledDot' operation, return its DPAS engine type.
5151
static DPASEngineType getDPASType(Operation *op);
5252

53+
// clang-format off
54+
template <typename OpTy>
55+
typename std::enable_if<llvm::is_one_of<OpTy, DotOp, DotScaledOp>::value,
56+
DPASAnalysis::DPASEngineType>::type
57+
static getDPASType(OpTy);
58+
// clang-format on
59+
5360
private:
5461
mlir::ModuleOp mod;
5562

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/Triton/IR/Dialect.h"
55
#include "llvm/Support/Casting.h"
66
#include <iostream>
7+
#include <type_traits>
78

89
namespace mlir::triton::gpu::intel {
910

@@ -23,6 +24,7 @@ DPASAnalysis::DPASAnalysis(Operation *root) {
2324
funcOp.walk([&](Operation *op) {
2425
if (!isa<DotOp, DotScaledOp>(op))
2526
return;
27+
2628
if (it != funcToDotMap.end())
2729
it->second.push_back(op);
2830
else
@@ -72,21 +74,36 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {
7274
}
7375

7476
DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
75-
RankedTensorType aTy, bTy, cTy, dTy;
76-
Type aElemTy, bElemTy, cElemTy, dElemTy;
77+
if (auto dotOp = dyn_cast<DotOp>(op))
78+
return DPASAnalysis::getDPASType<DotOp>(dotOp);
79+
if (auto dotScaledOp = dyn_cast<DotScaledOp>(op))
80+
return DPASAnalysis::getDPASType(dotScaledOp);
81+
return DPASEngineType::NOT_APPLICABLE;
82+
}
83+
84+
// This function determines the DPAS engine type for the given operation.
85+
// It checks the element types of the tensors involved in the operation
86+
// and returns the appropriate DPAS engine type based on the type combinations.
87+
template <typename OpTy>
88+
typename std::enable_if<llvm::is_one_of<OpTy, DotOp, DotScaledOp>::value,
89+
DPASAnalysis::DPASEngineType>::type
90+
DPASAnalysis::getDPASType(OpTy op) {
91+
auto cTy = cast<RankedTensorType>(op.getC().getType());
92+
auto dTy = cast<RankedTensorType>(op.getD().getType());
93+
Type cElemTy = cTy.getElementType();
94+
Type dElemTy = dTy.getElementType();
95+
96+
assert(cElemTy == dElemTy && "Unexpected element type mismatch");
7797

78-
if (auto dotOp = dyn_cast<DotOp>(op)) {
98+
RankedTensorType aTy, bTy;
99+
Type aElemTy, bElemTy;
100+
101+
if constexpr (std::is_same_v<OpTy, DotOp>) {
79102
// d = a * b + c
80-
aTy = cast<RankedTensorType>(dotOp.getA().getType());
81-
bTy = cast<RankedTensorType>(dotOp.getB().getType());
82-
cTy = cast<RankedTensorType>(dotOp.getC().getType());
83-
dTy = cast<RankedTensorType>(dotOp.getD().getType());
103+
aTy = cast<RankedTensorType>(op.getA().getType());
104+
bTy = cast<RankedTensorType>(op.getB().getType());
84105
aElemTy = aTy.getElementType();
85106
bElemTy = bTy.getElementType();
86-
cElemTy = cTy.getElementType();
87-
dElemTy = dTy.getElementType();
88-
89-
assert(cElemTy == dElemTy && "Unexpected element type mismatch");
90107

91108
if (aElemTy != bElemTy)
92109
return DPASEngineType::NOT_APPLICABLE;
@@ -105,8 +122,7 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
105122
return DPASEngineType::FP32_FP32_FP16_FP16;
106123
if (aElemTy.isBF16())
107124
return DPASEngineType::FP32_FP32_BF16_BF16;
108-
if (aElemTy.isF32() &&
109-
dotOp.getInputPrecision() == InputPrecision::TF32)
125+
if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32)
110126
return DPASEngineType::FP32_FP32_TF32_TF32;
111127
// For FP8XFP8->FP32, upcast to FP16
112128
if (aElemTy.isFloat8E5M2())
@@ -123,17 +139,11 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
123139
}
124140
}
125141

126-
if (auto scaledDot = dyn_cast<DotScaledOp>(op)) {
127-
aTy = cast<RankedTensorType>(scaledDot.getLhs().getType());
128-
bTy = cast<RankedTensorType>(scaledDot.getRhs().getType());
129-
cTy = cast<RankedTensorType>(scaledDot.getC().getType());
130-
dTy = cast<RankedTensorType>(scaledDot.getD().getType());
142+
if constexpr (std::is_same_v<OpTy, DotScaledOp>) {
143+
aTy = cast<RankedTensorType>(op.getLhs().getType());
144+
bTy = cast<RankedTensorType>(op.getRhs().getType());
131145
aElemTy = aTy.getElementType();
132146
bElemTy = bTy.getElementType();
133-
cElemTy = cTy.getElementType();
134-
dElemTy = dTy.getElementType();
135-
136-
assert(cElemTy == dElemTy && "Unexpected element type mismatch");
137147

138148
if (isa<FloatType>(dElemTy)) {
139149
if (dElemTy.isF32()) {
@@ -163,4 +173,10 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
163173
return DPASEngineType::NOT_APPLICABLE;
164174
}
165175

176+
// Explicit instantiations.
177+
template DPASAnalysis::DPASEngineType
178+
DPASAnalysis::getDPASType<DotOp>(DotOp op);
179+
template DPASAnalysis::DPASEngineType
180+
DPASAnalysis::getDPASType<DotScaledOp>(DotScaledOp op);
181+
166182
} // namespace mlir::triton::gpu::intel

0 commit comments

Comments
 (0)