Skip to content

Commit 18103ea

Browse files
mjulian31aadeshps-mcw
authored andcommitted
[InstCombine] Fix phi scalarization with binop (llvm#169120)
InstCombine phi scalarization would always create a new binary op with the phi as the first operand, which is not correct for non-commutable binary ops such as sub. This fix preserves the original binary op ordering in the new binary op and adds a test for this behavior. Currently, this transformation can produce silently incorrect IR, and in the case of the added test, would optimize it out entirely.
1 parent 87cfdc9 commit 18103ea

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,23 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
140140
Value *Elt = EI.getIndexOperand();
141141
// If the operand is the PHI induction variable:
142142
if (PHIInVal == PHIUser) {
143-
// Scalarize the binary operation. Its first operand is the
144-
// scalar PHI, and the second operand is extracted from the other
143+
// Scalarize the binary operation. One operand is the
144+
// scalar PHI, and the other is extracted from the other
145145
// vector operand.
146146
BinaryOperator *B0 = cast<BinaryOperator>(PHIUser);
147147
unsigned opId = (B0->getOperand(0) == PN) ? 1 : 0;
148148
Value *Op = InsertNewInstWith(
149149
ExtractElementInst::Create(B0->getOperand(opId), Elt,
150150
B0->getOperand(opId)->getName() + ".Elt"),
151151
B0->getIterator());
152-
Value *newPHIUser = InsertNewInstWith(
153-
BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(),
154-
scalarPHI, Op, B0), B0->getIterator());
152+
// Preserve operand order for binary operation to preserve semantics of
153+
// non-commutative operations.
154+
Value *FirstOp = (B0->getOperand(0) == PN) ? scalarPHI : Op;
155+
Value *SecondOp = (B0->getOperand(0) == PN) ? Op : scalarPHI;
156+
Value *newPHIUser =
157+
InsertNewInstWith(BinaryOperator::CreateWithCopiedFlags(
158+
B0->getOpcode(), FirstOp, SecondOp, B0),
159+
B0->getIterator());
155160
scalarPHI->addIncoming(newPHIUser, inBB);
156161
} else {
157162
// Scalarize PHI input:

llvm/test/Transforms/InstCombine/scalarization.ll

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,50 @@ for.end:
108108
ret void
109109
}
110110

111+
define void @scalarize_phi_sub(ptr %n, ptr %inout) {
112+
;
113+
; CHECK-LABEL: @scalarize_phi_sub(
114+
; CHECK-NEXT: entry:
115+
; CHECK-NEXT: [[T0:%.*]] = load volatile float, ptr [[INOUT:%.*]], align 4
116+
; CHECK-NEXT: br label [[FOR_COND:%.*]]
117+
; CHECK: for.cond:
118+
; CHECK-NEXT: [[TMP0:%.*]] = phi float [ [[T0]], [[ENTRY:%.*]] ], [ [[TMP1:%.*]], [[FOR_BODY:%.*]] ]
119+
; CHECK-NEXT: [[I_0:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[INC:%.*]], [[FOR_BODY]] ]
120+
; CHECK-NEXT: [[T1:%.*]] = load i32, ptr [[N:%.*]], align 4
121+
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i32 [[I_0]], [[T1]]
122+
; CHECK-NEXT: br i1 [[CMP_NOT]], label [[FOR_END:%.*]], label [[FOR_BODY]]
123+
; CHECK: for.body:
124+
; CHECK-NEXT: store volatile float [[TMP0]], ptr [[INOUT]], align 4
125+
; CHECK-NEXT: [[TMP1]] = fsub float 0.000000e+00, [[TMP0]]
126+
; CHECK-NEXT: [[INC]] = add nuw nsw i32 [[I_0]], 1
127+
; CHECK-NEXT: br label [[FOR_COND]]
128+
; CHECK: for.end:
129+
; CHECK-NEXT: ret void
130+
;
131+
entry:
132+
%t0 = load volatile float, ptr %inout, align 4
133+
%insert = insertelement <4 x float> poison, float %t0, i32 0
134+
%splat = shufflevector <4 x float> %insert, <4 x float> poison, <4 x i32> zeroinitializer
135+
br label %for.cond
136+
137+
for.cond:
138+
%x.0 = phi <4 x float> [ %splat, %entry ], [ %sub, %for.body ]
139+
%i.0 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
140+
%t1 = load i32, ptr %n, align 4
141+
%cmp = icmp ne i32 %i.0, %t1
142+
br i1 %cmp, label %for.body, label %for.end
143+
144+
for.body:
145+
%t2 = extractelement <4 x float> %x.0, i32 1
146+
store volatile float %t2, ptr %inout, align 4
147+
%sub = fsub <4 x float> zeroinitializer, %x.0
148+
%inc = add nsw i32 %i.0, 1
149+
br label %for.cond
150+
151+
for.end:
152+
ret void
153+
}
154+
111155
define float @extract_element_binop_splat_constant_index(<4 x float> %x) {
112156
;
113157
; CHECK-LABEL: @extract_element_binop_splat_constant_index(

0 commit comments

Comments
 (0)