@@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
288
288
llvm::Value *DestPtr,
289
289
clang::QualType DestTy) override ;
290
290
void AddHLSLFunctionInfo (llvm::Function *, const FunctionDecl *FD) override ;
291
+ bool FindDispatchGridSemantic (const CXXRecordDecl *RD,
292
+ hlsl::SVDispatchGrid &SDGRec,
293
+ CharUnits Offset = CharUnits());
291
294
void AddHLSLNodeRecordTypeInfo (const clang::ParmVarDecl *parmDecl,
292
295
hlsl::NodeIOProperties &node);
293
296
void EmitHLSLFunctionProlog (llvm::Function *,
@@ -2560,6 +2563,66 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
2560
2563
m_ScopeMap[F] = ScopeInfo (F, FD->getLocation ());
2561
2564
}
2562
2565
2566
+ // Find the input node record field with the SV_DispatchGrid semantic.
2567
+ // We have already diagnosed any error conditions in Sema, so we
2568
+ // expect valid size and types, and use the first occurance found.
2569
+ // We return true if we have populated the SV_DispatchGrid values.
2570
+ bool CGMSHLSLRuntime::FindDispatchGridSemantic (const CXXRecordDecl *RD,
2571
+ hlsl::SVDispatchGrid &SDGRec,
2572
+ CharUnits Offset) {
2573
+ const ASTRecordLayout &Layout = CGM.getContext ().getASTRecordLayout (RD);
2574
+
2575
+ // Check (non-virtual) bases
2576
+ for (const CXXBaseSpecifier &Base : RD->bases ()) {
2577
+ DXASSERT (!Base.getType ()->isDependentType (),
2578
+ " Node Record with dependent base class not caught by Sema" );
2579
+ if (Base.getType ()->isDependentType ())
2580
+ continue ;
2581
+ CXXRecordDecl *BaseDecl = Base.getType ()->getAsCXXRecordDecl ();
2582
+ CharUnits BaseOffset = Offset + Layout.getBaseClassOffset (BaseDecl);
2583
+ if (FindDispatchGridSemantic (BaseDecl, SDGRec, BaseOffset))
2584
+ return true ;
2585
+ }
2586
+
2587
+ // Check each field in this record.
2588
+ for (FieldDecl *Field : RD->fields ()) {
2589
+ uint64_t FieldNo = Field->getFieldIndex ();
2590
+ CharUnits FieldOffset = Offset + CGM.getContext ().toCharUnitsFromBits (
2591
+ Layout.getFieldOffset (FieldNo));
2592
+
2593
+ // If this field is a record check its fields
2594
+ if (const CXXRecordDecl *D = Field->getType ()->getAsCXXRecordDecl ()) {
2595
+ if (FindDispatchGridSemantic (D, SDGRec, FieldOffset))
2596
+ return true ;
2597
+ }
2598
+ // Otherwise check this field for the SV_DispatchGrid semantic annotation
2599
+ for (const hlsl::UnusualAnnotation *UA : Field->getUnusualAnnotations ()) {
2600
+ if (UA->getKind () == hlsl::UnusualAnnotation::UA_SemanticDecl) {
2601
+ const hlsl::SemanticDecl *SD = cast<hlsl::SemanticDecl>(UA);
2602
+ if (SD->SemanticName .equals (" SV_DispatchGrid" )) {
2603
+ const llvm::Type *FTy = CGM.getTypes ().ConvertType (Field->getType ());
2604
+ const llvm::Type *ElTy = FTy;
2605
+ SDGRec.NumComponents = 1 ;
2606
+ SDGRec.ByteOffset = (unsigned )FieldOffset.getQuantity ();
2607
+ if (const llvm::VectorType *VT = dyn_cast<llvm::VectorType>(FTy)) {
2608
+ SDGRec.NumComponents = VT->getNumElements ();
2609
+ ElTy = VT->getElementType ();
2610
+ } else if (const llvm::ArrayType *AT =
2611
+ dyn_cast<llvm::ArrayType>(FTy)) {
2612
+ SDGRec.NumComponents = AT->getNumElements ();
2613
+ ElTy = AT->getElementType ();
2614
+ }
2615
+ SDGRec.ComponentType = (ElTy->getIntegerBitWidth () == 16 )
2616
+ ? DXIL::ComponentType::U16
2617
+ : DXIL::ComponentType::U32;
2618
+ return true ;
2619
+ }
2620
+ }
2621
+ }
2622
+ }
2623
+ return false ;
2624
+ }
2625
+
2563
2626
void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo (
2564
2627
const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) {
2565
2628
clang::QualType paramTy = parmDecl->getType ().getCanonicalType ();
@@ -2577,7 +2640,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
2577
2640
DiagnosticsEngine &Diags = CGM.getDiags ();
2578
2641
auto &Rec = TemplateArgs.get (0 );
2579
2642
clang::QualType RecType = Rec.getAsType ();
2580
- llvm::Type *Type = CGM.getTypes ().ConvertType (RecType);
2581
2643
CXXRecordDecl *RD = RecType->getAsCXXRecordDecl ();
2582
2644
2583
2645
// Get the TrackRWInputSharing flag from the record attribute
@@ -2597,63 +2659,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
2597
2659
2598
2660
// Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
2599
2661
// size(MY_RECORD), alignment = alignof(MY_RECORD)
2662
+ llvm::Type *Type = CGM.getTypes ().ConvertType (RecType);
2600
2663
node.RecordType .size = CGM.getDataLayout ().getTypeAllocSize (Type);
2601
2664
node.RecordType .alignment =
2602
2665
CGM.getDataLayout ().getABITypeAlignment (Type);
2603
- // Iterate over fields of the MY_RECORD(example) struct
2604
- for (auto fieldDecl : RD->fields ()) {
2605
- // Check if any of the fields have a semantic annotation =
2606
- // SV_DispatchGrid
2607
- for (const hlsl::UnusualAnnotation *it :
2608
- fieldDecl->getUnusualAnnotations ()) {
2609
- if (it->getKind () == hlsl::UnusualAnnotation::UA_SemanticDecl) {
2610
- const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
2611
- // if we find a field with SV_DispatchGrid, fill out the
2612
- // SV_DispatchGrid member with byteoffset of the field,
2613
- // NumComponents (3 for uint3 etc) and U32 vs U16 types, which are
2614
- // the only types allowed
2615
- if (sd->SemanticName .equals (" SV_DispatchGrid" )) {
2616
- clang::QualType FT = fieldDecl->getType ();
2617
- auto &DL = CGM.getDataLayout ();
2618
- auto &SDGRec = node.RecordType .SV_DispatchGrid ;
2619
-
2620
- DXASSERT_NOMSG (SDGRec.NumComponents == 0 );
2621
-
2622
- unsigned fieldIdx = fieldDecl->getFieldIndex ();
2623
- if (StructType *ST = dyn_cast<StructType>(Type)) {
2624
- SDGRec.ByteOffset =
2625
- DL.getStructLayout (ST)->getElementOffset (fieldIdx);
2626
- }
2627
- const llvm::Type *lTy = CGM.getTypes ().ConvertType (FT);
2628
- if (const llvm::VectorType *VT =
2629
- dyn_cast<llvm::VectorType>(lTy)) {
2630
- DXASSERT (VT->getElementType ()->isIntegerTy (), " invalid type" );
2631
- SDGRec.NumComponents = VT->getNumElements ();
2632
- SDGRec.ComponentType =
2633
- (VT->getElementType ()->getIntegerBitWidth () == 16 )
2634
- ? DXIL::ComponentType::U16
2635
- : DXIL::ComponentType::U32;
2636
- } else if (const llvm::ArrayType *AT =
2637
- dyn_cast<llvm::ArrayType>(lTy)) {
2638
- DXASSERT (AT->getElementType ()->isIntegerTy (), " invalid type" );
2639
- DXASSERT_NOMSG (AT->getNumElements () <= 3 );
2640
- SDGRec.NumComponents = AT->getNumElements ();
2641
- SDGRec.ComponentType =
2642
- (AT->getElementType ()->getIntegerBitWidth () == 16 )
2643
- ? DXIL::ComponentType::U16
2644
- : DXIL::ComponentType::U32;
2645
- } else {
2646
- // Scalar U16 or U32
2647
- DXASSERT (lTy->isIntegerTy (), " invalid type" );
2648
- SDGRec.NumComponents = 1 ;
2649
- SDGRec.ComponentType = (lTy->getIntegerBitWidth () == 16 )
2650
- ? DXIL::ComponentType::U16
2651
- : DXIL::ComponentType::U32;
2652
- }
2653
- }
2654
- }
2655
- }
2656
- }
2666
+
2667
+ FindDispatchGridSemantic (RD, node.RecordType .SV_DispatchGrid );
2657
2668
}
2658
2669
}
2659
2670
}
0 commit comments