From 0dc732da1505ef7c7ea2d6144954f721ebcd20d0 Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Fri, 1 Aug 2025 10:17:40 +0000 Subject: [PATCH 1/2] [LLVM][InstCombine] Extend masked_gather's demanded elt analysis. Add support for other Constant types for the mask operand. --- .../InstCombineSimplifyDemanded.cpp | 22 +++++++++++++------ .../InstCombine/masked_intrinsics.ll | 1 + 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 0e3436d12702d..c82dae7ac6e65 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1834,14 +1834,22 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // segfaults which didn't exist in the original program. APInt DemandedPtrs(APInt::getAllOnes(VWidth)), DemandedPassThrough(DemandedElts); - if (auto *CV = dyn_cast(II->getOperand(2))) - for (unsigned i = 0; i < VWidth; i++) { - Constant *CElt = CV->getAggregateElement(i); - if (CElt->isNullValue()) - DemandedPtrs.clearBit(i); - else if (CElt->isAllOnesValue()) - DemandedPassThrough.clearBit(i); + if (auto *CMask = dyn_cast(II->getOperand(2))) { + if (CMask->isNullValue()) + DemandedPtrs.clearAllBits(); + else if (CMask->isAllOnesValue()) + DemandedPassThrough.clearAllBits(); + else if (auto *CV = dyn_cast(CMask)) { + for (unsigned i = 0; i < VWidth; i++) { + Constant *CElt = CV->getAggregateElement(i); + if (CElt->isNullValue()) + DemandedPtrs.clearBit(i); + else if (CElt->isAllOnesValue()) + DemandedPassThrough.clearBit(i); + } } + } + if (II->getIntrinsicID() == Intrinsic::masked_gather) simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2); simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3); diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll index d9f022442a02e..8f7683419a82a 100644 --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -passes=instcombine -S < %s | FileCheck %s +; RUN: opt -passes=instcombine -use-constant-int-for-fixed-length-splat -S < %s | FileCheck %s declare <2 x double> @llvm.masked.load.v2f64.p0(ptr %ptrs, i32, <2 x i1> %mask, <2 x double> %src0) declare void @llvm.masked.store.v2f64.p0(<2 x double> %val, ptr %ptrs, i32, <2 x i1> %mask) From efd0af43f81b6997272b353339d07f06f5e541e1 Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Sat, 2 Aug 2025 11:58:27 +0100 Subject: [PATCH 2/2] Simplify code by iterating across all Constant types. --- .../InstCombine/InstCombineSimplifyDemanded.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index c82dae7ac6e65..f17fecd430a6c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1835,13 +1835,8 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, APInt DemandedPtrs(APInt::getAllOnes(VWidth)), DemandedPassThrough(DemandedElts); if (auto *CMask = dyn_cast(II->getOperand(2))) { - if (CMask->isNullValue()) - DemandedPtrs.clearAllBits(); - else if (CMask->isAllOnesValue()) - DemandedPassThrough.clearAllBits(); - else if (auto *CV = dyn_cast(CMask)) { - for (unsigned i = 0; i < VWidth; i++) { - Constant *CElt = CV->getAggregateElement(i); + for (unsigned i = 0; i < VWidth; i++) { + if (Constant *CElt = CMask->getAggregateElement(i)) { if (CElt->isNullValue()) DemandedPtrs.clearBit(i); else if (CElt->isAllOnesValue())