@@ -568,51 +568,6 @@ Instruction *emitSpecConstantComposite(Type *Ty, ArrayRef<Value *> Elements,
568568 return emitCall (Ty, SPIRV_GET_SPEC_CONST_COMPOSITE, Elements, InsertBefore);
569569}
570570
571- // Select corresponding element of the default value. For a
572- // struct, we getting the corresponding default value is a little
573- // tricky. There are potentially distinct two types: the type of
574- // the default value, which comes from the initializer of the
575- // global spec constant value, and the return type of the call to
576- // getComposite2020SpecConstValue. The return type can be a
577- // version of the default value type, with padding fields
578- // potentially inserted at the top level and within nested
579- // structs.
580-
581- // Examples: (RT = Return Type, DVT = Default Value Type)
582- // RT: { i8, [3 x i8], i32 }, DVT = { i8, i32 }
583- // RT: { { i32, i8, [3 x i8] }, i32 } DVT = { { i32, i8 }, i32 }
584-
585- // For a given element of the default value type we are
586- // trying to initialize, we will initialize that element with
587- // the element of the default value type that has the same offset
588- // as the element we are trying to initialize. If no such element
589- // exists, we used undef as the initializer.
590- Constant *getElemDefaultValue (Type *Ty, Type *ElTy, Constant *DefaultValue,
591- size_t ElemIndex, const DataLayout &DL) {
592- if (auto *StructTy = dyn_cast<StructType>(Ty)) {
593- auto *DefaultValueType = cast<StructType>(DefaultValue->getType ());
594- const auto &DefaultValueTypeSL = DL.getStructLayout (DefaultValueType);
595- // The struct has padding, so we have to adjust ElemIndex
596- if (DefaultValueTypeSL->hasPadding ()) {
597- const auto &ReturnTypeSL = DL.getStructLayout (StructTy);
598- ArrayRef<TypeSize> DefaultValueOffsets =
599- DefaultValueTypeSL->getMemberOffsets ();
600- TypeSize CurrentIterationOffset =
601- ReturnTypeSL->getElementOffset (ElemIndex);
602- const auto It =
603- std::find (DefaultValueOffsets.begin (), DefaultValueOffsets.end (),
604- CurrentIterationOffset);
605-
606- // The element we are looking at is a padding field
607- if (It == DefaultValueOffsets.end ())
608- return UndefValue::get (ElTy);
609- // Select the index with the same offset
610- ElemIndex = It - DefaultValueOffsets.begin ();
611- }
612- }
613- return DefaultValue->getAggregateElement (ElemIndex);
614- }
615-
616571// / For specified specialization constant type emits LLVM IR which is required
617572// / in order to correctly handle it later during LLVM IR -> SPIR-V translation.
618573// /
@@ -636,19 +591,26 @@ Constant *getElemDefaultValue(Type *Ty, Type *ElTy, Constant *DefaultValue,
636591// / __spirvSpecConstantComposite calls for each composite member of the
637592// / composite (plus for the top-level composite). Also enumerates all
638593// / encountered scalars and assigns them IDs (or re-uses existing ones).
639- Instruction *emitSpecConstantRecursiveImpl (Type *Ty, Instruction *InsertBefore,
640- SmallVectorImpl<ID> &IDs,
641- unsigned &Index,
642- Constant *DefaultValue ) {
594+ Instruction *emitSpecConstantRecursiveImpl (
595+ Type *Ty, Instruction *InsertBefore, SmallVectorImpl<ID> &IDs,
596+ unsigned &Index, unsigned CurrentOffset ,
597+ const SmallVectorImpl<std::pair< uint64_t , Constant *>> &DefinedElements ) {
643598 const Module &M = *InsertBefore->getModule ();
644599 if (!Ty->isArrayTy () && !Ty->isStructTy () && !Ty->isVectorTy ()) { // Scalar
600+ auto It = llvm::lower_bound (DefinedElements, CurrentOffset,
601+ [](const std::pair<uint64_t , Constant *> &LHS,
602+ uint64_t RHS) { return LHS.first < RHS; });
603+ assert (It != DefinedElements.end () && It->first == CurrentOffset);
604+ Constant *DefaultValue = It->second ;
605+
645606 if (Index >= IDs.size ()) {
646607 // If it is a new specialization constant, we need to generate IDs for
647608 // scalar elements, starting with the second one.
648609 assert (!isa<UndefValue>(DefaultValue) &&
649610 " All scalar values should be defined" );
650611 IDs.push_back ({IDs.back ().ID + 1 , false });
651612 }
613+
652614 return emitSpecConstant (IDs[Index++].ID , Ty, InsertBefore, DefaultValue);
653615 }
654616
@@ -662,44 +624,89 @@ Instruction *emitSpecConstantRecursiveImpl(Type *Ty, Instruction *InsertBefore,
662624 Elements.push_back (Def);
663625 Index++;
664626 };
665- auto LoopIteration = [&](Type *ElTy, unsigned LocalIndex) {
666- const auto ElemDefaultValue = getElemDefaultValue (
667- Ty, ElTy, DefaultValue, LocalIndex, M.getDataLayout ());
668-
627+ auto LoopIteration = [&](Type *ElTy, unsigned LocalOffset) {
628+ auto ElOffset = CurrentOffset + LocalOffset;
629+ auto It = llvm::lower_bound (DefinedElements, ElOffset,
630+ [](const std::pair<uint64_t , Constant *> &LHS,
631+ uint64_t RHS) { return LHS.first < RHS; });
669632 // If the default value is a composite and has the value 'undef', we should
670633 // not generate a bunch of __spirv_SpecConstant for its elements but
671634 // pass it into __spirv_SpecConstantComposite as is.
672- if (isa<UndefValue>(ElemDefaultValue) )
673- HandleUndef (ElemDefaultValue );
635+ if (It == DefinedElements. end () || It-> first != ElOffset )
636+ HandleUndef (UndefValue::get (ElTy) );
674637 else
675638 Elements.push_back (emitSpecConstantRecursiveImpl (
676- ElTy, InsertBefore, IDs, Index, ElemDefaultValue ));
639+ ElTy, InsertBefore, IDs, Index, ElOffset, DefinedElements ));
677640 };
678641
642+ auto DL = M.getDataLayout ();
679643 if (auto *ArrTy = dyn_cast<ArrayType>(Ty)) {
644+ uint64_t ElSize = DL.getTypeAllocSize (ArrTy->getElementType ());
680645 for (size_t I = 0 ; I < ArrTy->getNumElements (); ++I)
681- LoopIteration (ArrTy->getElementType (), I);
646+ LoopIteration (ArrTy->getElementType (), I * ElSize );
682647 } else if (auto *StructTy = dyn_cast<StructType>(Ty)) {
683- size_t I = 0 ;
684- for (Type *ElTy : StructTy->elements ())
685- LoopIteration (ElTy, I++);
648+ const StructLayout *SL = M.getDataLayout ().getStructLayout (StructTy);
649+ for (auto [ElTy, Offset] :
650+ zip_equal (StructTy->elements (), SL->getMemberOffsets ()))
651+ LoopIteration (ElTy, Offset);
686652 } else if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
653+ uint64_t ElSize = DL.getTypeAllocSize (VecTy->getElementType ());
687654 for (size_t I = 0 ; I < VecTy->getNumElements (); ++I)
688- LoopIteration (VecTy->getElementType (), I);
655+ LoopIteration (VecTy->getElementType (), I * ElSize );
689656 } else {
690657 llvm_unreachable (" Unexpected spec constant type" );
691658 }
692659
693660 return emitSpecConstantComposite (Ty, Elements, InsertBefore);
694661}
695662
663+ // / Recursively iterates over a composite type in order to collect information
664+ // / about the offsets of its scalar elements.
665+ void collectDefinedElements (
666+ Constant *C, const DataLayout &DL,
667+ SmallVectorImpl<std::pair<uint64_t , Constant *>> &Result,
668+ uint64_t CurrentOffset) {
669+ if (isa<UndefValue>(C)) {
670+ return ;
671+ }
672+
673+ if (auto *StructTy = dyn_cast<StructType>(C->getType ())) {
674+ const StructLayout *SL = DL.getStructLayout (StructTy);
675+ for (auto [I, MemberOffset] : enumerate(SL->getMemberOffsets ()))
676+ collectDefinedElements (C->getAggregateElement (I), DL, Result,
677+ CurrentOffset + MemberOffset);
678+ }
679+
680+ else if (auto *ArrTy = dyn_cast<ArrayType>(C->getType ())) {
681+ uint64_t ElSize = DL.getTypeAllocSize (ArrTy->getElementType ());
682+ for (size_t I = 0 ; I < ArrTy->getNumElements (); ++I)
683+ collectDefinedElements (C->getAggregateElement (I), DL, Result,
684+ CurrentOffset + I * ElSize);
685+ }
686+
687+ else if (auto *VecTy = dyn_cast<FixedVectorType>(C->getType ())) {
688+ uint64_t ElSize = DL.getTypeAllocSize (VecTy->getElementType ());
689+ for (size_t I = 0 ; I < VecTy->getNumElements (); ++I)
690+ collectDefinedElements (C->getAggregateElement (I), DL, Result,
691+ CurrentOffset + I * ElSize);
692+ }
693+
694+ else {
695+ Result.push_back ({CurrentOffset, C});
696+ }
697+ }
698+
696699// / Wrapper intended to hide IsFirstElement argument from the caller
697700Instruction *emitSpecConstantRecursive (Type *Ty, Instruction *InsertBefore,
698701 SmallVectorImpl<ID> &IDs,
699702 Constant *DefaultValue) {
700703 unsigned Index = 0 ;
701- return emitSpecConstantRecursiveImpl (Ty, InsertBefore, IDs, Index,
702- DefaultValue);
704+ SmallVector<std::pair<uint64_t , Constant *>, 32 > DefinedElements;
705+ collectDefinedElements (DefaultValue,
706+ InsertBefore->getModule ()->getDataLayout (),
707+ DefinedElements, 0 );
708+ return emitSpecConstantRecursiveImpl (Ty, InsertBefore, IDs, Index, 0 ,
709+ DefinedElements);
703710}
704711
705712// / Function creates load instruction from the given Buffer by the given Offset.
0 commit comments