@@ -1885,11 +1885,40 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
18851885
18861886 auto insertValueOp = getContainer ().getDefiningOp <InsertValueOp>();
18871887 OpFoldResult result = {};
1888+ ArrayRef<int64_t > extractPos = getPosition ();
1889+ bool switchedToInsertedValue = false ;
18881890 while (insertValueOp) {
1889- if (getPosition () == insertValueOp.getPosition ())
1891+ ArrayRef<int64_t > insertPos = insertValueOp.getPosition ();
1892+ auto extractPosSize = extractPos.size ();
1893+ auto insertPosSize = insertPos.size ();
1894+
1895+ // Case 1: Exact match of positions.
1896+ if (extractPos == insertPos)
18901897 return insertValueOp.getValue ();
1891- unsigned min =
1892- std::min (getPosition ().size (), insertValueOp.getPosition ().size ());
1898+
1899+ // Case 2: Insert position is a prefix of extract position. Continue
1900+ // traversal with the inserted value. Example:
1901+ // ```
1902+ // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
1903+ // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
1904+ // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
1905+ // %3 = llvm.insertvalue %2, %foo[0]
1906+ // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1907+ // %4 = llvm.extractvalue %3[0, 0]
1908+ // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1909+ // ```
1910+ // In the above example, %4 is folded to %arg1.
1911+ if (extractPosSize > insertPosSize &&
1912+ extractPos.take_front (insertPosSize) == insertPos) {
1913+ insertValueOp = insertValueOp.getValue ().getDefiningOp <InsertValueOp>();
1914+ extractPos = extractPos.drop_front (insertPosSize);
1915+ switchedToInsertedValue = true ;
1916+ continue ;
1917+ }
1918+
1919+ // Case 3: Try to continue the traversal with the container value.
1920+ unsigned min = std::min (extractPosSize, insertPosSize);
1921+
18931922 // If one is fully prefix of the other, stop propagating back as it will
18941923 // miss dependencies. For instance, %3 should not fold to %f0 in the
18951924 // following example:
@@ -1900,15 +1929,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
19001929 // !llvm.array<4 x !llvm.array<4 x f32>>
19011930 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
19021931 // ```
1903- if (getPosition ().take_front (min) ==
1904- insertValueOp.getPosition ().take_front (min))
1932+ if (extractPos.take_front (min) == insertPos.take_front (min))
19051933 return result;
1906-
19071934 // If neither a prefix, nor the exact position, we can extract out of the
19081935 // value being inserted into. Moreover, we can try again if that operand
19091936 // is itself an insertvalue expression.
1910- getContainerMutable ().assign (insertValueOp.getContainer ());
1911- result = getResult ();
1937+ if (!switchedToInsertedValue) {
1938+ // Do not swap out the container operand if we decided earlier to
1939+ // continue the traversal with the inserted value (Case 2).
1940+ getContainerMutable ().assign (insertValueOp.getContainer ());
1941+ result = getResult ();
1942+ }
19121943 insertValueOp = insertValueOp.getContainer ().getDefiningOp <InsertValueOp>();
19131944 }
19141945 return result;
0 commit comments