44#include " triton/Dialect/Triton/IR/Dialect.h"
55#include " llvm/Support/Casting.h"
66#include < iostream>
7+ #include < type_traits>
78
89namespace mlir ::triton::gpu::intel {
910
@@ -23,6 +24,7 @@ DPASAnalysis::DPASAnalysis(Operation *root) {
2324 funcOp.walk ([&](Operation *op) {
2425 if (!isa<DotOp, DotScaledOp>(op))
2526 return ;
27+
2628 if (it != funcToDotMap.end ())
2729 it->second .push_back (op);
2830 else
@@ -72,21 +74,36 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {
7274}
7375
7476DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType (Operation *op) {
75- RankedTensorType aTy, bTy, cTy, dTy;
76- Type aElemTy, bElemTy, cElemTy, dElemTy;
77+ if (auto dotOp = dyn_cast<DotOp>(op))
78+ return DPASAnalysis::getDPASType<DotOp>(dotOp);
79+ if (auto dotScaledOp = dyn_cast<DotScaledOp>(op))
80+ return DPASAnalysis::getDPASType (dotScaledOp);
81+ return DPASEngineType::NOT_APPLICABLE;
82+ }
83+
84+ // This function determines the DPAS engine type for the given operation.
85+ // It checks the element types of the tensors involved in the operation
86+ // and returns the appropriate DPAS engine type based on the type combinations.
87+ template <typename OpTy>
88+ typename std::enable_if<llvm::is_one_of<OpTy, DotOp, DotScaledOp>::value,
89+ DPASAnalysis::DPASEngineType>::type
90+ DPASAnalysis::getDPASType (OpTy op) {
91+ auto cTy = cast<RankedTensorType>(op.getC ().getType ());
92+ auto dTy = cast<RankedTensorType>(op.getD ().getType ());
93+ Type cElemTy = cTy.getElementType ();
94+ Type dElemTy = dTy.getElementType ();
95+
96+ assert (cElemTy == dElemTy && " Unexpected element type mismatch" );
7797
78- if (auto dotOp = dyn_cast<DotOp>(op)) {
98+ RankedTensorType aTy, bTy;
99+ Type aElemTy, bElemTy;
100+
101+ if constexpr (std::is_same_v<OpTy, DotOp>) {
79102 // d = a * b + c
80- aTy = cast<RankedTensorType>(dotOp.getA ().getType ());
81- bTy = cast<RankedTensorType>(dotOp.getB ().getType ());
82- cTy = cast<RankedTensorType>(dotOp.getC ().getType ());
83- dTy = cast<RankedTensorType>(dotOp.getD ().getType ());
103+ aTy = cast<RankedTensorType>(op.getA ().getType ());
104+ bTy = cast<RankedTensorType>(op.getB ().getType ());
84105 aElemTy = aTy.getElementType ();
85106 bElemTy = bTy.getElementType ();
86- cElemTy = cTy.getElementType ();
87- dElemTy = dTy.getElementType ();
88-
89- assert (cElemTy == dElemTy && " Unexpected element type mismatch" );
90107
91108 if (aElemTy != bElemTy)
92109 return DPASEngineType::NOT_APPLICABLE;
@@ -105,8 +122,7 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
105122 return DPASEngineType::FP32_FP32_FP16_FP16;
106123 if (aElemTy.isBF16 ())
107124 return DPASEngineType::FP32_FP32_BF16_BF16;
108- if (aElemTy.isF32 () &&
109- dotOp.getInputPrecision () == InputPrecision::TF32)
125+ if (aElemTy.isF32 () && op.getInputPrecision () == InputPrecision::TF32)
110126 return DPASEngineType::FP32_FP32_TF32_TF32;
111127 // For FP8XFP8->FP32, upcast to FP16
112128 if (aElemTy.isFloat8E5M2 ())
@@ -123,17 +139,11 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
123139 }
124140 }
125141
126- if (auto scaledDot = dyn_cast<DotScaledOp>(op)) {
127- aTy = cast<RankedTensorType>(scaledDot.getLhs ().getType ());
128- bTy = cast<RankedTensorType>(scaledDot.getRhs ().getType ());
129- cTy = cast<RankedTensorType>(scaledDot.getC ().getType ());
130- dTy = cast<RankedTensorType>(scaledDot.getD ().getType ());
142+ if constexpr (std::is_same_v<OpTy, DotScaledOp>) {
143+ aTy = cast<RankedTensorType>(op.getLhs ().getType ());
144+ bTy = cast<RankedTensorType>(op.getRhs ().getType ());
131145 aElemTy = aTy.getElementType ();
132146 bElemTy = bTy.getElementType ();
133- cElemTy = cTy.getElementType ();
134- dElemTy = dTy.getElementType ();
135-
136- assert (cElemTy == dElemTy && " Unexpected element type mismatch" );
137147
138148 if (isa<FloatType>(dElemTy)) {
139149 if (dElemTy.isF32 ()) {
@@ -163,4 +173,10 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
163173 return DPASEngineType::NOT_APPLICABLE;
164174}
165175
176+ // Explicit instantiations.
177+ template DPASAnalysis::DPASEngineType
178+ DPASAnalysis::getDPASType<DotOp>(DotOp op);
179+ template DPASAnalysis::DPASEngineType
180+ DPASAnalysis::getDPASType<DotScaledOp>(DotScaledOp op);
181+
166182} // namespace mlir::triton::gpu::intel
0 commit comments