Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1885,11 +1885,40 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {

auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
OpFoldResult result = {};
ArrayRef<int64_t> extractPos = getPosition();
bool switchedToInsertedValue = false;
while (insertValueOp) {
if (getPosition() == insertValueOp.getPosition())
ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
auto extractPosSize = extractPos.size();
auto insertPosSize = insertPos.size();

// Case 1: Exact match of positions.
if (extractPos == insertPos)
return insertValueOp.getValue();
unsigned min =
std::min(getPosition().size(), insertValueOp.getPosition().size());

// Case 2: Insert position is a prefix of extract position. Continue
// traversal with the inserted value. Example:
// ```
// %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
// %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
// %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
// %3 = llvm.insertvalue %2, %foo[0]
// : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
// %4 = llvm.extractvalue %3[0, 0]
// : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
// ```
// In the above example, %4 is folded to %arg1.
if (extractPosSize > insertPosSize &&
extractPos.take_front(insertPosSize) == insertPos) {
insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
extractPos = extractPos.drop_front(insertPosSize);
switchedToInsertedValue = true;
continue;
}

// Case 3: Try to continue the traversal with the container value.
unsigned min = std::min(extractPosSize, insertPosSize);

// If one is fully prefix of the other, stop propagating back as it will
// miss dependencies. For instance, %3 should not fold to %f0 in the
// following example:
Expand All @@ -1900,15 +1929,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// !llvm.array<4 x !llvm.array<4 x f32>>
// %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
// ```
if (getPosition().take_front(min) ==
insertValueOp.getPosition().take_front(min))
if (extractPos.take_front(min) == insertPos.take_front(min))
return result;

// If neither a prefix, nor the exact position, we can extract out of the
// value being inserted into. Moreover, we can try again if that operand
// is itself an insertvalue expression.
getContainerMutable().assign(insertValueOp.getContainer());
result = getResult();
if (!switchedToInsertedValue) {
// Do not swap out the container operand if we decided earlier to
// continue the traversal with the inserted value (Case 2).
getContainerMutable().assign(insertValueOp.getContainer());
result = getResult();
}
insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
}
return result;
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/LLVMIR/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ llvm.func @fold_extractvalue() -> i32 {

// -----

// CHECK-LABEL: fold_extractvalue(
// CHECK-SAME: %[[arg1:.*]]: i32, %[[arg2:.*]]: i32, %[[arg3:.*]]: i32)
// CHECK-NEXT: llvm.return %[[arg1]] : i32
llvm.func @fold_extractvalue(%arg1: i32, %arg2: i32, %arg3: i32) -> i32{
%3 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
%5 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32)>
%6 = llvm.insertvalue %arg1, %5[0] : !llvm.struct<(i32, i32, i32)>
%7 = llvm.insertvalue %arg1, %6[1] : !llvm.struct<(i32, i32, i32)>
%8 = llvm.insertvalue %arg1, %7[2] : !llvm.struct<(i32, i32, i32)>
%11 = llvm.insertvalue %8, %3[0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
%13 = llvm.extractvalue %11[0, 0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
llvm.return %13 : i32
}

// -----

// CHECK-LABEL: no_fold_extractvalue
llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
Expand Down
Loading