Skip to content

Commit e063f17

Browse files
committed
utility decoupling
1 parent c031abd commit e063f17

File tree

6 files changed

+91
-78
lines changed

6 files changed

+91
-78
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
#include "triton/Analysis/iluvatar_AxisInfo.h"
22
#include "triton/Analysis/iluvatar_Membar.h"
3+
#include "triton/Analysis/iluvatar_Utility.h"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#ifndef ILUVATAR_TRITON_ANALYSIS_UTILITY_H
2+
#define ILUVATAR_TRITON_ANALYSIS_UTILITY_H
3+
4+
#define FLAGTREE_SPEC_Utility_Function
5+
6+
#endif // TRITON_ANALYSIS_UTILITY_H

third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_triton_library(FlagTree_iluvatar_TritonAnalysis
22
AxisInfo.cpp
33
Membar.cpp
4+
Utility.cpp
45

56
DEPENDS
67
TritonTableGen
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#include "triton/Analysis/Utility.h"
2+
3+
namespace mlir {
4+
5+
bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
6+
7+
auto srcLayout = srcTy.getEncoding();
8+
auto dstLayout = dstTy.getEncoding();
9+
if (!srcLayout.isa<triton::gpu::IluvatarMmaEncodingAttr>())
10+
return false;
11+
auto mmaLayout = srcLayout.cast<triton::gpu::IluvatarMmaEncodingAttr>();
12+
if (!dstLayout.isa<triton::gpu::DotOperandEncodingAttr>())
13+
return false;
14+
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
15+
auto dstParLayout = dotOperandLayout.getParent();
16+
if (!dstParLayout.isa<triton::gpu::IluvatarMmaEncodingAttr>())
17+
return false;
18+
auto dstMmaLayout =
19+
dstParLayout.dyn_cast<triton::gpu::IluvatarMmaEncodingAttr>();
20+
return !isMmaToDotShortcut(srcTy, dstTy) &&
21+
mmaLayout.getVersionMajor() == 1 &&
22+
dstMmaLayout.getVersionMajor() == 1 &&
23+
mmaLayout.getWarpsPerCTA()[0] == dstMmaLayout.getWarpsPerCTA()[0] &&
24+
dotOperandLayout.getOpIdx() == 0 && !srcTy.getElementType().isF32();
25+
}
26+
27+
void getBackwardSliceImplCorex(Operation *op,
28+
SetVector<Operation *> *backwardSlice,
29+
TransitiveFilter filter,
30+
bool omitBlockArguments) {
31+
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
32+
return;
33+
34+
// Evaluate whether we should keep this def.
35+
// This is useful in particular to implement scoping; i.e. return the
36+
// transitive backwardSlice in the current scope.
37+
if (filter && !filter(op))
38+
return;
39+
40+
for (const auto &en : llvm::enumerate(op->getOperands())) {
41+
auto operand = en.value();
42+
if (auto *definingOp = operand.getDefiningOp()) {
43+
if (backwardSlice->count(definingOp) == 0)
44+
getBackwardSliceImplCorex(definingOp, backwardSlice, filter,
45+
omitBlockArguments);
46+
} else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
47+
if (omitBlockArguments)
48+
continue;
49+
50+
Block *block = blockArg.getOwner();
51+
Operation *parentOp = block->getParentOp();
52+
// TODO: determine whether we want to recurse backward into the other
53+
// blocks of parentOp, which are not technically backward unless they flow
54+
// into us. For now, just bail.
55+
if (parentOp && backwardSlice->count(parentOp) == 0) {
56+
// assert(parentOp->getNumRegions() == 1 &&
57+
// parentOp->getRegion(0).getBlocks().size() == 1);
58+
getBackwardSliceImplCorex(parentOp, backwardSlice, filter,
59+
omitBlockArguments);
60+
}
61+
} else {
62+
llvm_unreachable("No definingOp and not a block argument.");
63+
}
64+
}
65+
66+
backwardSlice->insert(op);
67+
}
68+
69+
void getBackwardSliceCorex(Operation *op, SetVector<Operation *> *backwardSlice,
70+
TransitiveFilter filter, bool omitBlockArguments) {
71+
getBackwardSliceImplCorex(op, backwardSlice, filter, omitBlockArguments);
72+
73+
// Don't insert the top level operation, we just queried on it and don't
74+
// want it in the results.
75+
backwardSlice->remove(op);
76+
}
77+
78+
}

third_party/iluvatar/include/triton/Analysis/Utility.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "triton/Dialect/Triton/IR/Dialect.h"
88
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
99

