Skip to content

Commit dfc8330

Browse files
Merge branch 'main' into modify-tutorial-typo
2 parents 1f1672c + d162c91 commit dfc8330

File tree

4 files changed

+69
-8
lines changed

4 files changed

+69
-8
lines changed

llvm/include/llvm/ExecutionEngine/Orc/WaitingOnGraph.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ template <typename ContainerIdT, typename ElementIdT> class WaitingOnGraph {
500500
if (I == SN->Deps.end())
501501
continue;
502502
for (auto &[DefElem, DefSN] : DefElems)
503-
if (I->second.erase(DefElem))
503+
if (I->second.erase(DefElem) && DefSN != SN.get())
504504
SNDeps.insert(DefSN);
505505
if (I->second.empty())
506506
SN->Deps.erase(I);
@@ -511,11 +511,13 @@ template <typename ContainerIdT, typename ElementIdT> class WaitingOnGraph {
511511
// Compute transitive closure of deps for each node.
512512
static void propagateSuperNodeDeps(SuperNodeDepsMap &SuperNodeDeps) {
513513
for (auto &[SN, Deps] : SuperNodeDeps) {
514-
DenseSet<SuperNode *> Reachable({SN});
514+
DenseSet<SuperNode *> Reachable;
515515
SmallVector<SuperNode *> Worklist(Deps.begin(), Deps.end());
516516

517517
while (!Worklist.empty()) {
518518
auto *DepSN = Worklist.pop_back_val();
519+
if (DepSN == SN)
520+
continue;
519521
if (!Reachable.insert(DepSN).second)
520522
continue;
521523
auto I = SuperNodeDeps.find(DepSN);
@@ -537,9 +539,11 @@ template <typename ContainerIdT, typename ElementIdT> class WaitingOnGraph {
537539
if (I == SuperNodeDeps.end())
538540
continue;
539541

540-
for (auto *DepSN : I->second)
542+
for (auto *DepSN : I->second) {
543+
assert(DepSN != SN.get() && "Unexpected self-dependence for SN");
541544
for (auto &[Container, Elems] : DepSN->Deps)
542545
SN->Deps[Container].insert(Elems.begin(), Elems.end());
546+
}
543547
}
544548
}
545549

llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,23 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
140140
Value *Elt = EI.getIndexOperand();
141141
// If the operand is the PHI induction variable:
142142
if (PHIInVal == PHIUser) {
143-
// Scalarize the binary operation. Its first operand is the
144-
// scalar PHI, and the second operand is extracted from the other
143+
// Scalarize the binary operation. One operand is the
144+
// scalar PHI, and the other is extracted from the other
145145
// vector operand.
146146
BinaryOperator *B0 = cast<BinaryOperator>(PHIUser);
147147
unsigned opId = (B0->getOperand(0) == PN) ? 1 : 0;
148148
Value *Op = InsertNewInstWith(
149149
ExtractElementInst::Create(B0->getOperand(opId), Elt,
150150
B0->getOperand(opId)->getName() + ".Elt"),
151151
B0->getIterator());
152-
Value *newPHIUser = InsertNewInstWith(
153-
BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(),
154-
scalarPHI, Op, B0), B0->getIterator());
152+
// Preserve operand order for binary operation to preserve semantics of
153+
// non-commutative operations.
154+
Value *FirstOp = (B0->getOperand(0) == PN) ? scalarPHI : Op;
155+
Value *SecondOp = (B0->getOperand(0) == PN) ? Op : scalarPHI;
156+
Value *newPHIUser =
157+
InsertNewInstWith(BinaryOperator::CreateWithCopiedFlags(
158+
B0->getOpcode(), FirstOp, SecondOp, B0),
159+
B0->getIterator());
155160
scalarPHI->addIncoming(newPHIUser, inBB);
156161
} else {
157162
// Scalarize PHI input:

llvm/test/Transforms/InstCombine/scalarization.ll

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,50 @@ for.end:
108108
ret void
109109
}
110110

111+
define void @scalarize_phi_sub(ptr %n, ptr %inout) {
112+
;
113+
; CHECK-LABEL: @scalarize_phi_sub(
114+
; CHECK-NEXT: entry:
115+
; CHECK-NEXT: [[T0:%.*]] = load volatile float, ptr [[INOUT:%.*]], align 4
116+
; CHECK-NEXT: br label [[FOR_COND:%.*]]
117+
; CHECK: for.cond:
118+
; CHECK-NEXT: [[TMP0:%.*]] = phi float [ [[T0]], [[ENTRY:%.*]] ], [ [[TMP1:%.*]], [[FOR_BODY:%.*]] ]
119+
; CHECK-NEXT: [[I_0:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[INC:%.*]], [[FOR_BODY]] ]
120+
; CHECK-NEXT: [[T1:%.*]] = load i32, ptr [[N:%.*]], align 4
121+
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i32 [[I_0]], [[T1]]
122+
; CHECK-NEXT: br i1 [[CMP_NOT]], label [[FOR_END:%.*]], label [[FOR_BODY]]
123+
; CHECK: for.body:
124+
; CHECK-NEXT: store volatile float [[TMP0]], ptr [[INOUT]], align 4
125+
; CHECK-NEXT: [[TMP1]] = fsub float 0.000000e+00, [[TMP0]]
126+
; CHECK-NEXT: [[INC]] = add nuw nsw i32 [[I_0]], 1
127+
; CHECK-NEXT: br label [[FOR_COND]]
128+
; CHECK: for.end:
129+
; CHECK-NEXT: ret void
130+
;
131+
entry:
132+
%t0 = load volatile float, ptr %inout, align 4
133+
%insert = insertelement <4 x float> poison, float %t0, i32 0
134+
%splat = shufflevector <4 x float> %insert, <4 x float> poison, <4 x i32> zeroinitializer
135+
br label %for.cond
136+
137+
for.cond:
138+
%x.0 = phi <4 x float> [ %splat, %entry ], [ %sub, %for.body ]
139+
%i.0 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
140+
%t1 = load i32, ptr %n, align 4
141+
%cmp = icmp ne i32 %i.0, %t1
142+
br i1 %cmp, label %for.body, label %for.end
143+
144+
for.body:
145+
%t2 = extractelement <4 x float> %x.0, i32 1
146+
store volatile float %t2, ptr %inout, align 4
147+
%sub = fsub <4 x float> zeroinitializer, %x.0
148+
%inc = add nsw i32 %i.0, 1
149+
br label %for.cond
150+
151+
for.end:
152+
ret void
153+
}
154+
111155
define float @extract_element_binop_splat_constant_index(<4 x float> %x) {
112156
;
113157
; CHECK-LABEL: @extract_element_binop_splat_constant_index(

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,8 +2610,13 @@ cc_library(
26102610
hdrs = glob(["include/mlir/Dialect/X86Vector/TransformOps/*.h"]),
26112611
includes = ["include"],
26122612
deps = [
2613+
":IR",
2614+
":LLVMCommonConversion",
2615+
":LLVMDialect",
26132616
":TransformDialect",
2617+
":TransformDialectInterfaces",
26142618
":VectorDialect",
2619+
":X86VectorDialect",
26152620
":X86VectorTransformOpsIncGen",
26162621
":X86VectorTransforms",
26172622
],
@@ -2628,6 +2633,9 @@ cc_library(
26282633
":LLVMCommonConversion",
26292634
":LLVMDialect",
26302635
":LinalgDialect",
2636+
":LinalgInterfaces",
2637+
":Pass",
2638+
":TransformUtils",
26312639
":VectorDialect",
26322640
":VectorUtils",
26332641
":X86VectorDialect",

0 commit comments

Comments
 (0)