Skip to content

Commit 7ce800b

Browse files
authored
[MLIR][Presburger] Fix bug in PresburgerSpace::convertVarKind (#67267)
This patch fixes a bug in PresburgerSpace::convertVarKind where the identifiers were not moved properly due to offset being invalidated.
1 parent e5038f0 commit 7ce800b

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

mlir/lib/Analysis/Presburger/PresburgerSpace.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,26 @@ void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
161161
assert(dstPos <= getNumVarKind(dstKind) &&
162162
"invalid position for destination variables");
163163

164+
// Move identifiers if `usingIds` and variables moved are not locals.
165+
unsigned srcOffset = getVarKindOffset(srcKind) + srcPos;
166+
unsigned dstOffset = getVarKindOffset(dstKind) + dstPos;
167+
if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
168+
identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
169+
// Update srcOffset if insertion of new elements invalidates it.
170+
if (dstOffset < srcOffset)
171+
srcOffset += num;
172+
std::move(identifiers.begin() + srcOffset,
173+
identifiers.begin() + srcOffset + num,
174+
identifiers.begin() + dstOffset);
175+
identifiers.erase(identifiers.begin() + srcOffset,
176+
identifiers.begin() + srcOffset + num);
177+
} else if (isUsingIds() && srcKind != VarKind::Local) {
178+
identifiers.erase(identifiers.begin() + srcOffset,
179+
identifiers.begin() + srcOffset + num);
180+
} else if (isUsingIds() && dstKind != VarKind::Local) {
181+
identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
182+
}
183+
164184
auto addVars = [&](VarKind kind, int num) {
165185
switch (kind) {
166186
case VarKind::Domain:
@@ -180,22 +200,6 @@ void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
180200

181201
addVars(srcKind, -(signed)num);
182202
addVars(dstKind, num);
183-
184-
// Move identifiers if `usingIds` and variables moved are not locals.
185-
unsigned srcOffset = getVarKindOffset(srcKind) + srcPos;
186-
unsigned dstOffset = getVarKindOffset(dstKind) + dstPos;
187-
if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
188-
identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
189-
for (unsigned i = 0; i < num; ++i)
190-
identifiers[dstOffset + i] = identifiers[srcOffset + i];
191-
identifiers.erase(identifiers.begin() + srcOffset,
192-
identifiers.begin() + srcOffset + num);
193-
} else if (isUsingIds() && srcKind != VarKind::Local) {
194-
identifiers.erase(identifiers.begin() + srcOffset,
195-
identifiers.begin() + srcOffset + num);
196-
} else if (isUsingIds() && dstKind != VarKind::Local) {
197-
identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
198-
}
199203
}
200204

201205
void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA,

mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,24 @@ TEST(PresburgerSpaceTest, convertVarKindLocals) {
158158
EXPECT_FALSE(space.getId(VarKind::Range, 0).hasValue());
159159
EXPECT_FALSE(space.getId(VarKind::Range, 1).hasValue());
160160
}
161+
162+
TEST(PresburgerSpaceTest, convertVarKind2) {
163+
PresburgerSpace space = PresburgerSpace::getRelationSpace(0, 2, 2, 0);
164+
space.resetIds();
165+
166+
// Attach identifiers.
167+
int identifiers[4] = {0, 1, 2, 3};
168+
space.getId(VarKind::Range, 0) = Identifier(&identifiers[0]);
169+
space.getId(VarKind::Range, 1) = Identifier(&identifiers[1]);
170+
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[2]);
171+
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[3]);
172+
173+
// Convert Range variables to symbols.
174+
space.convertVarKind(VarKind::Range, 0, 2, VarKind::Symbol, 1);
175+
176+
// Check if the identifiers are moved to symbols.
177+
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[2]));
178+
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[0]));
179+
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
180+
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
181+
}

0 commit comments

Comments
 (0)