diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h index fa00b29dbd803..2012cf8a8ed3e 100644 --- a/llvm/include/llvm/SandboxIR/Constant.h +++ b/llvm/include/llvm/SandboxIR/Constant.h @@ -670,7 +670,96 @@ class ConstantDataVector final : public ConstantDataSequential { friend class Context; public: - // TODO: Add missing functions. + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const Value *From) { + return From->getSubclassID() == ClassID::ConstantDataVector; + } + /// get() constructors - Return a constant with vector type with an element + /// count and element type matching the ArrayRef passed in. Note that this + /// can return a ConstantAggregateZero object. + static Constant *get(Context &Ctx, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts); + return Ctx.getOrCreateConstant(NewLLVMC); + } + static Constant *get(Context &Ctx, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts); + return Ctx.getOrCreateConstant(NewLLVMC); + } + static Constant *get(Context &Ctx, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts); + return Ctx.getOrCreateConstant(NewLLVMC); + } + static Constant *get(Context &Ctx, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts); + return Ctx.getOrCreateConstant(NewLLVMC); + } + static Constant *get(Context &Ctx, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts); + return Ctx.getOrCreateConstant(NewLLVMC); + } + static Constant *get(Context &Ctx, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts); + return Ctx.getOrCreateConstant(NewLLVMC); + } + + /// getRaw() constructor - Return a constant with vector type with an element + /// count and element type matching the NumElements and ElementTy parameters + /// passed in. Note that this can return a ConstantAggregateZero object. + /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is + /// the buffer containing the elements. Be careful to make sure Data uses the + /// right endianness, the buffer will be used as-is. + static Constant *getRaw(StringRef Data, uint64_t NumElements, + Type *ElementTy) { + auto *NewLLVMC = + llvm::ConstantDataVector::getRaw(Data, NumElements, ElementTy->LLVMTy); + return ElementTy->getContext().getOrCreateConstant(NewLLVMC); + } + /// getFP() constructors - Return a constant of vector type with a float + /// element type taken from argument `ElementType', and count taken from + /// argument `Elts'. The amount of bits of the contained type must match the + /// number of bits of the type contained in the passed in ArrayRef. + /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note + /// that this can return a ConstantAggregateZero object. + static Constant *getFP(Type *ElementType, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts); + return ElementType->getContext().getOrCreateConstant(NewLLVMC); + } + static Constant *getFP(Type *ElementType, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts); + return ElementType->getContext().getOrCreateConstant(NewLLVMC); + } + static Constant *getFP(Type *ElementType, ArrayRef Elts) { + auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts); + return ElementType->getContext().getOrCreateConstant(NewLLVMC); + } + + /// Return a ConstantVector with the specified constant in each element. + /// The specified constant has to be a of a compatible type (i8/i16/ + /// i32/i64/half/bfloat/float/double) and must be a ConstantFP or ConstantInt. + static Constant *getSplat(unsigned NumElts, Constant *Elt) { + auto *NewLLVMC = llvm::ConstantDataVector::getSplat( + NumElts, cast(Elt->Val)); + return Elt->getContext().getOrCreateConstant(NewLLVMC); + } + + /// Returns true if this is a splat constant, meaning that all elements have + /// the same value. + bool isSplat() const { + return cast(Val)->isSplat(); + } + + /// If this is a splat constant, meaning that all of the elements have the + /// same value, return that value. Otherwise return NULL. + Constant *getSplatValue() const { + return Ctx.getOrCreateConstant( + cast(Val)->getSplatValue()); + } + + /// Specialize the getType() method to always return a FixedVectorType, + /// which reduces the amount of casting needed in parts of the compiler. + inline FixedVectorType *getType() const { + return cast(Value::getType()); + } }; // TODO: Inherit from ConstantData. diff --git a/llvm/include/llvm/SandboxIR/Value.h b/llvm/include/llvm/SandboxIR/Value.h index d45aa4059de69..dbd0208b4f3f3 100644 --- a/llvm/include/llvm/SandboxIR/Value.h +++ b/llvm/include/llvm/SandboxIR/Value.h @@ -171,6 +171,7 @@ class Value { friend class Region; friend class ScoreBoard; // Needs access to `Val` for the instruction cost. friend class ConstantDataArray; // For `Val` + friend class ConstantDataVector; // For `Val` /// All values point to the context. Context &Ctx; diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 8cce659596a4d..18882add59941 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -622,6 +622,7 @@ define void @foo() { %fvector = extractelement <2 x double> , i32 0 %string = extractvalue [6 x i8] [i8 72, i8 69, i8 76, i8 76, i8 79, i8 0], 0 %stringNoNull = extractvalue [5 x i8] [i8 72, i8 69, i8 76, i8 76, i8 79], 0 + %splat = extractelement <4 x i8> , i32 0 ret void } )IR"); @@ -637,6 +638,7 @@ define void @foo() { auto *I3 = &*It++; auto *I4 = &*It++; auto *I5 = &*It++; + auto *I6 = &*It++; auto *Array = cast(I0->getOperand(0)); EXPECT_TRUE(isa(Array)); auto *Vector = cast(I1->getOperand(0)); @@ -649,6 +651,8 @@ define void @foo() { EXPECT_TRUE(isa(String)); auto *StringNoNull = cast(I5->getOperand(0)); EXPECT_TRUE(isa(StringNoNull)); + auto *Splat = cast(I6->getOperand(0)); + EXPECT_TRUE(isa(Splat)); auto *Zero8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0); auto *One8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 1); @@ -750,9 +754,74 @@ define void @foo() { llvm::Type::getDoubleTy(C), Elts64)))); // Check getString(). EXPECT_EQ(sandboxir::ConstantDataArray::getString(Ctx, "HELLO"), String); + EXPECT_EQ(sandboxir::ConstantDataArray::getString(Ctx, "HELLO", /*AddNull=*/false), StringNoNull); + EXPECT_EQ( + sandboxir::ConstantDataArray::getString(Ctx, "HELLO", /*AddNull=*/false), + StringNoNull); + + { + // Check ConstantDataArray member functions + // ---------------------------------------- + // Check get(). + SmallVector Elts8({0u, 1u}); + SmallVector Elts16({0u, 1u}); + SmallVector Elts32({0u, 1u}); + SmallVector Elts64({0u, 1u}); + SmallVector EltsF32({0.0, 1.0}); + SmallVector EltsF64({0.0, 1.0}); + auto *CDV8 = sandboxir::ConstantDataVector::get(Ctx, Elts8); + EXPECT_EQ(CDV8, cast( + Ctx.getValue(llvm::ConstantDataVector::get(C, Elts8)))); + auto *CDV16 = sandboxir::ConstantDataVector::get(Ctx, Elts16); + EXPECT_EQ(CDV16, cast(Ctx.getValue( + llvm::ConstantDataVector::get(C, Elts16)))); + auto *CDV32 = sandboxir::ConstantDataVector::get(Ctx, Elts32); + EXPECT_EQ(CDV32, cast(Ctx.getValue( + llvm::ConstantDataVector::get(C, Elts32)))); + auto *CDVF32 = sandboxir::ConstantDataVector::get(Ctx, EltsF32); + EXPECT_EQ(CDVF32, cast(Ctx.getValue( + llvm::ConstantDataVector::get(C, EltsF32)))); + auto *CDVF64 = sandboxir::ConstantDataVector::get(Ctx, EltsF64); + EXPECT_EQ(CDVF64, cast(Ctx.getValue( + llvm::ConstantDataVector::get(C, EltsF64)))); + // Check getRaw(). + auto *CDVRaw = sandboxir::ConstantDataVector::getRaw( + StringRef("HELLO"), 5, sandboxir::Type::getInt8Ty(Ctx)); + EXPECT_EQ(CDVRaw, + cast( + Ctx.getValue(llvm::ConstantDataVector::getRaw( + StringRef("HELLO"), 5, llvm::Type::getInt8Ty(C))))); + // Check getFP(). + auto *CDVFP16 = sandboxir::ConstantDataVector::getFP(F16Ty, Elts16); + EXPECT_EQ(CDVFP16, cast( + Ctx.getValue(llvm::ConstantDataVector::getFP( + llvm::Type::getHalfTy(C), Elts16)))); + auto *CDVFP32 = sandboxir::ConstantDataVector::getFP(F32Ty, Elts32); + EXPECT_EQ(CDVFP32, cast( + Ctx.getValue(llvm::ConstantDataVector::getFP( + llvm::Type::getFloatTy(C), Elts32)))); + auto *CDVFP64 = sandboxir::ConstantDataVector::getFP(F64Ty, Elts64); + EXPECT_EQ(CDVFP64, cast( + Ctx.getValue(llvm::ConstantDataVector::getFP( + llvm::Type::getDoubleTy(C), Elts64)))); + // Check getSplat(). + auto *NewSplat = cast( + sandboxir::ConstantDataVector::getSplat(4, One8)); + EXPECT_EQ(NewSplat, Splat); + // Check isSplat(). + EXPECT_TRUE(NewSplat->isSplat()); + EXPECT_FALSE(Vector->isSplat()); + // Check getSplatValue(). + EXPECT_EQ(NewSplat->getSplatValue(), One8); + // Check getType(). + EXPECT_TRUE(isa(NewSplat->getType())); + EXPECT_EQ( + cast(NewSplat->getType())->getNumElements(), + 4u); + } } TEST_F(SandboxIRTest, ConstantPointerNull) {