Skip to content

Commit db07b9e

Browse files
Improve codegen for GEMM kernel with exponential function on one of the inputs of the tt.dot operation (#2360)
Fixes #2346 Provides a ~16% improvement in performance (for a 4Kx4Kx4K shape). --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 633c005 commit db07b9e

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Interfaces/SideEffectInterfaces.h"
77
#include "triton/Analysis/AxisInfo.h"
88
#include "triton/Dialect/Triton/IR/Dialect.h"
9+
#include "llvm/Support/Casting.h"
910
#include "llvm/Support/Debug.h"
1011

1112
#define DEBUG_TYPE "tritonintelgpu-pipeline"
@@ -57,12 +58,8 @@ static ttg::DotOperandEncodingAttr getDotEncodingFromUser(Operation *user) {
5758
if (isa<ttg::SharedEncodingAttr>(tensorType.getEncoding()))
5859
return allTransitiveUsesHaveDotEncoding(res);
5960

60-
if (auto op = dyn_cast<ttg::ConvertLayoutOp>(user))
61-
if (auto tensorType =
62-
dyn_cast<RankedTensorType>(op->getResult(0).getType()))
63-
return dyn_cast<ttg::DotOperandEncodingAttr>(tensorType.getEncoding());
64-
65-
return nullptr;
61+
return llvm::dyn_cast_or_null<ttg::DotOperandEncodingAttr>(
62+
tensorType.getEncoding());
6663
}
6764

6865
/// If all the transitive uses of the given value are used by a convert to the

0 commit comments

Comments
 (0)