Skip to content

Commit 7ea98fc

Browse files
committed
conversion tritonGPUToLLVM elementwiseOpToLLVMBase decoupling
1 parent 913a93b commit 7ea98fc

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
#include "triton/Analysis/iluvatar_AxisInfo.h"
22
#include "triton/Analysis/iluvatar_Membar.h"
33
#include "triton/Analysis/iluvatar_Utility.h"
4+
#include "triton/Conversion/TritonGPUToLLVM/iluvatar_ElementwiseOpToLLVMBase.h"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#ifndef ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H
2+
#define ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H
3+
4+
#define FLAGTREE_SPEC_ElementwiseOpConversionBase_maybeDeduplicate
5+
#define FLAGTREE_SPEC_ElementwiseOpConversionBase_matchAndRewrite
6+
7+
#endif // ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H

third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,14 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
102102
// test_core::test_fp8_dot_acc
103103
return resultVals;
104104
}
105+
#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_maybeDeduplicate
105106
if (isa<IluvatarMmaEncodingAttr, DotOperandEncodingAttr>(baseEncoding)) {
106107
// TODO: this logic seems incorrect for mma layout. Skip for now.
107108
// The following test crashes and some other miscompile:
108109
// test_core::test_fp8_dot_acc
109110
return resultVals;
110111
}
112+
#endif
111113

112114
SmallVector<unsigned> elemsPerThread = getElemsPerThread(rtType);
113115
int rank = elemsPerThread.size();
@@ -188,7 +190,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
188190
// element type
189191
auto resultElementTy = getElementTypeOrSelf(resultTy);
190192
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
191-
#ifdef __ILUVATAR__
193+
#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_matchAndRewrite
192194
auto srcType = this->getTypeConverter()->convertType(resultTy);
193195
if (auto structTy = dyn_cast<LLVM::LLVMStructType>(srcType))
194196
elemTy = structTy.getBody()[0];

0 commit comments

Comments
 (0)