Skip to content

Commit 424f98b

Browse files
authored
Make fuse reshape with load op deterministic (#5327)
Similar to #4323. The main idea is to get rid of structures where pointers are sorted. Signed-off-by: Anatoly Myachev <[email protected]>
1 parent c437c95 commit 424f98b

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

test/Triton/Intel/FuseReshape/fuse-reshape.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ tt.func public @fuseLoadWithReshape4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt
185185
// CHECK: [[ADD22:%.*]] = arith.addi [[MUL22]], %c1_i32 : i32
186186
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD12]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD22]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
187187
// CHECK: scf.for
188-
// CHECK: [[ADV:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<32x64xf16>>
188+
// CHECK: [[ADV:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<32x64xf16>>
189189
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>>
190190
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
191191
// CHECK: scf.yield
192192
// CHECK: scf.for
193-
// CHECK: [[ADV:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<32x64xf16>>
193+
// CHECK: [[ADV:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<32x64xf16>>
194194
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>>
195195
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
196196
// CHECK: scf.yield

third_party/intel/include/Utils/DefUseChain.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "mlir/IR/Value.h"
55
#include "llvm/ADT/SetVector.h"
6+
#include <unordered_set>
67

78
namespace mlir::triton::intel {
89

@@ -20,12 +21,6 @@ class DefUseChain {
2021

2122
DefUseChain() = delete;
2223

23-
bool operator<(const DefUseChain &other) const {
24-
if (start == other.start)
25-
return end < other.end;
26-
return start < other.start;
27-
}
28-
2924
bool operator==(const DefUseChain &other) const { return ops == other.ops; }
3025

3126
const Operations &getOps() const { return ops; }
@@ -61,13 +56,19 @@ class DefUseChain {
6156
Operation *end; //< last operation in the chain
6257
};
6358

59+
struct DefUseChainHash {
60+
size_t operator()(const mlir::triton::intel::DefUseChain &c) const noexcept {
61+
return llvm::hash_combine(c.getStart(), c.getEnd());
62+
}
63+
};
64+
6465
/// \class DefUseChainManager
6566
/// Manages collection of one or more \class DefUseChain.
6667
class DefUseChainManager {
6768
friend raw_ostream &operator<<(raw_ostream &, const DefUseChainManager &);
6869

6970
public:
70-
using DefUseChains = std::set<DefUseChain>;
71+
using DefUseChains = std::unordered_set<DefUseChain, DefUseChainHash>;
7172
using Operations = DefUseChain::Operations;
7273

7374
/// Create all def-use chains rooted at \p start and terminated by \p end.

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class FuseTransWithLoad {
111111
private:
112112
// Duplicate the root operation of the given chains.
113113
void duplicateRoot(DefUseChains &chains) const {
114-
std::map<Operation *, DefUseChains> rootToChains;
114+
std::unordered_map<Operation *, DefUseChains> rootToChains;
115115
for (const DefUseChain &chain : chains) {
116116
Operation *start = chain.getStart();
117117
if (!rootToChains[start].empty())

0 commit comments

Comments
 (0)