Skip to content

Commit e47671d

Browse files
CopilothanhanW
authored andcommitted
Fix ReplicateGlobalsPerAffinity to maintain correct order of globals and initializers (iree-org#22401)
## Summary Fix ordering issue in ReplicateGlobalsPerAffinity pass where new globals and their initializers were incorrectly inserted immediately after the original global, breaking dependency order when the original global had an initializer. ## Problem The ReplicateGlobalsPerAffinity pass creates per-affinity copies of globals by inserting new global operations and their initializers right after the original global operation. This breaks the correct ordering when the original global has an initializer, as the new initializers (which load from the original global) would be placed before the original global's initializer runs. **Before (incorrect ordering):** ``` global global_a <- inserted here initializer_a <- loads from global (not yet initialized!) initializer <- initializes global ``` **After (correct ordering):** ``` global initializer <- initializes global global_a <- inserted after initializer initializer_a <- can safely load from global ``` ## Solution Modified `ValuePerAffinityHelper::getOrCreateGlobalForAffinity()` to track and insert new globals after the last initializer that references the original global: 1. **Implemented constructor caching**: Pre-computes the insertion point (last initializer) for each global upfront using GlobalTable 2. **Efficient lookup**: Uses GlobalTable's forEach method with direct access to storeOps collection after calling rebuild() 3. **Correct ordering**: Gets parent initializer directly from store operations, ensuring new globals are inserted after all initializers that reference the original global ## Testing Added comprehensive test case `unknown_global_device_with_initializer` that verifies: - Original global and its initializer appear first - Replicated globals and their initializers follow in correct dependency order - Each replicated initializer can safely load from the initialized original global ## Performance - O(N+M) complexity: GlobalTable scans globals and their uses once, then forEach provides efficient iteration - No redundant IR walking: uses pre-collected storeOps from GlobalTable - Efficient lookup: cached insertion points avoid repeated scans Fixes iree-org#22399 --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: hanhanW <[email protected]>
1 parent e9907a8 commit e47671d

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

compiler/src/iree/compiler/Dialect/Stream/Transforms/ReplicateGlobalsPerAffinity.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
1010
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
1111
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
12+
#include "iree/compiler/Dialect/Util/Analysis/GlobalTable.h"
1213
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
1314
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
1415
#include "llvm/ADT/STLExtras.h"
@@ -35,8 +36,7 @@ namespace {
3536
// from the new global op for the requested affinity.
3637
class ValuePerAffinityHelper {
3738
public:
38-
explicit ValuePerAffinityHelper(mlir::ModuleOp moduleOp)
39-
: builder(moduleOp), symbolTable(moduleOp) {}
39+
explicit ValuePerAffinityHelper(mlir::ModuleOp moduleOp);
4040
~ValuePerAffinityHelper() = default;
4141

4242
using OpAffinityPair = std::tuple<Operation *, IREE::Stream::AffinityAttr>;
@@ -70,8 +70,41 @@ class ValuePerAffinityHelper {
7070
SymbolTable symbolTable;
7171
DenseMap<OpAffinityPair, IREE::Util::GlobalOpInterface> cachedGlobals;
7272
DenseMap<ValueAffinityPair, Value> cachedValuePerAffinity;
73+
74+
// Cache the insertion point (last initializer or the global itself) for each
75+
// global for performance.
76+
DenseMap<Operation *, Operation *> cachedInsertionPointForGlobal;
7377
};
7478

79+
ValuePerAffinityHelper::ValuePerAffinityHelper(mlir::ModuleOp moduleOp)
80+
: builder(moduleOp), symbolTable(moduleOp) {
81+
// Pre-compute the insertion point for each global for performance.
82+
// This avoids scanning all initializers multiple times during transformation.
83+
IREE::Util::GlobalTable globalTable(moduleOp);
84+
globalTable.rebuild();
85+
globalTable.forEach([&](IREE::Util::Global &global) {
86+
// Initialize with the global op itself as the default insertion point.
87+
Operation *insertionPoint = global.op.getOperation();
88+
89+
// Iterate through store operations to find initializers.
90+
for (auto storeOp : global.storeOps) {
91+
// Get the parent initializer op if the store is within one.
92+
auto initOp = storeOp->getParentOfType<IREE::Util::InitializerOp>();
93+
if (!initOp) {
94+
continue;
95+
}
96+
97+
// Update the insertion point if this is the latest initializer.
98+
if (insertionPoint->isBeforeInBlock(initOp)) {
99+
insertionPoint = initOp;
100+
}
101+
}
102+
103+
cachedInsertionPointForGlobal[global.op.getOperation()] = insertionPoint;
104+
return IREE::Util::GlobalAction::PRESERVE;
105+
});
106+
}
107+
75108
Value ValuePerAffinityHelper::getOrCreateValueForAffinity(
76109
OpOperand *opOperand, IREE::Stream::AffinityAttr affinityAttr) {
77110
ValueAffinityPair key = {opOperand->get(), affinityAttr};
@@ -126,9 +159,15 @@ ValuePerAffinityHelper::getOrCreateGlobalForAffinity(
126159
return cachedGlobals.lookup(key);
127160
}
128161

129-
// Create an initializer right after the global op.
162+
// Find the insertion point: after the last initializer that references this
163+
// global, or after the global itself if no initializers exist.
164+
// The cache was already populated in the constructor.
165+
Operation *insertionPoint =
166+
cachedInsertionPointForGlobal.lookup(globalOp.getOperation());
167+
168+
// Create the new global and initializer after the insertion point.
130169
Location loc = globalOp.getLoc();
131-
builder.setInsertionPointAfter(globalOp);
170+
builder.setInsertionPointAfter(insertionPoint);
132171
std::string newGlobalName = getNewGlobalName(globalName, affinityAttr);
133172
auto newGlobalOp = IREE::Util::GlobalOp::create(builder, loc, newGlobalName,
134173
/*isMutable=*/false,

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/replicate_globals_per_affinity.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,61 @@ util.func private @unknown_global_device(%arg0: tensor<10xf32>) -> (tensor<10xf3
4343

4444
// -----
4545

46+
// Test case with a global that has an initializer.
47+
// The new globals should be placed after the original global's initializer.
48+
49+
// CHECK: util.global private @[[$DEVICE_A:.+]] : !hal.device
50+
// CHECK: util.global private @[[$DEVICE_B:.+]] : !hal.device
51+
util.global private @device_a : !hal.device
52+
util.global private @device_b : !hal.device
53+
54+
// CHECK: util.global private @[[$GLOBAL:.+]] : tensor<10xf32>
55+
// CHECK: util.initializer {
56+
// CHECK: %[[CST:.+]] = arith.constant dense<0.0{{.*}}> : tensor<10xf32>
57+
// CHECK: util.global.store %[[CST]], @[[$GLOBAL]]
58+
// CHECK: }
59+
// CHECK: util.global private @[[$GLOBAL_B:.+]] : tensor<10xf32>
60+
// CHECK: util.initializer {
61+
// CHECK: %[[LOAD:.+]] = util.global.load @[[$GLOBAL]]
62+
// CHECK: %[[TRANSFER_B:.+]] = flow.tensor.transfer %[[LOAD]] {{.+}} to #hal.device.affinity<@[[$DEVICE_B]]>
63+
// CHECK: util.global.store %[[TRANSFER_B]], @[[$GLOBAL_B]]
64+
// CHECK: }
65+
// CHECK: util.global private @[[$GLOBAL_A:.+]] : tensor<10xf32>
66+
// CHECK: util.initializer {
67+
// CHECK: %[[LOAD:.+]] = util.global.load @[[$GLOBAL]]
68+
// CHECK: %[[TRANSFER_A:.+]] = flow.tensor.transfer %[[LOAD]] {{.+}} to #hal.device.affinity<@[[$DEVICE_A]]>
69+
// CHECK: util.global.store %[[TRANSFER_A]], @[[$GLOBAL_A]]
70+
// CHECK: }
71+
util.global private @global : tensor<10xf32>
72+
util.initializer {
73+
%0 = arith.constant dense<0.0> : tensor<10xf32>
74+
util.global.store %0, @global : tensor<10xf32>
75+
util.return
76+
}
77+
78+
// CHECK-LABEL: @unknown_global_device_with_initializer(
79+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
80+
util.func private @unknown_global_device_with_initializer(%arg0: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) {
81+
// CHECK: %[[OPERAND_A:.+]] = flow.tensor.transfer %[[ARG0]] {{.+}} to #hal.device.affinity<@[[$DEVICE_A]]>
82+
%0 = flow.tensor.transfer %arg0 : tensor<10xf32> to #hal.device.affinity<@device_a>
83+
84+
// CHECK: %[[LOAD_A:.+]] = util.global.load immutable @[[$GLOBAL_A]]
85+
%global = util.global.load immutable @global : tensor<10xf32>
86+
87+
// CHECK: flow.dispatch @dispatch(%[[OPERAND_A]], %[[LOAD_A]])
88+
%1 = flow.dispatch @dispatch(%0, %global) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
89+
90+
// CHECK: %[[OPERAND_B:.+]] = flow.tensor.transfer %[[ARG0]] {{.+}} to #hal.device.affinity<@[[$DEVICE_B]]>
91+
%2 = flow.tensor.transfer %arg0 : tensor<10xf32> to #hal.device.affinity<@device_b>
92+
93+
// CHECK: %[[LOAD_B:.+]] = util.global.load immutable @[[$GLOBAL_B]]
94+
// CHECK: flow.dispatch @dispatch(%[[OPERAND_B]], %[[LOAD_B]])
95+
%3 = flow.dispatch @dispatch(%2, %global) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
96+
util.return %1, %3 : tensor<10xf32>, tensor<10xf32>
97+
}
98+
99+
// -----
100+
46101
// CHECK: util.global private @[[$DEVICE_A:.+]] : !hal.device
47102
// CHECK: util.global private @[[$DEVICE_B:.+]] : !hal.device
48103
util.global private @device_a : !hal.device

0 commit comments

Comments
 (0)