Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,7 +2296,10 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f32, { 1, 1, 1, 1 } },
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v16f32, { 3, 1, 1, 1 } },
{ ISD::FP_EXTEND, MVT::v16f64, MVT::v16f32, { 4, 1, 1, 1 } }, // 2*vcvtps2pd+vextractf64x4
{ ISD::FP_EXTEND, MVT::v16f32, MVT::v16f16, { 1, 1, 1, 1 } }, // vcvtph2ps
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
{ ISD::FP_ROUND, MVT::v8f32, MVT::v8f64, { 1, 1, 1, 1 } },
{ ISD::FP_ROUND, MVT::v16f16, MVT::v16f32, { 1, 1, 1, 1 } }, // vcvtps2ph

{ ISD::TRUNCATE, MVT::v2i1, MVT::v2i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
{ ISD::TRUNCATE, MVT::v4i1, MVT::v4i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
Expand Down Expand Up @@ -2973,6 +2976,17 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
{ ISD::TRUNCATE, MVT::v4i32, MVT::v2i64, { 1, 1, 1, 1 } }, // PSHUFD
};

static const TypeConversionCostKindTblEntry F16ConversionTbl[] = {
{ ISD::FP_ROUND, MVT::f16, MVT::f32, { 1, 1, 1, 1 } },
{ ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, { 1, 1, 1, 1 } },
{ ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, { 1, 1, 1, 1 } },
{ ISD::FP_EXTEND, MVT::f32, MVT::f16, { 1, 1, 1, 1 } },
{ ISD::FP_EXTEND, MVT::f64, MVT::f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
{ ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, { 1, 1, 1, 1 } },
{ ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, { 1, 1, 1, 1 } },
{ ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
};

// Attempt to map directly to (simple) MVT types to let us match custom entries.
EVT SrcTy = TLI->getValueType(DL, Src);
EVT DstTy = TLI->getValueType(DL, Dst);
Expand Down Expand Up @@ -3034,6 +3048,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
return *KindCost;
}

if (ST->hasF16C()) {
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
if (auto KindCost = Entry->Cost[CostKind])
return *KindCost;
}

if (ST->hasSSE41()) {
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
Expand Down Expand Up @@ -3107,6 +3128,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
if (auto KindCost = Entry->Cost[CostKind])
return std::max(LTSrc.first, LTDest.first) * *KindCost;

if (ST->hasF16C()) {
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
LTDest.second, LTSrc.second))
if (auto KindCost = Entry->Cost[CostKind])
return std::max(LTSrc.first, LTDest.first) * *KindCost;
}

if (ST->hasSSE41())
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
LTDest.second, LTSrc.second))
Expand Down Expand Up @@ -3146,6 +3174,11 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
TTI::CastContextHint::None, CostKind);
}

if (ISD == ISD::FP_ROUND && LTDest.second.getScalarType() == MVT::f16) {
// Conversion requires a libcall.
return InstructionCost::getInvalid();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is breaking https://github.com/google/jax/blob/main/tests/lax_test.py#L3630 LazyConstantTest.testConvertElementTypeAvoidsCopies21 (dtype_in=<class 'numpy.float64'>, dtype_out=<class 'numpy.float16'>).

With

F1029 08:45:30.640847    4013 logging.cc:62] assert.h assertion failed at [third_party/llvm/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp:4569](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp?l=4569&ws=joelwee/4894&snapshot=26) in VectorizationFactor llvm::LoopVectorizationPlanner::selectVectorizationFactor(): ExpectedCost.isValid() && "Unexpected invalid cost for scalar loop"
*** Check failure stack trace: ***
    @     0x7ef66f09cf59  absl::log_internal::LogMessage::SendToLog()
    @     0x7ef66f09c4fe  absl::log_internal::LogMessage::Flush()
    @     0x7ef66f09d519  absl::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7ef67ade7314  __assert_fail
    @     0x7efa86da3f10  llvm::LoopVectorizationPlanner::selectVectorizationFactor()
    @     0x7efa86db95df  llvm::LoopVectorizationPlanner::computeBestVF()
    @     0x7efa86dcbdfd  llvm::LoopVectorizePass::processLoop()
    @     0x7efa86dd2c3d  llvm::LoopVectorizePass::runImpl()
    @     0x7efa86dd3875  llvm::LoopVectorizePass::run()
    @     0x7efa8ceb7332  llvm::detail::PassModel<>::run()
    @     0x7ef9520b9050  llvm::PassManager<>::run()
    @     0x7efaff179412  llvm::detail::PassModel<>::run()
    @     0x7ef9520be28a  llvm::ModuleToFunctionPassAdaptor::run()
    @     0x7efaff179192  llvm::detail::PassModel<>::run()
    @     0x7ef9520b7d7c  llvm::PassManager<>::run()
    @     0x7efaa12ec861  xla::cpu::CompilerFunctor::operator()()
    @     0x7efa913b0271  llvm::orc::ThreadSafeModule::withModuleDo<>()
    @     0x7efa913b000b  llvm::orc::IRCompileLayer::emit()
    @     0x7efa913e6d45  llvm::orc::BasicIRLayerMaterializationUnit::materialize()
    @     0x7efa91454337  llvm::orc::InPlaceTaskDispatcher::dispatch()
    @     0x7efa91349466  llvm::orc::ExecutionSession::dispatchOutstandingMUs()
    @     0x7efa9134e9e6  llvm::orc::ExecutionSession::OL_completeLookup()
    @     0x7efa91369a89  llvm::orc::InProgressFullLookupState::complete()
    @     0x7efa9133a0f0  llvm::orc::ExecutionSession::OL_applyQueryPhase1()
    @     0x7efa91337234  llvm::orc::ExecutionSession::lookup()
    @     0x7efa9134991e  llvm::orc::ExecutionSession::lookup()
    @     0x7efa91349de8  llvm::orc::ExecutionSession::lookup()
    @     0x7efa9134a30e  llvm::orc::ExecutionSession::lookup()
    @     0x7efa9134a459  llvm::orc::ExecutionSession::lookup()
    @     0x7efaa1719abf  xla::cpu::SimpleOrcJIT::FindCompiledSymbol()
    @     0x7efaddc247c0  absl::internal_any_invocable::RemoteInvoker<>()
    @     0x7efaddc0fb68  std::__u::__function::__policy_invoker<>::__call_impl<>()
    @     0x7ef89847e1b6  tsl::thread::EigenEnvironment::ExecuteTask()
    @     0x7ef89847dd10  Eigen::ThreadPoolTempl<>::WorkerLoop()
    @     0x7ef89847d940  std::__u::invoke<>()
    @     0x7ef6a5f9e25e  Thread::ThreadBody()
    @     0x7efafb6827db  start_thread
    @     0x7efabc18e05f  clone

