@@ -566,17 +566,16 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
566
566
return B.CreateLoad (Ty, GV);
567
567
}
568
568
569
- llvm::Value *
570
- CGHLSLRuntime::emitSystemSemanticLoad (IRBuilder<> &B, llvm::Type *Type,
571
- const clang::DeclaratorDecl *Decl,
572
- SemanticInfo &ActiveSemantic) {
573
- if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic )) {
569
+ llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad (
570
+ IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
571
+ Attr *Semantic, std::optional<unsigned > Index) {
572
+ if (isa<HLSLSV_GroupIndexAttr>(Semantic)) {
574
573
llvm::Function *GroupIndex =
575
574
CGM.getIntrinsic (getFlattenedThreadIdInGroupIntrinsic ());
576
575
return B.CreateCall (FunctionCallee (GroupIndex));
577
576
}
578
577
579
- if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic. Semantic )) {
578
+ if (isa<HLSLSV_DispatchThreadIDAttr>(Semantic)) {
580
579
llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic ();
581
580
llvm::Function *ThreadIDIntrinsic =
582
581
llvm::Intrinsic::isOverloaded (IntrinID)
@@ -585,7 +584,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
585
584
return buildVectorInput (B, ThreadIDIntrinsic, Type);
586
585
}
587
586
588
- if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic. Semantic )) {
587
+ if (isa<HLSLSV_GroupThreadIDAttr>(Semantic)) {
589
588
llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic ();
590
589
llvm::Function *GroupThreadIDIntrinsic =
591
590
llvm::Intrinsic::isOverloaded (IntrinID)
@@ -594,7 +593,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
594
593
return buildVectorInput (B, GroupThreadIDIntrinsic, Type);
595
594
}
596
595
597
- if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic. Semantic )) {
596
+ if (isa<HLSLSV_GroupIDAttr>(Semantic)) {
598
597
llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic ();
599
598
llvm::Function *GroupIDIntrinsic =
600
599
llvm::Intrinsic::isOverloaded (IntrinID)
@@ -603,8 +602,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
603
602
return buildVectorInput (B, GroupIDIntrinsic, Type);
604
603
}
605
604
606
- if (HLSLSV_PositionAttr *S =
607
- dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic )) {
605
+ if (HLSLSV_PositionAttr *S = dyn_cast<HLSLSV_PositionAttr>(Semantic)) {
608
606
if (CGM.getTriple ().getEnvironment () == Triple::EnvironmentType::Pixel)
609
607
return createSPIRVBuiltinLoad (B, CGM.getModule (), Type,
610
608
S->getAttrName ()->getName (),
@@ -616,54 +614,32 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
616
614
617
615
llvm::Value *
618
616
CGHLSLRuntime::handleScalarSemanticLoad (IRBuilder<> &B, llvm::Type *Type,
619
- const clang::DeclaratorDecl *Decl,
620
- SemanticInfo &ActiveSemantic) {
621
-
622
- if (!ActiveSemantic.Semantic ) {
623
- ActiveSemantic.Semantic = Decl->getAttr <HLSLSemanticAttr>();
624
- if (!ActiveSemantic.Semantic ) {
625
- CGM.getDiags ().Report (Decl->getInnerLocStart (),
626
- diag::err_hlsl_semantic_missing);
627
- return nullptr ;
628
- }
629
- ActiveSemantic.Index = ActiveSemantic.Semantic ->getSemanticIndex ();
630
- }
631
-
632
- return emitSystemSemanticLoad (B, Type, Decl, ActiveSemantic);
617
+ const clang::DeclaratorDecl *Decl) {
618
+ HLSLSemanticAttr *Semantic = Decl->getAttr <HLSLSemanticAttr>();
619
+ // Sema either attached a semantic to each field/param, or raised an error.
620
+ assert (Semantic);
621
+
622
+ std::optional<unsigned > Index = std::nullopt ;
623
+ if (Semantic->isSemanticIndexExplicit ())
624
+ Index = Semantic->getSemanticIndex ();
625
+ return emitSystemSemanticLoad (B, Type, Decl, Semantic, Index);
633
626
}
634
627
635
628
llvm::Value *
636
629
CGHLSLRuntime::handleStructSemanticLoad (IRBuilder<> &B, llvm::Type *Type,
637
- const clang::DeclaratorDecl *Decl,
638
- SemanticInfo &ActiveSemantic) {
630
+ const clang::DeclaratorDecl *Decl) {
639
631
const llvm::StructType *ST = cast<StructType>(Type);
640
632
const clang::RecordDecl *RD = Decl->getType ()->getAsRecordDecl ();
641
633
642
634
assert (std::distance (RD->field_begin (), RD->field_end ()) ==
643
635
ST->getNumElements ());
644
636
645
- if (!ActiveSemantic.Semantic ) {
646
- ActiveSemantic.Semantic = Decl->getAttr <HLSLSemanticAttr>();
647
- ActiveSemantic.Index = ActiveSemantic.Semantic
648
- ? ActiveSemantic.Semantic ->getSemanticIndex ()
649
- : 0 ;
650
- }
651
-
652
637
llvm::Value *Aggregate = llvm::PoisonValue::get (Type);
653
638
auto FieldDecl = RD->field_begin ();
654
639
for (unsigned I = 0 ; I < ST->getNumElements (); ++I) {
655
- SemanticInfo Info = ActiveSemantic;
656
640
llvm::Value *ChildValue =
657
- handleSemanticLoad (B, ST->getElementType (I), *FieldDecl, Info);
658
- if (!ChildValue) {
659
- CGM.getDiags ().Report (Decl->getInnerLocStart (),
660
- diag::note_hlsl_semantic_used_here)
661
- << Decl;
662
- return nullptr ;
663
- }
664
- if (ActiveSemantic.Semantic )
665
- ActiveSemantic = Info;
666
-
641
+ handleSemanticLoad (B, ST->getElementType (I), *FieldDecl);
642
+ assert (ChildValue);
667
643
Aggregate = B.CreateInsertValue (Aggregate, ChildValue, I);
668
644
++FieldDecl;
669
645
}
@@ -673,11 +649,10 @@ CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
673
649
674
650
llvm::Value *
675
651
CGHLSLRuntime::handleSemanticLoad (IRBuilder<> &B, llvm::Type *Type,
676
- const clang::DeclaratorDecl *Decl,
677
- SemanticInfo &ActiveSemantic) {
652
+ const clang::DeclaratorDecl *Decl) {
678
653
if (Type->isStructTy ())
679
- return handleStructSemanticLoad (B, Type, Decl, ActiveSemantic );
680
- return handleScalarSemanticLoad (B, Type, Decl, ActiveSemantic );
654
+ return handleStructSemanticLoad (B, Type, Decl);
655
+ return handleScalarSemanticLoad (B, Type, Decl);
681
656
}
682
657
683
658
void CGHLSLRuntime::emitEntryFunction (const FunctionDecl *FD,
@@ -731,8 +706,7 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
731
706
} else {
732
707
llvm::Type *ParamType =
733
708
Param.hasByValAttr () ? Param.getParamByValType () : Param.getType ();
734
- SemanticInfo ActiveSemantic = {nullptr , 0 };
735
- SemanticValue = handleSemanticLoad (B, ParamType, PD, ActiveSemantic);
709
+ SemanticValue = handleSemanticLoad (B, ParamType, PD);
736
710
if (!SemanticValue)
737
711
return ;
738
712
if (Param.hasByValAttr ()) {
0 commit comments