Skip to content

Commit 34fc7d8

Browse files
authored
Create a Fuser utility designed to fuse operations in the same DefUseChain (#5332)
The Triton XPU compiler has implemented a couple of transformations patterns to fuse "adjacent" operations (e.g. fuse a transpose with a load, fuse a reshape with a load). Both of these transformation leverage the `DefUseChain` infrastructure already built. This PR consolidate common functionality into a generic Fuser class so that derived classes can specialize the fusing transformation and reuse the base class. --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 8e84b4e commit 34fc7d8

File tree

6 files changed

+309
-373
lines changed

6 files changed

+309
-373
lines changed

test/TritonIntelGPU/dot-operands.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
356356
}
357357

358358
// -----
359+
359360
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
360361
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
361362
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ def TritonIntelFuseReshape
5555
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
5656
%load = tt.load %ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x512x64xf16>>
5757
%A = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
58-
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
58+
dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
5959

6060
The transformation drops the reshape operation, and generates:
6161
%div = %a / %b
6262
%ptr = tt.make_tensor_ptr %base_ptr, [%s0 * %div + %s1, %s2], [%b, %c], [%x * %div + %y, %z]
6363
{order = array<i32: 1, 0>} : <tensor<512x64xf16>>
6464
%A = tt.load %ptr {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<512x64xf16>>
65-
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
65+
dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
6666
}];
6767

6868
let dependentDialects = [

third_party/intel/include/Utils/DefUseChain.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef TRITON_INTEL_UTILS_DEFUSECHAIN_H
22
#define TRITON_INTEL_UTILS_DEFUSECHAIN_H
33

4+
#include "Utils/Utility.h"
45
#include "mlir/IR/Value.h"
6+
#include "mlir/Interfaces/LoopLikeInterface.h"
57
#include "llvm/ADT/SetVector.h"
68
#include <unordered_set>
79

@@ -97,6 +99,72 @@ class DefUseChainManager {
9799
DefUseChains chains;
98100
};
99101

102+
/// \class Fuser
103+
/// Abstract base class providing functionality to fuse operations within a
104+
/// set of def-use chains.
105+
class Fuser {
106+
protected:
107+
SmallPtrSet<Operation *, 8> cleanUp;
108+
109+
virtual ~Fuser() {
110+
if (!cleanUp.empty())
111+
eraseOperations(cleanUp);
112+
}
113+
114+
using DefUseChain = intel::DefUseChain;
115+
using DefUseChainManager = intel::DefUseChainManager;
116+
using DefUseChains = DefUseChainManager::DefUseChains;
117+
118+
// Delegate to derived classes details on which operations within a
119+
// DefUseChain to fuse.
120+
virtual void fuse(const DefUseChain &) = 0;
121+
122+
// Fuse operations in the given \p chains.
123+
void fuse(const DefUseChains &chains);
124+
125+
// Duplicate the root operation of the given \p chains.
126+
void duplicateRoot(DefUseChains &chains) const;
127+
128+
// Duplicate the root operation of \p sameRootChains and update \p chains.
129+
void duplicateRoot(DefUseChains &sameRootChains, DefUseChains &chains) const;
130+
131+
// Prune \p chains that cannot be handled during fusion. For example,
132+
// operations in the def-use chain should have a single user, except in
133+
// special circumstances (e.g. the root operation of a chain might have more
134+
// than one user).
135+
void pruneInvalid(DefUseChains &chains) const;
136+
137+
// Determine whether all operations in the given def-use \p chain have a
138+
// single user. Note: we allow an operation in the def-use chain to have an
139+
// additional user if the operation is in a for loop, and the additional user
140+
// is the loop yield operation, provided that the result yielded is not used
141+
// after the loop. Example:
142+
// make_tensor_ptr -> advance -> load (OK)
143+
// make_tensor_ptr -> for init_arg -> advance -> load (OK)
144+
// -> yield (OK)
145+
// make_tensor_ptr -> for init_arg -> advance -> load (OK)
146+
// -> yield -> load (NOT OK)
147+
//
148+
bool validateChain(const DefUseChain &chain) const;
149+
150+
// Propagate \p newVal to operations in the given def-use \p chain.
151+
void propagateToUsers(Value newVal, const DefUseChain &chain,
152+
IRMapping &mapping);
153+
154+
// Propagate \p newVal to users of \p origOp.
155+
void propagateToUsers(Value newVal, Value origVal, Operation *origOp,
156+
Operation *sentinel, IRMapping &mapping);
157+
158+
// If \p user is not \p sentinel, propagate \p newVal to \p user. Otherwise
159+
// terminate the propagation.
160+
virtual void propagateToUser(Value newVal, Value origVal, Operation *user,
161+
Operation *sentinel, IRMapping &mapping) = 0;
162+
163+
// Propagate \p newVal to users of \p origOp in the given \p loop.
164+
void propagateToLoop(Value newVal, Value origVal, LoopLikeOpInterface loopOp,
165+
Operation *sentinel, IRMapping &mapping);
166+
};
167+
100168
} // namespace mlir::triton::intel
101169

102170
#endif // TRITON_INTEL_UTILS_DEFUSECHAIN_H

0 commit comments

Comments
 (0)