diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index a6e996f3fb810..5da58ad9c3a18 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1898,6 +1898,13 @@ static Type getInsertExtractValueElementType(Type llvmType, } OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { + if (auto extractValueOp = getContainer().getDefiningOp()) { + SmallVector newPos(extractValueOp.getPosition()); + newPos.append(getPosition().begin(), getPosition().end()); + setPosition(newPos); + getContainerMutable().set(extractValueOp.getContainer()); + return getResult(); + } auto insertValueOp = getContainer().getDefiningOp(); OpFoldResult result = {}; while (insertValueOp) { diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir index 15f960167cb5f..c509cd82227c2 100644 --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -84,6 +84,16 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 { llvm.return %3 : f32 } +// ----- +// CHECK-LABEL: fold_extract_extractvalue +llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)>) -> !llvm.ptr<1> { + // CHECK: llvm.extractvalue %{{.*}}[1, 0] + // CHECK-NOT: extractvalue + %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)> + %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>> + llvm.return %b : !llvm.ptr<1> +} + // ----- // CHECK-LABEL: fold_bitcast