I dumped the LLVM IR before:

; Function Attrs: nofree norecurse nosync nounwind memory(readwrite, inaccessiblemem: none) uwtable
define noalias noundef ptr @convert.2(ptr nocapture readonly %0) local_unnamed_addr #0 {
  %args_gep = getelementptr inbounds nuw i8, ptr %0, i64 24
  %args = load ptr, ptr %args_gep, align 8
  %arg0 = load ptr, ptr %args, align 8, !invariant.load !0, !dereferenceable !1, !align !2
  %arg1_gep = getelementptr i8, ptr %args, i64 16
  %arg1 = load ptr, ptr %arg1_gep, align 8, !invariant.load !0, !dereferenceable !3, !align !2
  br label %convert.2.loop_body.dim.0

convert.2.loop_body.dim.0:                        ; preds = %1, %convert.2.loop_body.dim.0
  %convert.2.invar_address.dim.0.03 = phi i64 [ 0, %1 ], [ %invar.inc, %convert.2.loop_body.dim.0 ]
  %2 = getelementptr inbounds [5 x double], ptr %arg0, i64 0, i64 %convert.2.invar_address.dim.0.03
  %3 = load double, ptr %2, align 8, !invariant.load !0, !noalias !4
  %4 = fptrunc double %3 to half
  %5 = getelementptr inbounds [5 x half], ptr %arg1, i64 0, i64 %convert.2.invar_address.dim.0.03
  store half %4, ptr %5, align 2, !alias.scope !4
  %invar.inc = add nuw nsw i64 %convert.2.invar_address.dim.0.03, 1
  %exitcond = icmp eq i64 %invar.inc, 5
  br i1 %exitcond, label %return, label %convert.2.loop_body.dim.0

return:                                           ; preds = %convert.2.loop_body.dim.0
  ret ptr null
}

Could we fix this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not able to reproduce so far. Unfortunately the dump does not contain some of the referenced metadata so I have to make guesses for that. Then trying to run opt -S -o - -passes=loop-vectorize /tmp/x.ll works just fine and I guess I need some target setup (I played with some -mtriple=x86_64 -mattr=+avx512f,+f16c but that doesn't repro either).

That said, could you try if replacing the InstructionCost::getInvalid(); with InstructionCost::getMax() or if that doesn't work with a big number like 128 helps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope #114128 fixes this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it looks like it does. Thanks! (And apologies about the bad dump)

}

// TODO: Allow non-throughput costs that aren't binary.
auto AdjustCost = [&CostKind](InstructionCost Cost,
InstructionCost N = 1) -> InstructionCost {
Expand Down Expand Up @@ -6923,6 +6956,14 @@ bool X86TTIImpl::isVectorShiftByScalarCheap(Type *Ty) const {
return true;
}

unsigned X86TTIImpl::getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
Type *ScalarValTy) const {
if (ST->hasF16C() && ScalarMemTy->isHalfTy()) {
return 4;
}
return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy);
}

bool X86TTIImpl::isProfitableToSinkOperands(Instruction *I,
SmallVectorImpl<Use *> &Ops) const {
using namespace llvm::PatternMatch;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/X86/X86TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {

bool isVectorShiftByScalarCheap(Type *Ty) const;

unsigned getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
Type *ScalarValTy) const;

private:
bool supportsGather() const;
InstructionCost getGSVectorCost(unsigned Opcode, TTI::TargetCostKind CostKind,
Expand Down
Loading
Loading