10+
#include "triton/../../backend/flagtree_backend_specialization/include/flagtree_spec.h"
11+
1012
namespace mlir {
1113

1214
inline bool isZeroConst(Value v) {
@@ -192,8 +194,6 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
192194

193195
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
194196

195-
bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
196-
197197
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
198198

199199
// Return true if the src and dst layout match.
@@ -212,7 +212,9 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
212212
SetVector<Operation *>
213213
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
214214

215-
#ifdef __ILUVATAR__
215+
#ifdef FLAGTREE_SPEC_Utility_Function
216+
bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
217+
216218
/// This function dones't use assertion check.
217219
void getBackwardSliceCorex(Operation *op, SetVector<Operation *> *backwardSlice,
218220
TransitiveFilter filter = nullptr,

third_party/iluvatar/lib/Analysis/Utility.cpp

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -676,28 +676,6 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
676676
#endif
677677
}
678678

679-
bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
680-
681-
auto srcLayout = srcTy.getEncoding();
682-
auto dstLayout = dstTy.getEncoding();
683-
if (!srcLayout.isa<triton::gpu::IluvatarMmaEncodingAttr>())
684-
return false;
685-
auto mmaLayout = srcLayout.cast<triton::gpu::IluvatarMmaEncodingAttr>();
686-
if (!dstLayout.isa<triton::gpu::DotOperandEncodingAttr>())
687-
return false;
688-
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
689-
auto dstParLayout = dotOperandLayout.getParent();
690-
if (!dstParLayout.isa<triton::gpu::IluvatarMmaEncodingAttr>())
691-
return false;
692-
auto dstMmaLayout =
693-
dstParLayout.dyn_cast<triton::gpu::IluvatarMmaEncodingAttr>();
694-
return !isMmaToDotShortcut(srcTy, dstTy) &&
695-
mmaLayout.getVersionMajor() == 1 &&
696-
dstMmaLayout.getVersionMajor() == 1 &&
697-
mmaLayout.getWarpsPerCTA()[0] == dstMmaLayout.getWarpsPerCTA()[0] &&
698-
dotOperandLayout.getOpIdx() == 0 && !srcTy.getElementType().isF32();
699-
}
700-
701679
namespace {
702680

703681
/// A data structure similar to SetVector but maintains
@@ -830,59 +808,6 @@ multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
830808
return res;
831809
}
832810

833-
#ifdef __ILUVATAR__
834-
void getBackwardSliceImplCorex(Operation *op,
835-
SetVector<Operation *> *backwardSlice,
836-
TransitiveFilter filter,
837-
bool omitBlockArguments) {
838-
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
839-
return;
840-
841-
// Evaluate whether we should keep this def.
842-
// This is useful in particular to implement scoping; i.e. return the
843-
// transitive backwardSlice in the current scope.
844-
if (filter && !filter(op))
845-
return;
846-
847-
for (const auto &en : llvm::enumerate(op->getOperands())) {
848-
auto operand = en.value();
849-
if (auto *definingOp = operand.getDefiningOp()) {
850-
if (backwardSlice->count(definingOp) == 0)
851-
getBackwardSliceImplCorex(definingOp, backwardSlice, filter,
852-
omitBlockArguments);
853-
} else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
854-
if (omitBlockArguments)
855-
continue;
856-
857-
Block *block = blockArg.getOwner();
858-
Operation *parentOp = block->getParentOp();
859-
// TODO: determine whether we want to recurse backward into the other
860-
// blocks of parentOp, which are not technically backward unless they flow
861-
// into us. For now, just bail.
862-
if (parentOp && backwardSlice->count(parentOp) == 0) {
863-
// assert(parentOp->getNumRegions() == 1 &&
864-
// parentOp->getRegion(0).getBlocks().size() == 1);
865-
getBackwardSliceImplCorex(parentOp, backwardSlice, filter,
866-
omitBlockArguments);
867-
}
868-
} else {
869-
llvm_unreachable("No definingOp and not a block argument.");
870-
}
871-
}
872-
873-
backwardSlice->insert(op);
874-
}
875-
876-
void getBackwardSliceCorex(Operation *op, SetVector<Operation *> *backwardSlice,
877-
TransitiveFilter filter, bool omitBlockArguments) {
878-
getBackwardSliceImplCorex(op, backwardSlice, filter, omitBlockArguments);
879-
880-
// Don't insert the top level operation, we just queried on it and don't
881-
// want it in the results.
882-
backwardSlice->remove(op);
883-
}
884-
#endif
885-
886811
SetVector<Operation *> multiRootGetSlice(Operation *op,
887812
TransitiveFilter backwardFilter,
888813
TransitiveFilter forwardFilter,

0 commit comments

Comments
 (0)