From 2652be64c59511c18b245eae9ab7b60ddadbbb47 Mon Sep 17 00:00:00 2001 From: Greg Roth Date: Thu, 13 Feb 2025 09:54:28 -0700 Subject: [PATCH 01/31] NFC: Infrastructure changes for DXIL op vector and multi-dim-overloads This change adds vector and multi-dimensional overload support for DXIL operations. Incorporates change to add vector overloads from @pow2clk. This includes changes to hctdb*.py DxilOperations.* and DxilValidation.cpp. hctdb.py: Updates vector character from 't' to '<' Allows legal vector element overloads to be specified after vector overload character. Defaults legal vector element overloads to the scalar element overloads, if any. Adds 'x' for extended overload mechanism, which supports up to 2 overload dimensions at this point, but is easily expandable if necessary. Processes syntax using ',' to separate multiple overloads and defaulting vector element overloads into new breakdown of main overload string, vector overload list of strings, and list of extended overloads (used if main overload set to 'x'). DxilOperations: Extend OpCodeProperty with ExtendedOverloads and AllowedVectorElements arrays using new OverloadMask. When extended bit set in main overloads, ExtendedOverloads array is used for each overload dimension When vector bit set in main overloads or each ExtendedOverloads dimension, AllowedVectorElements are set for corresponding dimension index for allowed element types. Updated generated DXIL op table Remove unused static methods in hlsl::OP I think these were leftover from an attempt to work around a name collision ultimately caused by dx.types.CBufRet.f16|i16 between min-precision (with 4 elements) and native low precision (with 8 elements), caused by failing to initialize the min precision mode correctly for linking. That issue was fixed, and the names made unique by adding ".8" to the end of the 8-element native low precision cbuffer return type. eliminate use of std::string, std::vector - Use Twine, raw_svector_ostream, and SmallVector storage to replace uses of std::string - Use SmallVector instead of std::vector for ArgTypes in GetOpFunc Rework DXIL op overload system Add comments explaining the new system. Eliminate bool array in favor of array of masks for up to N dimensions. Add NumOverloadDims instead of two-mode system. Rework TypeSlots: - use enum, categorize basic, limit masks to used bits - void doesn't need a type slot (NumOverloadDims == 0 instead) - m_OverloadTypeName only contains basic type names Handle multi-overload in FixOverloadNames; new MayHaveNonCanonicalOverload is used to determine whether the overload name could need fixing. Extended overload is still a distinction because of the way the overloads must be wrapped in an unnamed StructType. However, it does not need a bit in the overload mask. Renamed GetVectorType to GetStructVectorType, since it's just used to get a struct for a particular vector type, not a vector type itself. In hctdb.py, no longer separate extended and vector overloads, just verify correctness of the incoming string, and add default vector overloads if necessary. In hctdb_instrhelp.py, update according to changes in hctdb.py, and eliminate needless, problematic, outdated comment printing. --- include/dxc/DXIL/DxilConstants.h | 5 + include/dxc/DXIL/DxilOperations.h | 103 +- lib/DXIL/DxilOperations.cpp | 5882 ++++++++++++------------- lib/DxilValidation/DxilValidation.cpp | 17 +- utils/hct/hctdb.py | 170 +- utils/hct/hctdb_instrhelp.py | 139 +- 6 files changed, 3158 insertions(+), 3158 deletions(-) diff --git a/include/dxc/DXIL/DxilConstants.h b/include/dxc/DXIL/DxilConstants.h index 0a9c6a4ffd..447728300b 100644 --- a/include/dxc/DXIL/DxilConstants.h +++ b/include/dxc/DXIL/DxilConstants.h @@ -155,6 +155,11 @@ const float kMinMipLodBias = -16.0f; const unsigned kResRetStatusIndex = 4; +/* hctdb_instrhelp.get_max_oload_dims()*/ +// OLOAD_DIMS-TEXT:BEGIN +const unsigned kDxilMaxOloadDims = 2; +// OLOAD_DIMS-TEXT:END + enum class ComponentType : uint32_t { Invalid = 0, I1, diff --git a/include/dxc/DXIL/DxilOperations.h b/include/dxc/DXIL/DxilOperations.h index e522e06204..0bd855ae58 100644 --- a/include/dxc/DXIL/DxilOperations.h +++ b/include/dxc/DXIL/DxilOperations.h @@ -57,12 +57,31 @@ class OP { // caches. void RefreshCache(); + // The single llvm::Type * "OverloadType" has one of these forms: + // No overloads (NumOverloadDims == 0): + // - TS_Void: VoidTy + // For single overload dimension (NumOverloadDims == 1): + // - TS_F*, TS_I*: a scalar numeric type (half, float, i1, i64, etc.), + // - TS_UDT: a pointer to a StructType representing a User Defined Type, + // - TS_Object: a named StructType representing a built-in object, or + // - TS_Vector: a vector type (<4 x float>, <16 x i16>, etc.) + // For multiple overload dimensions (TS_Extended, NumOverloadDims > 1): + // - an unnamed StructType containing each type for the corresponding + // dimension, such as: type { i32, <2 x float> } + // - contained type options are the same as for single dimension. + llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType); + + // N-dimension convenience version of GetOpFunc: + llvm::Function *GetOpFunc(OpCode OpCode, + llvm::ArrayRef OverloadTypes); + const llvm::SmallMapVector & GetOpFuncList(OpCode OpCode) const; bool IsDxilOpUsed(OpCode opcode) const; void RemoveFunction(llvm::Function *F); llvm::LLVMContext &GetCtx() { return m_Ctx; } + llvm::Module *GetModule() { return m_pModule; } llvm::Type *GetHandleType() const; llvm::Type *GetHitObjectType() const; llvm::Type *GetNodeHandleType() const; @@ -81,9 +100,14 @@ class OP { llvm::Type *GetResRetType(llvm::Type *pOverloadType); llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType); - llvm::Type *GetVectorType(unsigned numElements, llvm::Type *pOverloadType); + llvm::Type *GetStructVectorType(unsigned numElements, + llvm::Type *pOverloadType); bool IsResRetType(llvm::Type *Ty); + // Construct an unnamed struct type containing the set of member types. + llvm::StructType * + GetExtendedOverloadType(llvm::ArrayRef OverloadTypes); + // Try to get the opcode class for a function. // Return true and set `opClass` if the given function is a dxil function. // Return false if the given function is not a dxil function. @@ -128,11 +152,6 @@ class OP { static bool BarrierRequiresGroup(const llvm::CallInst *CI); static bool BarrierRequiresNode(const llvm::CallInst *CI); static DXIL::BarrierMode TranslateToBarrierMode(const llvm::CallInst *CI); - static bool IsDxilOpTypeName(llvm::StringRef name); - static bool IsDxilOpType(llvm::StructType *ST); - static bool IsDupDxilOpType(llvm::StructType *ST); - static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST, - llvm::Module &M); static void GetMinShaderModelAndMask(OpCode C, bool bWithTranslation, unsigned &major, unsigned &minor, unsigned &mask); @@ -141,6 +160,13 @@ class OP { unsigned valMinor, unsigned &major, unsigned &minor, unsigned &mask); + static bool IsDxilOpExtendedOverload(OpCode C); + + // Return true if the overload name for this operation may be constructed + // based on a type name that may not represent the same type in different + // modules. + static bool MayHaveNonCanonicalOverload(OpCode OC); + private: // Per-module properties. llvm::LLVMContext &m_Ctx; @@ -164,13 +190,33 @@ class OP { DXIL::LowPrecisionMode m_LowPrecisionMode; - static const unsigned kUserDefineTypeSlot = 9; - static const unsigned kObjectTypeSlot = 10; - static const unsigned kNumTypeOverloads = - 11; // void, h,f,d, i1, i8,i16,i32,i64, udt, obj + // Overload types are split into "basic" overload types and special types + // Basic: void, half, float, double, i1, i8, i16, i32, i64 + // - These have one canonical overload per TypeSlot + // Special: udt, obj, vec, extended + // - These may have many overloads per type slot + enum TypeSlot : unsigned { + TS_F16 = 0, + TS_F32 = 1, + TS_F64 = 2, + TS_I1 = 3, + TS_I8 = 4, + TS_I16 = 5, + TS_I32 = 6, + TS_I64 = 7, + TS_BasicCount, + TS_UDT = 8, // Ex: %"struct.MyStruct" * + TS_Object = 9, // Ex: %"class.StructuredBuffer" + TS_Vector = 10, // Ex: <8 x i16> + TS_MaskBitCount, // Types used in Mask end here + // TS_Extended is only used to identify the unnamed struct type used to wrap + // multiple overloads when using GetTypeSlot. + TS_Extended, // Ex: type { float, <16 x i32> } + TS_Invalid = UINT_MAX, + }; - llvm::Type *m_pResRetType[kNumTypeOverloads]; - llvm::Type *m_pCBufferRetType[kNumTypeOverloads]; + llvm::Type *m_pResRetType[TS_BasicCount]; + llvm::Type *m_pCBufferRetType[TS_BasicCount]; struct OpCodeCacheItem { llvm::SmallMapVector pOverloads; @@ -181,27 +227,46 @@ class OP { private: // Static properties. + struct OverloadMask { + // mask of type slot bits as (1 << TypeSlot) + uint16_t SlotMask; + static_assert(TS_MaskBitCount <= (sizeof(SlotMask) * 8)); + bool operator[](unsigned TypeSlot) const { + return (TypeSlot < TS_MaskBitCount) ? (bool)(SlotMask & (1 << TypeSlot)) + : 0; + } + operator bool() const { return SlotMask != 0; } + }; struct OpCodeProperty { OpCode opCode; const char *pOpCodeName; OpCodeClass opCodeClass; const char *pOpCodeClassName; - bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64, - // udt llvm::Attribute::AttrKind FuncAttr; + + // Number of overload dimensions used by the operation. + unsigned int NumOverloadDims; + + // Mask of supported overload types for each overload dimension. + OverloadMask AllowedOverloads[DXIL::kDxilMaxOloadDims]; + + // Mask of scalar components allowed for each demension where + // AllowedOverloads[n][TS_Vector] is true. + OverloadMask AllowedVectorElements[DXIL::kDxilMaxOloadDims]; }; static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes]; - static const char *m_OverloadTypeName[kNumTypeOverloads]; + static const char *m_OverloadTypeName[TS_BasicCount]; static const char *m_NamePrefix; static const char *m_TypePrefix; static const char *m_MatrixTypePrefix; static unsigned GetTypeSlot(llvm::Type *pType); static const char *GetOverloadTypeName(unsigned TypeSlot); - static llvm::StringRef GetTypeName(llvm::Type *Ty, std::string &str); - static llvm::StringRef ConstructOverloadName(llvm::Type *Ty, - DXIL::OpCode opCode, - std::string &funcNameStorage); + static llvm::StringRef GetTypeName(llvm::Type *Ty, + llvm::SmallVectorImpl &Storage); + static llvm::StringRef + ConstructOverloadName(llvm::Type *Ty, DXIL::OpCode opCode, + llvm::SmallVectorImpl &Storage); }; } // namespace hlsl diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index 86049fee9c..a2b0432ce3 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -23,8 +23,6 @@ #include "llvm/Support/raw_ostream.h" using namespace llvm; -using std::string; -using std::vector; namespace hlsl { @@ -41,2989 +39,2605 @@ import hctdb_instrhelp /* hctdb_instrhelp.get_oloads_props()*/ // OPCODE-OLOADS:BEGIN const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = { - // OpCode OpCode name, OpCodeClass - // OpCodeClass name, void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj, function attribute - // Temporary, indexable, input, output registers void, h, f, d, - // i1, i8, i16, i32, i64, udt, obj , function attribute - { - OC::TempRegLoad, - "TempRegLoad", - OCC::TempRegLoad, - "tempRegLoad", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::TempRegStore, - "TempRegStore", - OCC::TempRegStore, - "tempRegStore", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - { - OC::MinPrecXRegLoad, - "MinPrecXRegLoad", - OCC::MinPrecXRegLoad, - "minPrecXRegLoad", - {false, true, false, false, false, false, true, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::MinPrecXRegStore, - "MinPrecXRegStore", - OCC::MinPrecXRegStore, - "minPrecXRegStore", - {false, true, false, false, false, false, true, false, false, false, - false}, - Attribute::None, - }, - { - OC::LoadInput, - "LoadInput", - OCC::LoadInput, - "loadInput", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::StoreOutput, - "StoreOutput", - OCC::StoreOutput, - "storeOutput", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - - // Unary float void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::FAbs, - "FAbs", - OCC::Unary, - "unary", - {false, true, true, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Saturate, - "Saturate", - OCC::Unary, - "unary", - {false, true, true, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::IsNaN, - "IsNaN", - OCC::IsSpecialFloat, - "isSpecialFloat", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::IsInf, - "IsInf", - OCC::IsSpecialFloat, - "isSpecialFloat", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::IsFinite, - "IsFinite", - OCC::IsSpecialFloat, - "isSpecialFloat", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::IsNormal, - "IsNormal", - OCC::IsSpecialFloat, - "isSpecialFloat", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Cos, - "Cos", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Sin, - "Sin", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Tan, - "Tan", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Acos, - "Acos", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Asin, - "Asin", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Atan, - "Atan", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Hcos, - "Hcos", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Hsin, - "Hsin", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Htan, - "Htan", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Exp, - "Exp", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Frc, - "Frc", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Log, - "Log", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Sqrt, - "Sqrt", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Rsqrt, - "Rsqrt", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Unary float - rounding void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::Round_ne, - "Round_ne", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Round_ni, - "Round_ni", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Round_pi, - "Round_pi", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Round_z, - "Round_z", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Unary int void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::Bfrev, - "Bfrev", - OCC::Unary, - "unary", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - { - OC::Countbits, - "Countbits", - OCC::UnaryBits, - "unaryBits", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - { - OC::FirstbitLo, - "FirstbitLo", - OCC::UnaryBits, - "unaryBits", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Unary uint void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::FirstbitHi, - "FirstbitHi", - OCC::UnaryBits, - "unaryBits", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Unary int void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::FirstbitSHi, - "FirstbitSHi", - OCC::UnaryBits, - "unaryBits", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Binary float void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::FMax, - "FMax", - OCC::Binary, - "binary", - {false, true, true, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::FMin, - "FMin", - OCC::Binary, - "binary", - {false, true, true, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Binary int void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::IMax, - "IMax", - OCC::Binary, - "binary", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - { - OC::IMin, - "IMin", - OCC::Binary, - "binary", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Binary uint void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::UMax, - "UMax", - OCC::Binary, - "binary", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - { - OC::UMin, - "UMin", - OCC::Binary, - "binary", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Binary int with two outputs void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::IMul, - "IMul", - OCC::BinaryWithTwoOuts, - "binaryWithTwoOuts", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Binary uint with two outputs void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::UMul, - "UMul", - OCC::BinaryWithTwoOuts, - "binaryWithTwoOuts", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::UDiv, - "UDiv", - OCC::BinaryWithTwoOuts, - "binaryWithTwoOuts", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Binary uint with carry or borrow void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::UAddc, - "UAddc", - OCC::BinaryWithCarryOrBorrow, - "binaryWithCarryOrBorrow", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::USubb, - "USubb", - OCC::BinaryWithCarryOrBorrow, - "binaryWithCarryOrBorrow", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Tertiary float void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::FMad, - "FMad", - OCC::Tertiary, - "tertiary", - {false, true, true, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Fma, - "Fma", - OCC::Tertiary, - "tertiary", - {false, false, false, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Tertiary int void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::IMad, - "IMad", - OCC::Tertiary, - "tertiary", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Tertiary uint void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::UMad, - "UMad", - OCC::Tertiary, - "tertiary", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Tertiary int void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::Msad, - "Msad", - OCC::Tertiary, - "tertiary", - {false, false, false, false, false, false, false, true, true, false, - false}, - Attribute::ReadNone, - }, - { - OC::Ibfe, - "Ibfe", - OCC::Tertiary, - "tertiary", - {false, false, false, false, false, false, false, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Tertiary uint void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::Ubfe, - "Ubfe", - OCC::Tertiary, - "tertiary", - {false, false, false, false, false, false, false, true, true, false, - false}, - Attribute::ReadNone, - }, - - // Quaternary void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::Bfi, - "Bfi", - OCC::Quaternary, - "quaternary", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Dot void, h, f, d, i1, i8, i16, i32, i64, udt, - // obj , function attribute - { - OC::Dot2, - "Dot2", - OCC::Dot2, - "dot2", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Dot3, - "Dot3", - OCC::Dot3, - "dot3", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Dot4, - "Dot4", - OCC::Dot4, - "dot4", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Resources void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::CreateHandle, - "CreateHandle", - OCC::CreateHandle, - "createHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::CBufferLoad, - "CBufferLoad", - OCC::CBufferLoad, - "cbufferLoad", - {false, true, true, true, false, true, true, true, true, false, false}, - Attribute::ReadOnly, - }, - { - OC::CBufferLoadLegacy, - "CBufferLoadLegacy", - OCC::CBufferLoadLegacy, - "cbufferLoadLegacy", - {false, true, true, true, false, false, true, true, true, false, false}, - Attribute::ReadOnly, - }, - - // Resources - sample void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::Sample, - "Sample", - OCC::Sample, - "sample", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::SampleBias, - "SampleBias", - OCC::SampleBias, - "sampleBias", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::SampleLevel, - "SampleLevel", - OCC::SampleLevel, - "sampleLevel", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::SampleGrad, - "SampleGrad", - OCC::SampleGrad, - "sampleGrad", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::SampleCmp, - "SampleCmp", - OCC::SampleCmp, - "sampleCmp", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::SampleCmpLevelZero, - "SampleCmpLevelZero", - OCC::SampleCmpLevelZero, - "sampleCmpLevelZero", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Resources void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::TextureLoad, - "TextureLoad", - OCC::TextureLoad, - "textureLoad", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::TextureStore, - "TextureStore", - OCC::TextureStore, - "textureStore", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - { - OC::BufferLoad, - "BufferLoad", - OCC::BufferLoad, - "bufferLoad", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::BufferStore, - "BufferStore", - OCC::BufferStore, - "bufferStore", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - { - OC::BufferUpdateCounter, - "BufferUpdateCounter", - OCC::BufferUpdateCounter, - "bufferUpdateCounter", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::CheckAccessFullyMapped, - "CheckAccessFullyMapped", - OCC::CheckAccessFullyMapped, - "checkAccessFullyMapped", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::GetDimensions, - "GetDimensions", - OCC::GetDimensions, - "getDimensions", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Resources - gather void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::TextureGather, - "TextureGather", - OCC::TextureGather, - "textureGather", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::TextureGatherCmp, - "TextureGatherCmp", - OCC::TextureGatherCmp, - "textureGatherCmp", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadOnly, - }, - - // Resources - sample void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::Texture2DMSGetSamplePosition, - "Texture2DMSGetSamplePosition", - OCC::Texture2DMSGetSamplePosition, - "texture2DMSGetSamplePosition", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RenderTargetGetSamplePosition, - "RenderTargetGetSamplePosition", - OCC::RenderTargetGetSamplePosition, - "renderTargetGetSamplePosition", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RenderTargetGetSampleCount, - "RenderTargetGetSampleCount", - OCC::RenderTargetGetSampleCount, - "renderTargetGetSampleCount", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Synchronization void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::AtomicBinOp, - "AtomicBinOp", - OCC::AtomicBinOp, - "atomicBinOp", - {false, false, false, false, false, false, false, true, true, false, - false}, - Attribute::None, - }, - { - OC::AtomicCompareExchange, - "AtomicCompareExchange", - OCC::AtomicCompareExchange, - "atomicCompareExchange", - {false, false, false, false, false, false, false, true, true, false, - false}, - Attribute::None, - }, - { - OC::Barrier, - "Barrier", - OCC::Barrier, - "barrier", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::NoDuplicate, - }, - - // Derivatives void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::CalculateLOD, - "CalculateLOD", - OCC::CalculateLOD, - "calculateLOD", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Pixel shader void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::Discard, - "Discard", - OCC::Discard, - "discard", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Derivatives void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::DerivCoarseX, - "DerivCoarseX", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::DerivCoarseY, - "DerivCoarseY", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::DerivFineX, - "DerivFineX", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::DerivFineY, - "DerivFineY", - OCC::Unary, - "unary", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Pixel shader void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::EvalSnapped, - "EvalSnapped", - OCC::EvalSnapped, - "evalSnapped", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::EvalSampleIndex, - "EvalSampleIndex", - OCC::EvalSampleIndex, - "evalSampleIndex", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::EvalCentroid, - "EvalCentroid", - OCC::EvalCentroid, - "evalCentroid", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::SampleIndex, - "SampleIndex", - OCC::SampleIndex, - "sampleIndex", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Coverage, - "Coverage", - OCC::Coverage, - "coverage", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::InnerCoverage, - "InnerCoverage", - OCC::InnerCoverage, - "innerCoverage", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Compute/Mesh/Amplification/Node shader void, h, f, d, i1, - // i8, i16, i32, i64, udt, obj , function attribute - { - OC::ThreadId, - "ThreadId", - OCC::ThreadId, - "threadId", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::GroupId, - "GroupId", - OCC::GroupId, - "groupId", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::ThreadIdInGroup, - "ThreadIdInGroup", - OCC::ThreadIdInGroup, - "threadIdInGroup", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::FlattenedThreadIdInGroup, - "FlattenedThreadIdInGroup", - OCC::FlattenedThreadIdInGroup, - "flattenedThreadIdInGroup", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Geometry shader void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::EmitStream, - "EmitStream", - OCC::EmitStream, - "emitStream", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::CutStream, - "CutStream", - OCC::CutStream, - "cutStream", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::EmitThenCutStream, - "EmitThenCutStream", - OCC::EmitThenCutStream, - "emitThenCutStream", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::GSInstanceID, - "GSInstanceID", - OCC::GSInstanceID, - "gsInstanceID", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Double precision void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::MakeDouble, - "MakeDouble", - OCC::MakeDouble, - "makeDouble", - {false, false, false, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::SplitDouble, - "SplitDouble", - OCC::SplitDouble, - "splitDouble", - {false, false, false, true, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Domain and hull shader void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::LoadOutputControlPoint, - "LoadOutputControlPoint", - OCC::LoadOutputControlPoint, - "loadOutputControlPoint", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::LoadPatchConstant, - "LoadPatchConstant", - OCC::LoadPatchConstant, - "loadPatchConstant", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Domain shader void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::DomainLocation, - "DomainLocation", - OCC::DomainLocation, - "domainLocation", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Hull shader void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::StorePatchConstant, - "StorePatchConstant", - OCC::StorePatchConstant, - "storePatchConstant", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - { - OC::OutputControlPointID, - "OutputControlPointID", - OCC::OutputControlPointID, - "outputControlPointID", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Hull, Domain and Geometry shaders void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::PrimitiveID, - "PrimitiveID", - OCC::PrimitiveID, - "primitiveID", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Other void, h, f, d, i1, i8, i16, i32, i64, udt, - // obj , function attribute - { - OC::CycleCounterLegacy, - "CycleCounterLegacy", - OCC::CycleCounterLegacy, - "cycleCounterLegacy", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Wave void, h, f, d, i1, i8, i16, i32, i64, udt, - // obj , function attribute - { - OC::WaveIsFirstLane, - "WaveIsFirstLane", - OCC::WaveIsFirstLane, - "waveIsFirstLane", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WaveGetLaneIndex, - "WaveGetLaneIndex", - OCC::WaveGetLaneIndex, - "waveGetLaneIndex", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::WaveGetLaneCount, - "WaveGetLaneCount", - OCC::WaveGetLaneCount, - "waveGetLaneCount", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::WaveAnyTrue, - "WaveAnyTrue", - OCC::WaveAnyTrue, - "waveAnyTrue", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WaveAllTrue, - "WaveAllTrue", - OCC::WaveAllTrue, - "waveAllTrue", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WaveActiveAllEqual, - "WaveActiveAllEqual", - OCC::WaveActiveAllEqual, - "waveActiveAllEqual", - {false, true, true, true, true, true, true, true, true, false, false}, - Attribute::None, - }, - { - OC::WaveActiveBallot, - "WaveActiveBallot", - OCC::WaveActiveBallot, - "waveActiveBallot", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WaveReadLaneAt, - "WaveReadLaneAt", - OCC::WaveReadLaneAt, - "waveReadLaneAt", - {false, true, true, true, true, true, true, true, true, false, false}, - Attribute::None, - }, - { - OC::WaveReadLaneFirst, - "WaveReadLaneFirst", - OCC::WaveReadLaneFirst, - "waveReadLaneFirst", - {false, true, true, true, true, true, true, true, true, false, false}, - Attribute::None, - }, - { - OC::WaveActiveOp, - "WaveActiveOp", - OCC::WaveActiveOp, - "waveActiveOp", - {false, true, true, true, true, true, true, true, true, false, false}, - Attribute::None, - }, - { - OC::WaveActiveBit, - "WaveActiveBit", - OCC::WaveActiveBit, - "waveActiveBit", - {false, false, false, false, false, true, true, true, true, false, - false}, - Attribute::None, - }, - { - OC::WavePrefixOp, - "WavePrefixOp", - OCC::WavePrefixOp, - "wavePrefixOp", - {false, true, true, true, false, true, true, true, true, false, false}, - Attribute::None, - }, - - // Quad Wave Ops void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::QuadReadLaneAt, - "QuadReadLaneAt", - OCC::QuadReadLaneAt, - "quadReadLaneAt", - {false, true, true, true, true, true, true, true, true, false, false}, - Attribute::None, - }, - { - OC::QuadOp, - "QuadOp", - OCC::QuadOp, - "quadOp", - {false, true, true, true, false, true, true, true, true, false, false}, - Attribute::None, - }, - - // Bitcasts with different sizes void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::BitcastI16toF16, - "BitcastI16toF16", - OCC::BitcastI16toF16, - "bitcastI16toF16", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::BitcastF16toI16, - "BitcastF16toI16", - OCC::BitcastF16toI16, - "bitcastF16toI16", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::BitcastI32toF32, - "BitcastI32toF32", - OCC::BitcastI32toF32, - "bitcastI32toF32", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::BitcastF32toI32, - "BitcastF32toI32", - OCC::BitcastF32toI32, - "bitcastF32toI32", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::BitcastI64toF64, - "BitcastI64toF64", - OCC::BitcastI64toF64, - "bitcastI64toF64", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::BitcastF64toI64, - "BitcastF64toI64", - OCC::BitcastF64toI64, - "bitcastF64toI64", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Legacy floating-point void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::LegacyF32ToF16, - "LegacyF32ToF16", - OCC::LegacyF32ToF16, - "legacyF32ToF16", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::LegacyF16ToF32, - "LegacyF16ToF32", - OCC::LegacyF16ToF32, - "legacyF16ToF32", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Double precision void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::LegacyDoubleToFloat, - "LegacyDoubleToFloat", - OCC::LegacyDoubleToFloat, - "legacyDoubleToFloat", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::LegacyDoubleToSInt32, - "LegacyDoubleToSInt32", - OCC::LegacyDoubleToSInt32, - "legacyDoubleToSInt32", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::LegacyDoubleToUInt32, - "LegacyDoubleToUInt32", - OCC::LegacyDoubleToUInt32, - "legacyDoubleToUInt32", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Wave void, h, f, d, i1, i8, i16, i32, i64, udt, - // obj , function attribute - { - OC::WaveAllBitCount, - "WaveAllBitCount", - OCC::WaveAllOp, - "waveAllOp", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WavePrefixBitCount, - "WavePrefixBitCount", - OCC::WavePrefixOp, - "wavePrefixOp", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Pixel shader void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::AttributeAtVertex, - "AttributeAtVertex", - OCC::AttributeAtVertex, - "attributeAtVertex", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Graphics shader void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::ViewID, - "ViewID", - OCC::ViewID, - "viewID", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Resources void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::RawBufferLoad, - "RawBufferLoad", - OCC::RawBufferLoad, - "rawBufferLoad", - {false, true, true, true, false, false, true, true, true, false, false}, - Attribute::ReadOnly, - }, - { - OC::RawBufferStore, - "RawBufferStore", - OCC::RawBufferStore, - "rawBufferStore", - {false, true, true, true, false, false, true, true, true, false, false}, - Attribute::None, - }, - - // Raytracing object space uint System Values void, h, f, d, i1, - // i8, i16, i32, i64, udt, obj , function attribute - { - OC::InstanceID, - "InstanceID", - OCC::InstanceID, - "instanceID", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::InstanceIndex, - "InstanceIndex", - OCC::InstanceIndex, - "instanceIndex", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Raytracing hit uint System Values void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::HitKind, - "HitKind", - OCC::HitKind, - "hitKind", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Raytracing uint System Values void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::RayFlags, - "RayFlags", - OCC::RayFlags, - "rayFlags", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Ray Dispatch Arguments void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::DispatchRaysIndex, - "DispatchRaysIndex", - OCC::DispatchRaysIndex, - "dispatchRaysIndex", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::DispatchRaysDimensions, - "DispatchRaysDimensions", - OCC::DispatchRaysDimensions, - "dispatchRaysDimensions", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Ray Vectors void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::WorldRayOrigin, - "WorldRayOrigin", - OCC::WorldRayOrigin, - "worldRayOrigin", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::WorldRayDirection, - "WorldRayDirection", - OCC::WorldRayDirection, - "worldRayDirection", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Ray object space Vectors void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::ObjectRayOrigin, - "ObjectRayOrigin", - OCC::ObjectRayOrigin, - "objectRayOrigin", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::ObjectRayDirection, - "ObjectRayDirection", - OCC::ObjectRayDirection, - "objectRayDirection", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Ray Transforms void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::ObjectToWorld, - "ObjectToWorld", - OCC::ObjectToWorld, - "objectToWorld", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::WorldToObject, - "WorldToObject", - OCC::WorldToObject, - "worldToObject", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // RayT void, h, f, d, i1, i8, i16, i32, i64, udt, - // obj , function attribute - { - OC::RayTMin, - "RayTMin", - OCC::RayTMin, - "rayTMin", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::RayTCurrent, - "RayTCurrent", - OCC::RayTCurrent, - "rayTCurrent", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // AnyHit Terminals void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::IgnoreHit, - "IgnoreHit", - OCC::IgnoreHit, - "ignoreHit", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::NoReturn, - }, - { - OC::AcceptHitAndEndSearch, - "AcceptHitAndEndSearch", - OCC::AcceptHitAndEndSearch, - "acceptHitAndEndSearch", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::NoReturn, - }, - - // Indirect Shader Invocation void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::TraceRay, - "TraceRay", - OCC::TraceRay, - "traceRay", - {false, false, false, false, false, false, false, false, false, true, - false}, - Attribute::None, - }, - { - OC::ReportHit, - "ReportHit", - OCC::ReportHit, - "reportHit", - {false, false, false, false, false, false, false, false, false, true, - false}, - Attribute::None, - }, - { - OC::CallShader, - "CallShader", - OCC::CallShader, - "callShader", - {false, false, false, false, false, false, false, false, false, true, - false}, - Attribute::None, - }, - - // Library create handle from resource struct (like HL intrinsic) void, h, - // f, d, i1, i8, i16, i32, i64, udt, obj , function - // attribute - { - OC::CreateHandleForLib, - "CreateHandleForLib", - OCC::CreateHandleForLib, - "createHandleForLib", - {false, false, false, false, false, false, false, false, false, false, - true}, - Attribute::ReadOnly, - }, - - // Raytracing object space uint System Values void, h, f, d, i1, - // i8, i16, i32, i64, udt, obj , function attribute - { - OC::PrimitiveIndex, - "PrimitiveIndex", - OCC::PrimitiveIndex, - "primitiveIndex", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Dot product with accumulate void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::Dot2AddHalf, - "Dot2AddHalf", - OCC::Dot2AddHalf, - "dot2AddHalf", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Dot4AddI8Packed, - "Dot4AddI8Packed", - OCC::Dot4AddPacked, - "dot4AddPacked", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::Dot4AddU8Packed, - "Dot4AddU8Packed", - OCC::Dot4AddPacked, - "dot4AddPacked", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Wave void, h, f, d, i1, i8, i16, i32, i64, udt, - // obj , function attribute - { - OC::WaveMatch, - "WaveMatch", - OCC::WaveMatch, - "waveMatch", - {false, true, true, true, false, true, true, true, true, false, false}, - Attribute::None, - }, - { - OC::WaveMultiPrefixOp, - "WaveMultiPrefixOp", - OCC::WaveMultiPrefixOp, - "waveMultiPrefixOp", - {false, true, true, true, false, true, true, true, true, false, false}, - Attribute::None, - }, - { - OC::WaveMultiPrefixBitCount, - "WaveMultiPrefixBitCount", - OCC::WaveMultiPrefixBitCount, - "waveMultiPrefixBitCount", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Mesh shader instructions void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::SetMeshOutputCounts, - "SetMeshOutputCounts", - OCC::SetMeshOutputCounts, - "setMeshOutputCounts", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::EmitIndices, - "EmitIndices", - OCC::EmitIndices, - "emitIndices", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::GetMeshPayload, - "GetMeshPayload", - OCC::GetMeshPayload, - "getMeshPayload", - {false, false, false, false, false, false, false, false, false, true, - false}, - Attribute::ReadOnly, - }, - { - OC::StoreVertexOutput, - "StoreVertexOutput", - OCC::StoreVertexOutput, - "storeVertexOutput", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - { - OC::StorePrimitiveOutput, - "StorePrimitiveOutput", - OCC::StorePrimitiveOutput, - "storePrimitiveOutput", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - - // Amplification shader instructions void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::DispatchMesh, - "DispatchMesh", - OCC::DispatchMesh, - "dispatchMesh", - {false, false, false, false, false, false, false, false, false, true, - false}, - Attribute::None, - }, - - // Sampler Feedback void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::WriteSamplerFeedback, - "WriteSamplerFeedback", - OCC::WriteSamplerFeedback, - "writeSamplerFeedback", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WriteSamplerFeedbackBias, - "WriteSamplerFeedbackBias", - OCC::WriteSamplerFeedbackBias, - "writeSamplerFeedbackBias", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WriteSamplerFeedbackLevel, - "WriteSamplerFeedbackLevel", - OCC::WriteSamplerFeedbackLevel, - "writeSamplerFeedbackLevel", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::WriteSamplerFeedbackGrad, - "WriteSamplerFeedbackGrad", - OCC::WriteSamplerFeedbackGrad, - "writeSamplerFeedbackGrad", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Inline Ray Query void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::AllocateRayQuery, - "AllocateRayQuery", - OCC::AllocateRayQuery, - "allocateRayQuery", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::RayQuery_TraceRayInline, - "RayQuery_TraceRayInline", - OCC::RayQuery_TraceRayInline, - "rayQuery_TraceRayInline", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::RayQuery_Proceed, - "RayQuery_Proceed", - OCC::RayQuery_Proceed, - "rayQuery_Proceed", - {false, false, false, false, true, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::RayQuery_Abort, - "RayQuery_Abort", - OCC::RayQuery_Abort, - "rayQuery_Abort", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::RayQuery_CommitNonOpaqueTriangleHit, - "RayQuery_CommitNonOpaqueTriangleHit", - OCC::RayQuery_CommitNonOpaqueTriangleHit, - "rayQuery_CommitNonOpaqueTriangleHit", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::RayQuery_CommitProceduralPrimitiveHit, - "RayQuery_CommitProceduralPrimitiveHit", - OCC::RayQuery_CommitProceduralPrimitiveHit, - "rayQuery_CommitProceduralPrimitiveHit", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::RayQuery_CommittedStatus, - "RayQuery_CommittedStatus", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateType, - "RayQuery_CandidateType", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateObjectToWorld3x4, - "RayQuery_CandidateObjectToWorld3x4", - OCC::RayQuery_StateMatrix, - "rayQuery_StateMatrix", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateWorldToObject3x4, - "RayQuery_CandidateWorldToObject3x4", - OCC::RayQuery_StateMatrix, - "rayQuery_StateMatrix", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedObjectToWorld3x4, - "RayQuery_CommittedObjectToWorld3x4", - OCC::RayQuery_StateMatrix, - "rayQuery_StateMatrix", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedWorldToObject3x4, - "RayQuery_CommittedWorldToObject3x4", - OCC::RayQuery_StateMatrix, - "rayQuery_StateMatrix", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateProceduralPrimitiveNonOpaque, - "RayQuery_CandidateProceduralPrimitiveNonOpaque", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, true, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateTriangleFrontFace, - "RayQuery_CandidateTriangleFrontFace", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, true, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedTriangleFrontFace, - "RayQuery_CommittedTriangleFrontFace", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, true, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateTriangleBarycentrics, - "RayQuery_CandidateTriangleBarycentrics", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedTriangleBarycentrics, - "RayQuery_CommittedTriangleBarycentrics", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_RayFlags, - "RayQuery_RayFlags", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_WorldRayOrigin, - "RayQuery_WorldRayOrigin", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_WorldRayDirection, - "RayQuery_WorldRayDirection", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_RayTMin, - "RayQuery_RayTMin", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateTriangleRayT, - "RayQuery_CandidateTriangleRayT", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedRayT, - "RayQuery_CommittedRayT", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateInstanceIndex, - "RayQuery_CandidateInstanceIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateInstanceID, - "RayQuery_CandidateInstanceID", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateGeometryIndex, - "RayQuery_CandidateGeometryIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidatePrimitiveIndex, - "RayQuery_CandidatePrimitiveIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateObjectRayOrigin, - "RayQuery_CandidateObjectRayOrigin", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CandidateObjectRayDirection, - "RayQuery_CandidateObjectRayDirection", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedInstanceIndex, - "RayQuery_CommittedInstanceIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedInstanceID, - "RayQuery_CommittedInstanceID", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedGeometryIndex, - "RayQuery_CommittedGeometryIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedPrimitiveIndex, - "RayQuery_CommittedPrimitiveIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedObjectRayOrigin, - "RayQuery_CommittedObjectRayOrigin", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedObjectRayDirection, - "RayQuery_CommittedObjectRayDirection", - OCC::RayQuery_StateVector, - "rayQuery_StateVector", - {false, false, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Raytracing object space uint System Values, raytracing tier 1.1 void, h, - // f, d, i1, i8, i16, i32, i64, udt, obj , function - // attribute - { - OC::GeometryIndex, - "GeometryIndex", - OCC::GeometryIndex, - "geometryIndex", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Inline Ray Query void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::RayQuery_CandidateInstanceContributionToHitGroupIndex, - "RayQuery_CandidateInstanceContributionToHitGroupIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::RayQuery_CommittedInstanceContributionToHitGroupIndex, - "RayQuery_CommittedInstanceContributionToHitGroupIndex", - OCC::RayQuery_StateScalar, - "rayQuery_StateScalar", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadOnly, - }, - - // Get handle from heap void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::AnnotateHandle, - "AnnotateHandle", - OCC::AnnotateHandle, - "annotateHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::CreateHandleFromBinding, - "CreateHandleFromBinding", - OCC::CreateHandleFromBinding, - "createHandleFromBinding", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::CreateHandleFromHeap, - "CreateHandleFromHeap", - OCC::CreateHandleFromHeap, - "createHandleFromHeap", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Unpacking intrinsics void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::Unpack4x8, - "Unpack4x8", - OCC::Unpack4x8, - "unpack4x8", - {false, false, false, false, false, false, true, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Packing intrinsics void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::Pack4x8, - "Pack4x8", - OCC::Pack4x8, - "pack4x8", - {false, false, false, false, false, false, true, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Helper Lanes void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::IsHelperLane, - "IsHelperLane", - OCC::IsHelperLane, - "isHelperLane", - {false, false, false, false, true, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Quad Wave Ops void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::QuadVote, - "QuadVote", - OCC::QuadVote, - "quadVote", - {false, false, false, false, true, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Resources - gather void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::TextureGatherRaw, - "TextureGatherRaw", - OCC::TextureGatherRaw, - "textureGatherRaw", - {false, false, false, false, false, false, true, true, true, false, - false}, - Attribute::ReadOnly, - }, - - // Resources - sample void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::SampleCmpLevel, - "SampleCmpLevel", - OCC::SampleCmpLevel, - "sampleCmpLevel", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Resources void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute - { - OC::TextureStoreSample, - "TextureStoreSample", - OCC::TextureStoreSample, - "textureStoreSample", - {false, true, true, false, false, false, true, true, false, false, - false}, - Attribute::None, - }, - - // void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute - { - OC::Reserved0, - "Reserved0", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved1, - "Reserved1", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved2, - "Reserved2", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved3, - "Reserved3", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved4, - "Reserved4", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved5, - "Reserved5", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved6, - "Reserved6", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved7, - "Reserved7", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved8, - "Reserved8", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved9, - "Reserved9", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved10, - "Reserved10", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::Reserved11, - "Reserved11", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Create/Annotate Node Handles void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::AllocateNodeOutputRecords, - "AllocateNodeOutputRecords", - OCC::AllocateNodeOutputRecords, - "allocateNodeOutputRecords", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Get Pointer to Node Record in Address Space 6 void, h, f, d, - // i1, i8, i16, i32, i64, udt, obj , function attribute - { - OC::GetNodeRecordPtr, - "GetNodeRecordPtr", - OCC::GetNodeRecordPtr, - "getNodeRecordPtr", - {false, false, false, false, false, false, false, false, false, true, - false}, - Attribute::ReadNone, - }, - - // Work Graph intrinsics void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::IncrementOutputCount, - "IncrementOutputCount", - OCC::IncrementOutputCount, - "incrementOutputCount", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::OutputComplete, - "OutputComplete", - OCC::OutputComplete, - "outputComplete", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::GetInputRecordCount, - "GetInputRecordCount", - OCC::GetInputRecordCount, - "getInputRecordCount", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::FinishedCrossGroupSharing, - "FinishedCrossGroupSharing", - OCC::FinishedCrossGroupSharing, - "finishedCrossGroupSharing", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Synchronization void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::BarrierByMemoryType, - "BarrierByMemoryType", - OCC::BarrierByMemoryType, - "barrierByMemoryType", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::NoDuplicate, - }, - { - OC::BarrierByMemoryHandle, - "BarrierByMemoryHandle", - OCC::BarrierByMemoryHandle, - "barrierByMemoryHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::NoDuplicate, - }, - { - OC::BarrierByNodeRecordHandle, - "BarrierByNodeRecordHandle", - OCC::BarrierByNodeRecordHandle, - "barrierByNodeRecordHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::NoDuplicate, - }, - - // Create/Annotate Node Handles void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::CreateNodeOutputHandle, - "CreateNodeOutputHandle", - OCC::createNodeOutputHandle, - "createNodeOutputHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::IndexNodeHandle, - "IndexNodeHandle", - OCC::IndexNodeHandle, - "indexNodeHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::AnnotateNodeHandle, - "AnnotateNodeHandle", - OCC::AnnotateNodeHandle, - "annotateNodeHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::CreateNodeInputRecordHandle, - "CreateNodeInputRecordHandle", - OCC::CreateNodeInputRecordHandle, - "createNodeInputRecordHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::AnnotateNodeRecordHandle, - "AnnotateNodeRecordHandle", - OCC::AnnotateNodeRecordHandle, - "annotateNodeRecordHandle", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // Work Graph intrinsics void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::NodeOutputIsValid, - "NodeOutputIsValid", - OCC::NodeOutputIsValid, - "nodeOutputIsValid", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::GetRemainingRecursionLevels, - "GetRemainingRecursionLevels", - OCC::GetRemainingRecursionLevels, - "getRemainingRecursionLevels", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Comparison Samples void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::SampleCmpGrad, - "SampleCmpGrad", - OCC::SampleCmpGrad, - "sampleCmpGrad", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - { - OC::SampleCmpBias, - "SampleCmpBias", - OCC::SampleCmpBias, - "sampleCmpBias", - {false, true, true, false, false, false, false, false, false, false, - false}, - Attribute::ReadOnly, - }, - - // Extended Command Information void, h, f, d, i1, i8, - // i16, i32, i64, udt, obj , function attribute - { - OC::StartVertexLocation, - "StartVertexLocation", - OCC::StartVertexLocation, - "startVertexLocation", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::StartInstanceLocation, - "StartInstanceLocation", - OCC::StartInstanceLocation, - "startInstanceLocation", - {false, false, false, false, false, false, false, true, false, false, - false}, - Attribute::ReadNone, - }, - - // Inline Ray Query void, h, f, d, i1, i8, i16, i32, - // i64, udt, obj , function attribute - { - OC::AllocateRayQuery2, - "AllocateRayQuery2", - OCC::AllocateRayQuery2, - "allocateRayQuery2", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute - { - OC::ReservedA0, - "ReservedA0", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedA1, - "ReservedA1", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedA2, - "ReservedA2", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB0, - "ReservedB0", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB1, - "ReservedB1", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB2, - "ReservedB2", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - - // Shader Execution Reordering void, h, f, d, i1, i8, i16, - // i32, i64, udt, obj , function attribute - { - OC::HitObject_MakeMiss, - "HitObject_MakeMiss", - OCC::HitObject_MakeMiss, - "hitObject_MakeMiss", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - { - OC::HitObject_MakeNop, - "HitObject_MakeNop", - OCC::HitObject_MakeNop, - "hitObject_MakeNop", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::ReadNone, - }, - - // void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute - { - OC::ReservedB5, - "ReservedB5", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB6, - "ReservedB6", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB7, - "ReservedB7", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB8, - "ReservedB8", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB9, - "ReservedB9", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB10, - "ReservedB10", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB11, - "ReservedB11", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB12, - "ReservedB12", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB13, - "ReservedB13", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB14, - "ReservedB14", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB15, - "ReservedB15", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB16, - "ReservedB16", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB17, - "ReservedB17", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB18, - "ReservedB18", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB19, - "ReservedB19", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB20, - "ReservedB20", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB21, - "ReservedB21", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB22, - "ReservedB22", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB23, - "ReservedB23", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB24, - "ReservedB24", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB25, - "ReservedB25", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB26, - "ReservedB26", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB27, - "ReservedB27", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB28, - "ReservedB28", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB29, - "ReservedB29", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedB30, - "ReservedB30", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC0, - "ReservedC0", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC1, - "ReservedC1", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC2, - "ReservedC2", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC3, - "ReservedC3", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC4, - "ReservedC4", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC5, - "ReservedC5", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC6, - "ReservedC6", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC7, - "ReservedC7", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC8, - "ReservedC8", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, - { - OC::ReservedC9, - "ReservedC9", - OCC::Reserved, - "reserved", - {true, false, false, false, false, false, false, false, false, false, - false}, - Attribute::None, - }, + // Temporary, indexable, input, output registers + {OC::TempRegLoad, + "TempRegLoad", + OCC::TempRegLoad, + "tempRegLoad", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::TempRegStore, + "TempRegStore", + OCC::TempRegStore, + "tempRegStore", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::MinPrecXRegLoad, + "MinPrecXRegLoad", + OCC::MinPrecXRegLoad, + "minPrecXRegLoad", + Attribute::ReadOnly, + 1, + {{0x21}}, + {{0x0}}}, // Overloads: hw + {OC::MinPrecXRegStore, + "MinPrecXRegStore", + OCC::MinPrecXRegStore, + "minPrecXRegStore", + Attribute::None, + 1, + {{0x21}}, + {{0x0}}}, // Overloads: hw + {OC::LoadInput, + "LoadInput", + OCC::LoadInput, + "loadInput", + Attribute::ReadNone, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::StoreOutput, + "StoreOutput", + OCC::StoreOutput, + "storeOutput", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + + // Unary float + {OC::FAbs, + "FAbs", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x7}}, + {{0x0}}}, // Overloads: hfd + {OC::Saturate, + "Saturate", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x7}}, + {{0x0}}}, // Overloads: hfd + {OC::IsNaN, + "IsNaN", + OCC::IsSpecialFloat, + "isSpecialFloat", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::IsInf, + "IsInf", + OCC::IsSpecialFloat, + "isSpecialFloat", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::IsFinite, + "IsFinite", + OCC::IsSpecialFloat, + "isSpecialFloat", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::IsNormal, + "IsNormal", + OCC::IsSpecialFloat, + "isSpecialFloat", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Cos, + "Cos", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Sin, + "Sin", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Tan, + "Tan", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Acos, + "Acos", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Asin, + "Asin", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Atan, + "Atan", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Hcos, + "Hcos", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Hsin, + "Hsin", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Htan, + "Htan", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Exp, + "Exp", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Frc, + "Frc", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Log, + "Log", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Sqrt, + "Sqrt", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Rsqrt, + "Rsqrt", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + + // Unary float - rounding + {OC::Round_ne, + "Round_ne", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Round_ni, + "Round_ni", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Round_pi, + "Round_pi", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Round_z, + "Round_z", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + + // Unary int + {OC::Bfrev, + "Bfrev", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + {OC::Countbits, + "Countbits", + OCC::UnaryBits, + "unaryBits", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + {OC::FirstbitLo, + "FirstbitLo", + OCC::UnaryBits, + "unaryBits", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Unary uint + {OC::FirstbitHi, + "FirstbitHi", + OCC::UnaryBits, + "unaryBits", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Unary int + {OC::FirstbitSHi, + "FirstbitSHi", + OCC::UnaryBits, + "unaryBits", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Binary float + {OC::FMax, + "FMax", + OCC::Binary, + "binary", + Attribute::ReadNone, + 1, + {{0x7}}, + {{0x0}}}, // Overloads: hfd + {OC::FMin, + "FMin", + OCC::Binary, + "binary", + Attribute::ReadNone, + 1, + {{0x7}}, + {{0x0}}}, // Overloads: hfd + + // Binary int + {OC::IMax, + "IMax", + OCC::Binary, + "binary", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + {OC::IMin, + "IMin", + OCC::Binary, + "binary", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Binary uint + {OC::UMax, + "UMax", + OCC::Binary, + "binary", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + {OC::UMin, + "UMin", + OCC::Binary, + "binary", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Binary int with two outputs + {OC::IMul, + "IMul", + OCC::BinaryWithTwoOuts, + "binaryWithTwoOuts", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Binary uint with two outputs + {OC::UMul, + "UMul", + OCC::BinaryWithTwoOuts, + "binaryWithTwoOuts", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::UDiv, + "UDiv", + OCC::BinaryWithTwoOuts, + "binaryWithTwoOuts", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Binary uint with carry or borrow + {OC::UAddc, + "UAddc", + OCC::BinaryWithCarryOrBorrow, + "binaryWithCarryOrBorrow", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::USubb, + "USubb", + OCC::BinaryWithCarryOrBorrow, + "binaryWithCarryOrBorrow", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Tertiary float + {OC::FMad, + "FMad", + OCC::Tertiary, + "tertiary", + Attribute::ReadNone, + 1, + {{0x7}}, + {{0x0}}}, // Overloads: hfd + {OC::Fma, + "Fma", + OCC::Tertiary, + "tertiary", + Attribute::ReadNone, + 1, + {{0x4}}, + {{0x0}}}, // Overloads: d + + // Tertiary int + {OC::IMad, + "IMad", + OCC::Tertiary, + "tertiary", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Tertiary uint + {OC::UMad, + "UMad", + OCC::Tertiary, + "tertiary", + Attribute::ReadNone, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Tertiary int + {OC::Msad, + "Msad", + OCC::Tertiary, + "tertiary", + Attribute::ReadNone, + 1, + {{0xc0}}, + {{0x0}}}, // Overloads: il + {OC::Ibfe, + "Ibfe", + OCC::Tertiary, + "tertiary", + Attribute::ReadNone, + 1, + {{0xc0}}, + {{0x0}}}, // Overloads: il + + // Tertiary uint + {OC::Ubfe, + "Ubfe", + OCC::Tertiary, + "tertiary", + Attribute::ReadNone, + 1, + {{0xc0}}, + {{0x0}}}, // Overloads: il + + // Quaternary + {OC::Bfi, + "Bfi", + OCC::Quaternary, + "quaternary", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Dot + {OC::Dot2, + "Dot2", + OCC::Dot2, + "dot2", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Dot3, + "Dot3", + OCC::Dot3, + "dot3", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::Dot4, + "Dot4", + OCC::Dot4, + "dot4", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + + // Resources + {OC::CreateHandle, + "CreateHandle", + OCC::CreateHandle, + "createHandle", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + {OC::CBufferLoad, + "CBufferLoad", + OCC::CBufferLoad, + "cbufferLoad", + Attribute::ReadOnly, + 1, + {{0xf7}}, + {{0x0}}}, // Overloads: hfd8wil + {OC::CBufferLoadLegacy, + "CBufferLoadLegacy", + OCC::CBufferLoadLegacy, + "cbufferLoadLegacy", + Attribute::ReadOnly, + 1, + {{0xe7}}, + {{0x0}}}, // Overloads: hfdwil + + // Resources - sample + {OC::Sample, + "Sample", + OCC::Sample, + "sample", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::SampleBias, + "SampleBias", + OCC::SampleBias, + "sampleBias", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::SampleLevel, + "SampleLevel", + OCC::SampleLevel, + "sampleLevel", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::SampleGrad, + "SampleGrad", + OCC::SampleGrad, + "sampleGrad", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::SampleCmp, + "SampleCmp", + OCC::SampleCmp, + "sampleCmp", + Attribute::ReadOnly, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::SampleCmpLevelZero, + "SampleCmpLevelZero", + OCC::SampleCmpLevelZero, + "sampleCmpLevelZero", + Attribute::ReadOnly, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + + // Resources + {OC::TextureLoad, + "TextureLoad", + OCC::TextureLoad, + "textureLoad", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::TextureStore, + "TextureStore", + OCC::TextureStore, + "textureStore", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::BufferLoad, + "BufferLoad", + OCC::BufferLoad, + "bufferLoad", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::BufferStore, + "BufferStore", + OCC::BufferStore, + "bufferStore", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::BufferUpdateCounter, + "BufferUpdateCounter", + OCC::BufferUpdateCounter, + "bufferUpdateCounter", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::CheckAccessFullyMapped, + "CheckAccessFullyMapped", + OCC::CheckAccessFullyMapped, + "checkAccessFullyMapped", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::GetDimensions, + "GetDimensions", + OCC::GetDimensions, + "getDimensions", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + + // Resources - gather + {OC::TextureGather, + "TextureGather", + OCC::TextureGather, + "textureGather", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::TextureGatherCmp, + "TextureGatherCmp", + OCC::TextureGatherCmp, + "textureGatherCmp", + Attribute::ReadOnly, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + + // Resources - sample + {OC::Texture2DMSGetSamplePosition, + "Texture2DMSGetSamplePosition", + OCC::Texture2DMSGetSamplePosition, + "texture2DMSGetSamplePosition", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + {OC::RenderTargetGetSamplePosition, + "RenderTargetGetSamplePosition", + OCC::RenderTargetGetSamplePosition, + "renderTargetGetSamplePosition", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + {OC::RenderTargetGetSampleCount, + "RenderTargetGetSampleCount", + OCC::RenderTargetGetSampleCount, + "renderTargetGetSampleCount", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + + // Synchronization + {OC::AtomicBinOp, + "AtomicBinOp", + OCC::AtomicBinOp, + "atomicBinOp", + Attribute::None, + 1, + {{0xc0}}, + {{0x0}}}, // Overloads: li + {OC::AtomicCompareExchange, + "AtomicCompareExchange", + OCC::AtomicCompareExchange, + "atomicCompareExchange", + Attribute::None, + 1, + {{0xc0}}, + {{0x0}}}, // Overloads: li + {OC::Barrier, + "Barrier", + OCC::Barrier, + "barrier", + Attribute::NoDuplicate, + 0, + {}, + {}}, // Overloads: v + + // Derivatives + {OC::CalculateLOD, + "CalculateLOD", + OCC::CalculateLOD, + "calculateLOD", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + + // Pixel shader + {OC::Discard, + "Discard", + OCC::Discard, + "discard", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Derivatives + {OC::DerivCoarseX, + "DerivCoarseX", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::DerivCoarseY, + "DerivCoarseY", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::DerivFineX, + "DerivFineX", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::DerivFineY, + "DerivFineY", + OCC::Unary, + "unary", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + + // Pixel shader + {OC::EvalSnapped, + "EvalSnapped", + OCC::EvalSnapped, + "evalSnapped", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::EvalSampleIndex, + "EvalSampleIndex", + OCC::EvalSampleIndex, + "evalSampleIndex", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::EvalCentroid, + "EvalCentroid", + OCC::EvalCentroid, + "evalCentroid", + Attribute::ReadNone, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::SampleIndex, + "SampleIndex", + OCC::SampleIndex, + "sampleIndex", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::Coverage, + "Coverage", + OCC::Coverage, + "coverage", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::InnerCoverage, + "InnerCoverage", + OCC::InnerCoverage, + "innerCoverage", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Compute/Mesh/Amplification/Node shader + {OC::ThreadId, + "ThreadId", + OCC::ThreadId, + "threadId", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::GroupId, + "GroupId", + OCC::GroupId, + "groupId", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::ThreadIdInGroup, + "ThreadIdInGroup", + OCC::ThreadIdInGroup, + "threadIdInGroup", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::FlattenedThreadIdInGroup, + "FlattenedThreadIdInGroup", + OCC::FlattenedThreadIdInGroup, + "flattenedThreadIdInGroup", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Geometry shader + {OC::EmitStream, + "EmitStream", + OCC::EmitStream, + "emitStream", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::CutStream, + "CutStream", + OCC::CutStream, + "cutStream", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::EmitThenCutStream, + "EmitThenCutStream", + OCC::EmitThenCutStream, + "emitThenCutStream", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::GSInstanceID, + "GSInstanceID", + OCC::GSInstanceID, + "gsInstanceID", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Double precision + {OC::MakeDouble, + "MakeDouble", + OCC::MakeDouble, + "makeDouble", + Attribute::ReadNone, + 1, + {{0x4}}, + {{0x0}}}, // Overloads: d + {OC::SplitDouble, + "SplitDouble", + OCC::SplitDouble, + "splitDouble", + Attribute::ReadNone, + 1, + {{0x4}}, + {{0x0}}}, // Overloads: d + + // Domain and hull shader + {OC::LoadOutputControlPoint, + "LoadOutputControlPoint", + OCC::LoadOutputControlPoint, + "loadOutputControlPoint", + Attribute::ReadNone, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::LoadPatchConstant, + "LoadPatchConstant", + OCC::LoadPatchConstant, + "loadPatchConstant", + Attribute::ReadNone, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + + // Domain shader + {OC::DomainLocation, + "DomainLocation", + OCC::DomainLocation, + "domainLocation", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + + // Hull shader + {OC::StorePatchConstant, + "StorePatchConstant", + OCC::StorePatchConstant, + "storePatchConstant", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::OutputControlPointID, + "OutputControlPointID", + OCC::OutputControlPointID, + "outputControlPointID", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Hull, Domain and Geometry shaders + {OC::PrimitiveID, + "PrimitiveID", + OCC::PrimitiveID, + "primitiveID", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Other + {OC::CycleCounterLegacy, + "CycleCounterLegacy", + OCC::CycleCounterLegacy, + "cycleCounterLegacy", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Wave + {OC::WaveIsFirstLane, + "WaveIsFirstLane", + OCC::WaveIsFirstLane, + "waveIsFirstLane", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WaveGetLaneIndex, + "WaveGetLaneIndex", + OCC::WaveGetLaneIndex, + "waveGetLaneIndex", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + {OC::WaveGetLaneCount, + "WaveGetLaneCount", + OCC::WaveGetLaneCount, + "waveGetLaneCount", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::WaveAnyTrue, + "WaveAnyTrue", + OCC::WaveAnyTrue, + "waveAnyTrue", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WaveAllTrue, + "WaveAllTrue", + OCC::WaveAllTrue, + "waveAllTrue", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WaveActiveAllEqual, + "WaveActiveAllEqual", + OCC::WaveActiveAllEqual, + "waveActiveAllEqual", + Attribute::None, + 1, + {{0xff}}, + {{0x0}}}, // Overloads: hfd18wil + {OC::WaveActiveBallot, + "WaveActiveBallot", + OCC::WaveActiveBallot, + "waveActiveBallot", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WaveReadLaneAt, + "WaveReadLaneAt", + OCC::WaveReadLaneAt, + "waveReadLaneAt", + Attribute::None, + 1, + {{0xff}}, + {{0x0}}}, // Overloads: hfd18wil + {OC::WaveReadLaneFirst, + "WaveReadLaneFirst", + OCC::WaveReadLaneFirst, + "waveReadLaneFirst", + Attribute::None, + 1, + {{0xff}}, + {{0x0}}}, // Overloads: hfd18wil + {OC::WaveActiveOp, + "WaveActiveOp", + OCC::WaveActiveOp, + "waveActiveOp", + Attribute::None, + 1, + {{0xff}}, + {{0x0}}}, // Overloads: hfd18wil + {OC::WaveActiveBit, + "WaveActiveBit", + OCC::WaveActiveBit, + "waveActiveBit", + Attribute::None, + 1, + {{0xf0}}, + {{0x0}}}, // Overloads: 8wil + {OC::WavePrefixOp, + "WavePrefixOp", + OCC::WavePrefixOp, + "wavePrefixOp", + Attribute::None, + 1, + {{0xf7}}, + {{0x0}}}, // Overloads: hfd8wil + + // Quad Wave Ops + {OC::QuadReadLaneAt, + "QuadReadLaneAt", + OCC::QuadReadLaneAt, + "quadReadLaneAt", + Attribute::None, + 1, + {{0xff}}, + {{0x0}}}, // Overloads: hfd18wil + {OC::QuadOp, + "QuadOp", + OCC::QuadOp, + "quadOp", + Attribute::None, + 1, + {{0xf7}}, + {{0x0}}}, // Overloads: hfd8wil + + // Bitcasts with different sizes + {OC::BitcastI16toF16, + "BitcastI16toF16", + OCC::BitcastI16toF16, + "bitcastI16toF16", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::BitcastF16toI16, + "BitcastF16toI16", + OCC::BitcastF16toI16, + "bitcastF16toI16", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::BitcastI32toF32, + "BitcastI32toF32", + OCC::BitcastI32toF32, + "bitcastI32toF32", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::BitcastF32toI32, + "BitcastF32toI32", + OCC::BitcastF32toI32, + "bitcastF32toI32", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::BitcastI64toF64, + "BitcastI64toF64", + OCC::BitcastI64toF64, + "bitcastI64toF64", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::BitcastF64toI64, + "BitcastF64toI64", + OCC::BitcastF64toI64, + "bitcastF64toI64", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + + // Legacy floating-point + {OC::LegacyF32ToF16, + "LegacyF32ToF16", + OCC::LegacyF32ToF16, + "legacyF32ToF16", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::LegacyF16ToF32, + "LegacyF16ToF32", + OCC::LegacyF16ToF32, + "legacyF16ToF32", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + + // Double precision + {OC::LegacyDoubleToFloat, + "LegacyDoubleToFloat", + OCC::LegacyDoubleToFloat, + "legacyDoubleToFloat", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::LegacyDoubleToSInt32, + "LegacyDoubleToSInt32", + OCC::LegacyDoubleToSInt32, + "legacyDoubleToSInt32", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::LegacyDoubleToUInt32, + "LegacyDoubleToUInt32", + OCC::LegacyDoubleToUInt32, + "legacyDoubleToUInt32", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + + // Wave + {OC::WaveAllBitCount, + "WaveAllBitCount", + OCC::WaveAllOp, + "waveAllOp", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WavePrefixBitCount, + "WavePrefixBitCount", + OCC::WavePrefixOp, + "wavePrefixOp", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Pixel shader + {OC::AttributeAtVertex, + "AttributeAtVertex", + OCC::AttributeAtVertex, + "attributeAtVertex", + Attribute::ReadNone, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfiw + + // Graphics shader + {OC::ViewID, + "ViewID", + OCC::ViewID, + "viewID", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Resources + {OC::RawBufferLoad, + "RawBufferLoad", + OCC::RawBufferLoad, + "rawBufferLoad", + Attribute::ReadOnly, + 1, + {{0xe7}}, + {{0x0}}}, // Overloads: hfwidl + {OC::RawBufferStore, + "RawBufferStore", + OCC::RawBufferStore, + "rawBufferStore", + Attribute::None, + 1, + {{0xe7}}, + {{0x0}}}, // Overloads: hfwidl + + // Raytracing object space uint System Values + {OC::InstanceID, + "InstanceID", + OCC::InstanceID, + "instanceID", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::InstanceIndex, + "InstanceIndex", + OCC::InstanceIndex, + "instanceIndex", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Raytracing hit uint System Values + {OC::HitKind, + "HitKind", + OCC::HitKind, + "hitKind", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Raytracing uint System Values + {OC::RayFlags, + "RayFlags", + OCC::RayFlags, + "rayFlags", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Ray Dispatch Arguments + {OC::DispatchRaysIndex, + "DispatchRaysIndex", + OCC::DispatchRaysIndex, + "dispatchRaysIndex", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::DispatchRaysDimensions, + "DispatchRaysDimensions", + OCC::DispatchRaysDimensions, + "dispatchRaysDimensions", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Ray Vectors + {OC::WorldRayOrigin, + "WorldRayOrigin", + OCC::WorldRayOrigin, + "worldRayOrigin", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::WorldRayDirection, + "WorldRayDirection", + OCC::WorldRayDirection, + "worldRayDirection", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + + // Ray object space Vectors + {OC::ObjectRayOrigin, + "ObjectRayOrigin", + OCC::ObjectRayOrigin, + "objectRayOrigin", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::ObjectRayDirection, + "ObjectRayDirection", + OCC::ObjectRayDirection, + "objectRayDirection", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + + // Ray Transforms + {OC::ObjectToWorld, + "ObjectToWorld", + OCC::ObjectToWorld, + "objectToWorld", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::WorldToObject, + "WorldToObject", + OCC::WorldToObject, + "worldToObject", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + + // RayT + {OC::RayTMin, + "RayTMin", + OCC::RayTMin, + "rayTMin", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayTCurrent, + "RayTCurrent", + OCC::RayTCurrent, + "rayTCurrent", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + + // AnyHit Terminals + {OC::IgnoreHit, + "IgnoreHit", + OCC::IgnoreHit, + "ignoreHit", + Attribute::NoReturn, + 0, + {}, + {}}, // Overloads: v + {OC::AcceptHitAndEndSearch, + "AcceptHitAndEndSearch", + OCC::AcceptHitAndEndSearch, + "acceptHitAndEndSearch", + Attribute::NoReturn, + 0, + {}, + {}}, // Overloads: v + + // Indirect Shader Invocation + {OC::TraceRay, + "TraceRay", + OCC::TraceRay, + "traceRay", + Attribute::None, + 1, + {{0x100}}, + {{0x0}}}, // Overloads: u + {OC::ReportHit, + "ReportHit", + OCC::ReportHit, + "reportHit", + Attribute::None, + 1, + {{0x100}}, + {{0x0}}}, // Overloads: u + {OC::CallShader, + "CallShader", + OCC::CallShader, + "callShader", + Attribute::None, + 1, + {{0x100}}, + {{0x0}}}, // Overloads: u + + // Library create handle from resource struct (like HL intrinsic) + {OC::CreateHandleForLib, + "CreateHandleForLib", + OCC::CreateHandleForLib, + "createHandleForLib", + Attribute::ReadOnly, + 1, + {{0x200}}, + {{0x0}}}, // Overloads: o + + // Raytracing object space uint System Values + {OC::PrimitiveIndex, + "PrimitiveIndex", + OCC::PrimitiveIndex, + "primitiveIndex", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Dot product with accumulate + {OC::Dot2AddHalf, + "Dot2AddHalf", + OCC::Dot2AddHalf, + "dot2AddHalf", + Attribute::ReadNone, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::Dot4AddI8Packed, + "Dot4AddI8Packed", + OCC::Dot4AddPacked, + "dot4AddPacked", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::Dot4AddU8Packed, + "Dot4AddU8Packed", + OCC::Dot4AddPacked, + "dot4AddPacked", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Wave + {OC::WaveMatch, + "WaveMatch", + OCC::WaveMatch, + "waveMatch", + Attribute::None, + 1, + {{0xf7}}, + {{0x0}}}, // Overloads: hfd8wil + {OC::WaveMultiPrefixOp, + "WaveMultiPrefixOp", + OCC::WaveMultiPrefixOp, + "waveMultiPrefixOp", + Attribute::None, + 1, + {{0xf7}}, + {{0x0}}}, // Overloads: hfd8wil + {OC::WaveMultiPrefixBitCount, + "WaveMultiPrefixBitCount", + OCC::WaveMultiPrefixBitCount, + "waveMultiPrefixBitCount", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Mesh shader instructions + {OC::SetMeshOutputCounts, + "SetMeshOutputCounts", + OCC::SetMeshOutputCounts, + "setMeshOutputCounts", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::EmitIndices, + "EmitIndices", + OCC::EmitIndices, + "emitIndices", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::GetMeshPayload, + "GetMeshPayload", + OCC::GetMeshPayload, + "getMeshPayload", + Attribute::ReadOnly, + 1, + {{0x100}}, + {{0x0}}}, // Overloads: u + {OC::StoreVertexOutput, + "StoreVertexOutput", + OCC::StoreVertexOutput, + "storeVertexOutput", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::StorePrimitiveOutput, + "StorePrimitiveOutput", + OCC::StorePrimitiveOutput, + "storePrimitiveOutput", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + + // Amplification shader instructions + {OC::DispatchMesh, + "DispatchMesh", + OCC::DispatchMesh, + "dispatchMesh", + Attribute::None, + 1, + {{0x100}}, + {{0x0}}}, // Overloads: u + + // Sampler Feedback + {OC::WriteSamplerFeedback, + "WriteSamplerFeedback", + OCC::WriteSamplerFeedback, + "writeSamplerFeedback", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WriteSamplerFeedbackBias, + "WriteSamplerFeedbackBias", + OCC::WriteSamplerFeedbackBias, + "writeSamplerFeedbackBias", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WriteSamplerFeedbackLevel, + "WriteSamplerFeedbackLevel", + OCC::WriteSamplerFeedbackLevel, + "writeSamplerFeedbackLevel", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::WriteSamplerFeedbackGrad, + "WriteSamplerFeedbackGrad", + OCC::WriteSamplerFeedbackGrad, + "writeSamplerFeedbackGrad", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Inline Ray Query + {OC::AllocateRayQuery, + "AllocateRayQuery", + OCC::AllocateRayQuery, + "allocateRayQuery", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::RayQuery_TraceRayInline, + "RayQuery_TraceRayInline", + OCC::RayQuery_TraceRayInline, + "rayQuery_TraceRayInline", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::RayQuery_Proceed, + "RayQuery_Proceed", + OCC::RayQuery_Proceed, + "rayQuery_Proceed", + Attribute::None, + 1, + {{0x8}}, + {{0x0}}}, // Overloads: 1 + {OC::RayQuery_Abort, + "RayQuery_Abort", + OCC::RayQuery_Abort, + "rayQuery_Abort", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::RayQuery_CommitNonOpaqueTriangleHit, + "RayQuery_CommitNonOpaqueTriangleHit", + OCC::RayQuery_CommitNonOpaqueTriangleHit, + "rayQuery_CommitNonOpaqueTriangleHit", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::RayQuery_CommitProceduralPrimitiveHit, + "RayQuery_CommitProceduralPrimitiveHit", + OCC::RayQuery_CommitProceduralPrimitiveHit, + "rayQuery_CommitProceduralPrimitiveHit", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::RayQuery_CommittedStatus, + "RayQuery_CommittedStatus", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CandidateType, + "RayQuery_CandidateType", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CandidateObjectToWorld3x4, + "RayQuery_CandidateObjectToWorld3x4", + OCC::RayQuery_StateMatrix, + "rayQuery_StateMatrix", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CandidateWorldToObject3x4, + "RayQuery_CandidateWorldToObject3x4", + OCC::RayQuery_StateMatrix, + "rayQuery_StateMatrix", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CommittedObjectToWorld3x4, + "RayQuery_CommittedObjectToWorld3x4", + OCC::RayQuery_StateMatrix, + "rayQuery_StateMatrix", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CommittedWorldToObject3x4, + "RayQuery_CommittedWorldToObject3x4", + OCC::RayQuery_StateMatrix, + "rayQuery_StateMatrix", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CandidateProceduralPrimitiveNonOpaque, + "RayQuery_CandidateProceduralPrimitiveNonOpaque", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x8}}, + {{0x0}}}, // Overloads: 1 + {OC::RayQuery_CandidateTriangleFrontFace, + "RayQuery_CandidateTriangleFrontFace", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x8}}, + {{0x0}}}, // Overloads: 1 + {OC::RayQuery_CommittedTriangleFrontFace, + "RayQuery_CommittedTriangleFrontFace", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x8}}, + {{0x0}}}, // Overloads: 1 + {OC::RayQuery_CandidateTriangleBarycentrics, + "RayQuery_CandidateTriangleBarycentrics", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CommittedTriangleBarycentrics, + "RayQuery_CommittedTriangleBarycentrics", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_RayFlags, + "RayQuery_RayFlags", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_WorldRayOrigin, + "RayQuery_WorldRayOrigin", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_WorldRayDirection, + "RayQuery_WorldRayDirection", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_RayTMin, + "RayQuery_RayTMin", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CandidateTriangleRayT, + "RayQuery_CandidateTriangleRayT", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CommittedRayT, + "RayQuery_CommittedRayT", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CandidateInstanceIndex, + "RayQuery_CandidateInstanceIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CandidateInstanceID, + "RayQuery_CandidateInstanceID", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CandidateGeometryIndex, + "RayQuery_CandidateGeometryIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CandidatePrimitiveIndex, + "RayQuery_CandidatePrimitiveIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CandidateObjectRayOrigin, + "RayQuery_CandidateObjectRayOrigin", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CandidateObjectRayDirection, + "RayQuery_CandidateObjectRayDirection", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CommittedInstanceIndex, + "RayQuery_CommittedInstanceIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CommittedInstanceID, + "RayQuery_CommittedInstanceID", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CommittedGeometryIndex, + "RayQuery_CommittedGeometryIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CommittedPrimitiveIndex, + "RayQuery_CommittedPrimitiveIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CommittedObjectRayOrigin, + "RayQuery_CommittedObjectRayOrigin", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + {OC::RayQuery_CommittedObjectRayDirection, + "RayQuery_CommittedObjectRayDirection", + OCC::RayQuery_StateVector, + "rayQuery_StateVector", + Attribute::ReadOnly, + 1, + {{0x2}}, + {{0x0}}}, // Overloads: f + + // Raytracing object space uint System Values, raytracing tier 1.1 + {OC::GeometryIndex, + "GeometryIndex", + OCC::GeometryIndex, + "geometryIndex", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Inline Ray Query + {OC::RayQuery_CandidateInstanceContributionToHitGroupIndex, + "RayQuery_CandidateInstanceContributionToHitGroupIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::RayQuery_CommittedInstanceContributionToHitGroupIndex, + "RayQuery_CommittedInstanceContributionToHitGroupIndex", + OCC::RayQuery_StateScalar, + "rayQuery_StateScalar", + Attribute::ReadOnly, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Get handle from heap + {OC::AnnotateHandle, + "AnnotateHandle", + OCC::AnnotateHandle, + "annotateHandle", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::CreateHandleFromBinding, + "CreateHandleFromBinding", + OCC::CreateHandleFromBinding, + "createHandleFromBinding", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::CreateHandleFromHeap, + "CreateHandleFromHeap", + OCC::CreateHandleFromHeap, + "createHandleFromHeap", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + + // Unpacking intrinsics + {OC::Unpack4x8, + "Unpack4x8", + OCC::Unpack4x8, + "unpack4x8", + Attribute::ReadNone, + 1, + {{0x60}}, + {{0x0}}}, // Overloads: iw + + // Packing intrinsics + {OC::Pack4x8, + "Pack4x8", + OCC::Pack4x8, + "pack4x8", + Attribute::ReadNone, + 1, + {{0x60}}, + {{0x0}}}, // Overloads: iw + + // Helper Lanes + {OC::IsHelperLane, + "IsHelperLane", + OCC::IsHelperLane, + "isHelperLane", + Attribute::ReadOnly, + 1, + {{0x8}}, + {{0x0}}}, // Overloads: 1 + + // Quad Wave Ops + {OC::QuadVote, + "QuadVote", + OCC::QuadVote, + "quadVote", + Attribute::None, + 1, + {{0x8}}, + {{0x0}}}, // Overloads: 1 + + // Resources - gather + {OC::TextureGatherRaw, + "TextureGatherRaw", + OCC::TextureGatherRaw, + "textureGatherRaw", + Attribute::ReadOnly, + 1, + {{0xe0}}, + {{0x0}}}, // Overloads: wil + + // Resources - sample + {OC::SampleCmpLevel, + "SampleCmpLevel", + OCC::SampleCmpLevel, + "sampleCmpLevel", + Attribute::ReadOnly, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + + // Resources + {OC::TextureStoreSample, + "TextureStoreSample", + OCC::TextureStoreSample, + "textureStoreSample", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + + {OC::Reserved0, + "Reserved0", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved1, + "Reserved1", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved2, + "Reserved2", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved3, + "Reserved3", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved4, + "Reserved4", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved5, + "Reserved5", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved6, + "Reserved6", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved7, + "Reserved7", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved8, + "Reserved8", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved9, + "Reserved9", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved10, + "Reserved10", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::Reserved11, + "Reserved11", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Create/Annotate Node Handles + {OC::AllocateNodeOutputRecords, + "AllocateNodeOutputRecords", + OCC::AllocateNodeOutputRecords, + "allocateNodeOutputRecords", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Get Pointer to Node Record in Address Space 6 + {OC::GetNodeRecordPtr, + "GetNodeRecordPtr", + OCC::GetNodeRecordPtr, + "getNodeRecordPtr", + Attribute::ReadNone, + 1, + {{0x100}}, + {{0x0}}}, // Overloads: u + + // Work Graph intrinsics + {OC::IncrementOutputCount, + "IncrementOutputCount", + OCC::IncrementOutputCount, + "incrementOutputCount", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::OutputComplete, + "OutputComplete", + OCC::OutputComplete, + "outputComplete", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::GetInputRecordCount, + "GetInputRecordCount", + OCC::GetInputRecordCount, + "getInputRecordCount", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + {OC::FinishedCrossGroupSharing, + "FinishedCrossGroupSharing", + OCC::FinishedCrossGroupSharing, + "finishedCrossGroupSharing", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Synchronization + {OC::BarrierByMemoryType, + "BarrierByMemoryType", + OCC::BarrierByMemoryType, + "barrierByMemoryType", + Attribute::NoDuplicate, + 0, + {}, + {}}, // Overloads: v + {OC::BarrierByMemoryHandle, + "BarrierByMemoryHandle", + OCC::BarrierByMemoryHandle, + "barrierByMemoryHandle", + Attribute::NoDuplicate, + 0, + {}, + {}}, // Overloads: v + {OC::BarrierByNodeRecordHandle, + "BarrierByNodeRecordHandle", + OCC::BarrierByNodeRecordHandle, + "barrierByNodeRecordHandle", + Attribute::NoDuplicate, + 0, + {}, + {}}, // Overloads: v + + // Create/Annotate Node Handles + {OC::CreateNodeOutputHandle, + "CreateNodeOutputHandle", + OCC::createNodeOutputHandle, + "createNodeOutputHandle", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::IndexNodeHandle, + "IndexNodeHandle", + OCC::IndexNodeHandle, + "indexNodeHandle", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::AnnotateNodeHandle, + "AnnotateNodeHandle", + OCC::AnnotateNodeHandle, + "annotateNodeHandle", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::CreateNodeInputRecordHandle, + "CreateNodeInputRecordHandle", + OCC::CreateNodeInputRecordHandle, + "createNodeInputRecordHandle", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::AnnotateNodeRecordHandle, + "AnnotateNodeRecordHandle", + OCC::AnnotateNodeRecordHandle, + "annotateNodeRecordHandle", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + + // Work Graph intrinsics + {OC::NodeOutputIsValid, + "NodeOutputIsValid", + OCC::NodeOutputIsValid, + "nodeOutputIsValid", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + {OC::GetRemainingRecursionLevels, + "GetRemainingRecursionLevels", + OCC::GetRemainingRecursionLevels, + "getRemainingRecursionLevels", + Attribute::ReadOnly, + 0, + {}, + {}}, // Overloads: v + + // Comparison Samples + {OC::SampleCmpGrad, + "SampleCmpGrad", + OCC::SampleCmpGrad, + "sampleCmpGrad", + Attribute::ReadOnly, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + {OC::SampleCmpBias, + "SampleCmpBias", + OCC::SampleCmpBias, + "sampleCmpBias", + Attribute::ReadOnly, + 1, + {{0x3}}, + {{0x0}}}, // Overloads: hf + + // Extended Command Information + {OC::StartVertexLocation, + "StartVertexLocation", + OCC::StartVertexLocation, + "startVertexLocation", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + {OC::StartInstanceLocation, + "StartInstanceLocation", + OCC::StartInstanceLocation, + "startInstanceLocation", + Attribute::ReadNone, + 1, + {{0x40}}, + {{0x0}}}, // Overloads: i + + // Inline Ray Query + {OC::AllocateRayQuery2, + "AllocateRayQuery2", + OCC::AllocateRayQuery2, + "allocateRayQuery2", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + {OC::ReservedA0, + "ReservedA0", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedA1, + "ReservedA1", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedA2, + "ReservedA2", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB0, + "ReservedB0", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB1, + "ReservedB1", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB2, + "ReservedB2", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + + // Shader Execution Reordering + {OC::HitObject_MakeMiss, + "HitObject_MakeMiss", + OCC::HitObject_MakeMiss, + "hitObject_MakeMiss", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + {OC::HitObject_MakeNop, + "HitObject_MakeNop", + OCC::HitObject_MakeNop, + "hitObject_MakeNop", + Attribute::ReadNone, + 0, + {}, + {}}, // Overloads: v + + {OC::ReservedB5, + "ReservedB5", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB6, + "ReservedB6", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB7, + "ReservedB7", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB8, + "ReservedB8", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB9, + "ReservedB9", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB10, + "ReservedB10", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB11, + "ReservedB11", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB12, + "ReservedB12", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB13, + "ReservedB13", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB14, + "ReservedB14", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB15, + "ReservedB15", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB16, + "ReservedB16", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB17, + "ReservedB17", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB18, + "ReservedB18", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB19, + "ReservedB19", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB20, + "ReservedB20", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB21, + "ReservedB21", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB22, + "ReservedB22", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB23, + "ReservedB23", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB24, + "ReservedB24", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB25, + "ReservedB25", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB26, + "ReservedB26", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB27, + "ReservedB27", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB28, + "ReservedB28", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB29, + "ReservedB29", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedB30, + "ReservedB30", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC0, + "ReservedC0", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC1, + "ReservedC1", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC2, + "ReservedC2", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC3, + "ReservedC3", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC4, + "ReservedC4", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC5, + "ReservedC5", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC6, + "ReservedC6", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC7, + "ReservedC7", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC8, + "ReservedC8", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::ReservedC9, + "ReservedC9", + OCC::Reserved, + "reserved", + Attribute::None, + 0, + {}, + {}}, // Overloads: v }; // OPCODE-OLOADS:END -const char *OP::m_OverloadTypeName[kNumTypeOverloads] = { - "void", "f16", "f32", "f64", "i1", "i8", - "i16", "i32", "i64", "udt", "obj", // These should not be used -}; +const char *OP::m_OverloadTypeName[TS_BasicCount] = { + "f16", "f32", "f64", "i1", "i8", "i16", "i32", "i64"}; const char *OP::m_NamePrefix = "dx.op."; const char *OP::m_TypePrefix = "dx.types."; @@ -3040,82 +2654,110 @@ unsigned OP::GetTypeSlot(Type *pType) { Type::TypeID T = pType->getTypeID(); switch (T) { case Type::VoidTyID: - return 0; + return TS_Invalid; case Type::HalfTyID: - return 1; + return TS_F16; case Type::FloatTyID: - return 2; + return TS_F32; case Type::DoubleTyID: - return 3; + return TS_F64; case Type::IntegerTyID: { IntegerType *pIT = dyn_cast(pType); unsigned Bits = pIT->getBitWidth(); switch (Bits) { case 1: - return 4; + return TS_I1; case 8: - return 5; + return TS_I8; case 16: - return 6; + return TS_I16; case 32: - return 7; + return TS_I32; case 64: - return 8; + return TS_I64; } llvm_unreachable("Invalid Bits size"); + return TS_Invalid; } case Type::PointerTyID: { pType = cast(pType)->getElementType(); if (pType->isStructTy()) - return kUserDefineTypeSlot; + return TS_UDT; DXASSERT(!pType->isPointerTy(), "pointer-to-pointer type unsupported"); return GetTypeSlot(pType); } case Type::StructTyID: - return kObjectTypeSlot; + // Named struct value (not pointer) indicates a built-in object type. + // Anonymous struct value is used to wrap multi-overload dimensions. + if (cast(pType)->hasName()) + return TS_Object; + else + return TS_Extended; + case Type::VectorTyID: + return TS_Vector; default: break; } - return UINT_MAX; + return TS_Invalid; } const char *OP::GetOverloadTypeName(unsigned TypeSlot) { - DXASSERT(TypeSlot < kUserDefineTypeSlot, "otherwise caller passed OOB index"); + DXASSERT(TypeSlot < TS_BasicCount, "otherwise caller passed OOB index"); return m_OverloadTypeName[TypeSlot]; } -llvm::StringRef OP::GetTypeName(Type *Ty, std::string &str) { +StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl &Storage) { + DXASSERT(!Ty->isVoidTy(), "must not pass void type here"); unsigned TypeSlot = OP::GetTypeSlot(Ty); - if (TypeSlot < kUserDefineTypeSlot) { + if (TypeSlot < TS_BasicCount) { return GetOverloadTypeName(TypeSlot); - } else if (TypeSlot == kUserDefineTypeSlot) { + } else if (TypeSlot == TS_UDT) { if (Ty->isPointerTy()) Ty = Ty->getPointerElementType(); StructType *ST = cast(Ty); return ST->getStructName(); - } else if (TypeSlot == kObjectTypeSlot) { + } else if (TypeSlot == TS_Object) { StructType *ST = cast(Ty); return ST->getStructName(); + } else if (TypeSlot == TS_Vector) { + VectorType *VecTy = cast(Ty); + return (Twine("v") + Twine(VecTy->getNumElements()) + + Twine( + GetOverloadTypeName(OP::GetTypeSlot(VecTy->getElementType())))) + .toStringRef(Storage); + } else if (TypeSlot == TS_Extended) { + DXASSERT(isa(Ty), + "otherwise, extended overload type not wrapped in struct type."); + StructType *ST = cast(Ty); + DXASSERT(ST->getNumElements() <= DXIL::kDxilMaxOloadDims, + "otherwise, extended overload has too many dimensions."); + // Iterate extended slots, recurse, separate with '.' + raw_svector_ostream OS(Storage); + for (unsigned I = 0; I < ST->getNumElements(); ++I) { + if (I > 0) + OS << "."; + SmallVector TempStr; + OS << GetTypeName(ST->getElementType(I), TempStr); + } + return OS.str(); } else { - raw_string_ostream os(str); - Ty->print(os); - os.flush(); - return str; + raw_svector_ostream OS(Storage); + Ty->print(OS); + return OS.str(); } } -llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode, - std::string &funcNameStorage) { +StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode, + SmallVectorImpl &Storage) { if (Ty == Type::getVoidTy(Ty->getContext())) { - funcNameStorage = - (Twine(OP::m_NamePrefix) + Twine(GetOpCodeClassName(opCode))).str(); + return (Twine(OP::m_NamePrefix) + Twine(GetOpCodeClassName(opCode))) + .toStringRef(Storage); } else { - funcNameStorage = - (Twine(OP::m_NamePrefix) + Twine(GetOpCodeClassName(opCode)) + "." + - GetTypeName(Ty, funcNameStorage)) - .str(); + llvm::SmallVector TempStr; + return (Twine(OP::m_NamePrefix) + Twine(GetOpCodeClassName(opCode)) + "." + + GetTypeName(Ty, TempStr)) + .toStringRef(Storage); } - return funcNameStorage; } const char *OP::GetOpCodeName(OpCode opCode) { @@ -3143,13 +2785,42 @@ llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) { } bool OP::IsOverloadLegal(OpCode opCode, Type *pType) { - if (!pType) + if (static_cast(opCode) >= + static_cast(OpCode::NumOpCodes)) return false; - if (opCode == OpCode::NumOpCodes) + if (!pType) return false; - unsigned TypeSlot = GetTypeSlot(pType); - return TypeSlot != UINT_MAX && - m_OpCodeProps[(unsigned)opCode].bAllowOverload[TypeSlot]; + auto &OpProps = m_OpCodeProps[static_cast(opCode)]; + + if (OpProps.NumOverloadDims == 0) + return pType->isVoidTy(); + + // Normalize 1+ overload dimensions into array. + Type *Types[DXIL::kDxilMaxOloadDims] = {pType}; + if (OpProps.NumOverloadDims > 1) { + StructType *ST = dyn_cast(pType); + // Make sure multi-overload is well-formed. + if (!ST || ST->hasName() || + ST->getNumContainedTypes() != OpProps.NumOverloadDims) + return false; + for (unsigned I = 0; I < ST->getNumElements(); ++I) + Types[I] = ST->getElementType(I); + } + + for (unsigned I = 0; I < OpProps.NumOverloadDims; ++I) { + Type *Ty = Types[I]; + unsigned TypeSlot = GetTypeSlot(Ty); + if (!OpProps.AllowedOverloads[I][TypeSlot]) + return false; + if (TypeSlot == TS_Vector) { + unsigned EltTypeSlot = + GetTypeSlot(cast(Ty)->getElementType()); + if (!OpProps.AllowedVectorElements[I][EltTypeSlot]) + return false; + } + } + + return true; } bool OP::CheckOpCodeTable() { @@ -3173,41 +2844,6 @@ bool OP::IsDxilOpFunc(const llvm::Function *F) { return IsDxilOpFuncName(F->getName()); } -bool OP::IsDxilOpTypeName(StringRef name) { - return name.startswith(m_TypePrefix) || name.startswith(m_MatrixTypePrefix); -} - -bool OP::IsDxilOpType(llvm::StructType *ST) { - if (!ST->hasName()) - return false; - StringRef Name = ST->getName(); - return IsDxilOpTypeName(Name); -} - -bool OP::IsDupDxilOpType(llvm::StructType *ST) { - if (!ST->hasName()) - return false; - StringRef Name = ST->getName(); - if (!IsDxilOpTypeName(Name)) - return false; - size_t DotPos = Name.rfind('.'); - if (DotPos == 0 || DotPos == StringRef::npos || Name.back() == '.' || - !isdigit(static_cast(Name[DotPos + 1]))) - return false; - return true; -} - -StructType *OP::GetOriginalDxilOpType(llvm::StructType *ST, llvm::Module &M) { - DXASSERT(IsDupDxilOpType(ST), "else should not call GetOriginalDxilOpType"); - StringRef Name = ST->getName(); - size_t DotPos = Name.rfind('.'); - StructType *OriginalST = M.getTypeByName(Name.substr(0, DotPos)); - DXASSERT(OriginalST, "else name collison without original type"); - DXASSERT(ST->isLayoutIdentical(OriginalST), - "else invalid layout for dxil types"); - return OriginalST; -} - bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I) { const CallInst *CI = dyn_cast(I); if (CI == nullptr) @@ -3297,6 +2933,12 @@ bool OP::IsDxilOpBarrier(OpCode C) { // OPCODE-BARRIER:END } +bool OP::IsDxilOpExtendedOverload(OpCode C) { + if (C >= OpCode::NumOpCodes) + return false; + return m_OpCodeProps[static_cast(C)].NumOverloadDims > 1; +} + static unsigned MaskMemoryTypeFlagsIfAllowed(unsigned memoryTypeFlags, unsigned allowedMask) { // If the memory type is AllMemory, masking inapplicable flags is allowed. @@ -3945,13 +3587,12 @@ void OP::FixOverloadNames() { if (F.isDeclaration() && OP::IsDxilOpFunc(&F) && !F.user_empty()) { CallInst *CI = cast(*F.user_begin()); DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI); + if (!MayHaveNonCanonicalOverload(opCode)) + continue; llvm::Type *Ty = OP::GetOverloadType(opCode, &F); if (!OP::IsOverloadLegal(opCode, Ty)) continue; - if (!isa(Ty) && !isa(Ty)) - continue; - - std::string funcName; + SmallVector funcName; if (OP::ConstructOverloadName(Ty, opCode, funcName) .compare(F.getName()) != 0) F.setName(funcName); @@ -3964,11 +3605,54 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) { m_FunctionToOpClass[F] = opClass; } +bool OP::MayHaveNonCanonicalOverload(OpCode OC) { + if (OC >= OpCode::NumOpCodes) + return false; + const unsigned CheckMask = (1 << TS_UDT) | (1 << TS_Object); + auto &OpProps = m_OpCodeProps[static_cast(OC)]; + for (unsigned I = 0; I < OpProps.NumOverloadDims; ++I) + if ((CheckMask & OpProps.AllowedOverloads[I].SlotMask) != 0) + return true; + return false; +} + +Function *OP::GetOpFunc(OpCode OC, ArrayRef OverloadTypes) { + if (OC >= OpCode::NumOpCodes) + return nullptr; + if (OverloadTypes.size() != + m_OpCodeProps[static_cast(OC)].NumOverloadDims) { + llvm_unreachable("incorrect overload dimensions"); + return nullptr; + } + if (OverloadTypes.size() == 0) { + return GetOpFunc(OC, Type::getVoidTy(m_Ctx)); + } else if (OverloadTypes.size() == 1) { + return GetOpFunc(OC, OverloadTypes[0]); + } + return GetOpFunc(OC, GetExtendedOverloadType(OverloadTypes)); +} + Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { - if (opCode == OpCode::NumOpCodes) + if (opCode >= OpCode::NumOpCodes) return nullptr; if (!pOverloadType) return nullptr; + + auto &OpProps = m_OpCodeProps[static_cast(opCode)]; + if (IsDxilOpExtendedOverload(opCode)) { + // Make sure pOverloadType is well formed for an extended overload. + StructType *ST = dyn_cast(pOverloadType); + DXASSERT(ST != nullptr, + "otherwise, extended overload type is not a struct"); + if (ST == nullptr) + return nullptr; + bool EltCountValid = ST->getNumElements() == OpProps.NumOverloadDims; + DXASSERT(EltCountValid, + "otherwise, incorrect type count for extended overload."); + if (!EltCountValid) + return nullptr; + } + // Illegal overloads are generated and eliminated by DXIL op constant // evaluation for a number of cases where a double overload of an HL intrinsic // that otherwise does not support double is used for literal values, when @@ -3976,7 +3660,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { // Illegal overloads of DXIL intrinsics may survive through to final DXIL, // but these will be caught by the validator, and this is not a regression. - OpCodeClass opClass = m_OpCodeProps[(unsigned)opCode].opCodeClass; + OpCodeClass opClass = OpProps.opCodeClass; Function *&F = m_OpCodeClassCache[(unsigned)opClass].pOverloads[pOverloadType]; if (F != nullptr) { @@ -3984,7 +3668,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { return F; } - vector ArgTypes; // RetType is ArgTypes[0] + SmallVector ArgTypes; // RetType is ArgTypes[0] Type *pETy = pOverloadType; Type *pRes = GetHandleType(); Type *pNodeHandle = GetNodeHandleType(); @@ -4020,7 +3704,10 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { #define A(_x) ArgTypes.emplace_back(_x) #define RRT(_y) A(GetResRetType(_y)) #define CBRT(_y) A(GetCBufferRetType(_y)) -#define VEC4(_y) A(GetVectorType(4, _y)) +#define VEC4(_y) A(GetStructVectorType(4, _y)) + +// Extended Overload types are wrapped in an anonymous struct +#define EXT(_y) A(cast(pOverloadType)->getElementType(_y)) /* hctdb_instrhelp.get_oloads_funcs()*/ switch (opCode) { // return opCode @@ -6066,14 +5753,15 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { pFT = FunctionType::get( ArgTypes[0], ArrayRef(&ArgTypes[1], ArgTypes.size() - 1), false); - std::string funcName; - ConstructOverloadName(pOverloadType, opCode, funcName); + SmallVector FuncStorage; + StringRef FuncName = + ConstructOverloadName(pOverloadType, opCode, FuncStorage); // Try to find existing function with the same name in the module. // This needs to happen after the switch statement that constructs arguments // and return values to ensure that ResRetType is constructed in the // RefreshCache case. - if (Function *existF = m_pModule->getFunction(funcName)) { + if (Function *existF = m_pModule->getFunction(FuncName)) { if (existF->getFunctionType() != pFT) return nullptr; F = existF; @@ -6081,13 +5769,13 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { return F; } - F = cast(m_pModule->getOrInsertFunction(funcName, pFT)); + F = cast(m_pModule->getOrInsertFunction(FuncName, pFT)); UpdateCache(opClass, pOverloadType, F); F->setCallingConv(CallingConv::C); F->addFnAttr(Attribute::NoUnwind); - if (m_OpCodeProps[(unsigned)opCode].FuncAttr != Attribute::None) - F->addFnAttr(m_OpCodeProps[(unsigned)opCode].FuncAttr); + if (OpProps.FuncAttr != Attribute::None) + F->addFnAttr(OpProps.FuncAttr); return F; } @@ -6494,62 +6182,90 @@ Type *OP::GetFourI32Type() const { return m_pFourI32Type; } Type *OP::GetFourI16Type() const { return m_pFourI16Type; } bool OP::IsResRetType(llvm::Type *Ty) { + if (!Ty->isStructTy()) + return false; for (Type *ResTy : m_pResRetType) { if (Ty == ResTy) return true; } + StructType *ST = cast(Ty); + if (!ST->hasName() || ST->getNumContainedTypes() < 2) + return false; + return Ty == GetResRetType(ST->getContainedType(0)); return false; } Type *OP::GetResRetType(Type *pOverloadType) { unsigned TypeSlot = GetTypeSlot(pOverloadType); - if (m_pResRetType[TypeSlot] == nullptr) { - string TypeName("dx.types.ResRet."); - TypeName += GetOverloadTypeName(TypeSlot); - Type *FieldTypes[5] = {pOverloadType, pOverloadType, pOverloadType, - pOverloadType, Type::getInt32Ty(m_Ctx)}; - m_pResRetType[TypeSlot] = - GetOrCreateStructType(m_Ctx, FieldTypes, TypeName, m_pModule); + if (TypeSlot < TS_BasicCount) { + if (m_pResRetType[TypeSlot] == nullptr) { + SmallVector Storage; + StringRef TypeName = + (Twine("dx.types.ResRet.") + Twine(GetOverloadTypeName(TypeSlot))) + .toStringRef(Storage); + Type *FieldTypes[5] = {pOverloadType, pOverloadType, pOverloadType, + pOverloadType, Type::getInt32Ty(m_Ctx)}; + m_pResRetType[TypeSlot] = + GetOrCreateStructType(m_Ctx, FieldTypes, TypeName, m_pModule); + } + return m_pResRetType[TypeSlot]; + } else if (TypeSlot == TS_Vector) { + SmallVector Storage; + VectorType *VecTy = cast(pOverloadType); + StringRef TypeName = + (Twine("dx.types.ResRet.v") + Twine(VecTy->getNumElements()) + + Twine(GetOverloadTypeName(OP::GetTypeSlot(VecTy->getElementType())))) + .toStringRef(Storage); + Type *FieldTypes[2] = {pOverloadType, Type::getInt32Ty(m_Ctx)}; + return GetOrCreateStructType(m_Ctx, FieldTypes, TypeName, m_pModule); } - return m_pResRetType[TypeSlot]; + llvm_unreachable("Invalid overload for GetResRetType"); + return nullptr; } Type *OP::GetCBufferRetType(Type *pOverloadType) { unsigned TypeSlot = GetTypeSlot(pOverloadType); + if (TypeSlot >= TS_BasicCount) { + llvm_unreachable("Invalid overload for GetResRetType"); + return nullptr; + } + if (m_pCBufferRetType[TypeSlot] == nullptr) { DXASSERT(m_LowPrecisionMode != DXIL::LowPrecisionMode::Undefined, "m_LowPrecisionMode must be set before constructing type."); - string TypeName("dx.types.CBufRet."); - TypeName += GetOverloadTypeName(TypeSlot); + SmallVector Storage; + raw_svector_ostream OS(Storage); + OS << "dx.types.CBufRet."; + OS << GetOverloadTypeName(TypeSlot); Type *i64Ty = Type::getInt64Ty(pOverloadType->getContext()); Type *i16Ty = Type::getInt16Ty(pOverloadType->getContext()); if (pOverloadType->isDoubleTy() || pOverloadType == i64Ty) { Type *FieldTypes[2] = {pOverloadType, pOverloadType}; m_pCBufferRetType[TypeSlot] = - GetOrCreateStructType(m_Ctx, FieldTypes, TypeName, m_pModule); + GetOrCreateStructType(m_Ctx, FieldTypes, OS.str(), m_pModule); } else if (!UseMinPrecision() && (pOverloadType->isHalfTy() || pOverloadType == i16Ty)) { - TypeName += ".8"; // dx.types.CBufRet.fp16.8 for buffer of 8 halves + OS << ".8"; // dx.types.CBufRet.f16.8 for buffer of 8 halves Type *FieldTypes[8] = { pOverloadType, pOverloadType, pOverloadType, pOverloadType, pOverloadType, pOverloadType, pOverloadType, pOverloadType, }; m_pCBufferRetType[TypeSlot] = - GetOrCreateStructType(m_Ctx, FieldTypes, TypeName, m_pModule); + GetOrCreateStructType(m_Ctx, FieldTypes, OS.str(), m_pModule); } else { Type *FieldTypes[4] = {pOverloadType, pOverloadType, pOverloadType, pOverloadType}; m_pCBufferRetType[TypeSlot] = - GetOrCreateStructType(m_Ctx, FieldTypes, TypeName, m_pModule); + GetOrCreateStructType(m_Ctx, FieldTypes, OS.str(), m_pModule); } } return m_pCBufferRetType[TypeSlot]; } -Type *OP::GetVectorType(unsigned numElements, Type *pOverloadType) { +Type *OP::GetStructVectorType(unsigned numElements, Type *pOverloadType) { if (numElements == 4) { if (pOverloadType == Type::getInt32Ty(pOverloadType->getContext())) { return m_pFourI32Type; @@ -6561,6 +6277,10 @@ Type *OP::GetVectorType(unsigned numElements, Type *pOverloadType) { return nullptr; } +StructType *OP::GetExtendedOverloadType(ArrayRef OverloadTypes) { + return StructType::get(m_Ctx, OverloadTypes); +} + //------------------------------------------------------------------------------ // // LLVM utility methods. diff --git a/lib/DxilValidation/DxilValidation.cpp b/lib/DxilValidation/DxilValidation.cpp index 4622256dfe..cac074adc3 100644 --- a/lib/DxilValidation/DxilValidation.cpp +++ b/lib/DxilValidation/DxilValidation.cpp @@ -2037,7 +2037,7 @@ static void ValidateExternalFunction(Function *F, ValidationContext &ValCtx) { ValCtx.EmitInstrError(CI, ValidationRule::InstrOload); continue; } - dxilFunc = hlslOP->GetOpFunc(dxilOpcode, Ty->getScalarType()); + dxilFunc = hlslOP->GetOpFunc(dxilOpcode, Ty); } if (!dxilFunc) { @@ -2109,17 +2109,20 @@ static bool IsDxilBuiltinStructType(StructType *ST, hlsl::OP *hlslOP) { return true; unsigned EltNum = ST->getNumElements(); + Type *EltTy = ST->getElementType(0); switch (EltNum) { case 2: + // Check if it's a native vector resret. + if (EltTy->isVectorTy()) + return ST == hlslOP->GetResRetType(EltTy); + LLVM_FALLTHROUGH; case 4: - case 8: { // 2 for doubles, 8 for halfs. - Type *EltTy = ST->getElementType(0); + case 8: // 2 for doubles, 8 for halfs. return ST == hlslOP->GetCBufferRetType(EltTy); - } break; - case 5: { - Type *EltTy = ST->getElementType(0); + break; + case 5: return ST == hlslOP->GetResRetType(EltTy); - } break; + break; default: return false; } diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index e32ab1915a..57faee2fb2 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -37,6 +37,30 @@ "array_local_ldst", ] +# These are the valid overload type characters for DXIL instructions. +# - "v" is for void, and can only be used alone. +# - "u" is for user defined type (UDT), and is mutually exclusive with the other +# types. +# - "o" is for an HLSL object type (e.g. Texture, Sampler, etc.), and is +# mutually exclusive with the other types. +# - "<" is for vector overloads, and may be followed by a set of supported +# component types. +# - If "<" is not followed by any component types, any preceding scalar types +# are used. +# - Vector component types are captured into a separate list during +# processing. +# - "," is used to separate multiple overload dimensions. +# - When used, only $x0, $x1, etc. are supported for overloaded parameter +# types. +# dxil_all_user_oload_chars must be kept in sync with the indices in +# hlsl::OP::TypeSlot in DxilOperations.h. +dxil_all_user_oload_chars = "hfd18wiluo<" +dxil_scalar_oload_chars = "hfd18wil" + +# Maximum number of overload dimensions supported through the extended overload +# in DXIL instructions. +dxil_max_overload_dims = 2 + class db_dxil_enum_value(object): "A representation for a value in an enumeration type" @@ -81,6 +105,7 @@ def __init__(self, name, **kwargs): self.ops = [] # the operands that this instruction takes self.is_allowed = True # whether this instruction is allowed in a DXIL program self.oload_types = "" # overload types if applicable + # Always call process_oload_types() after setting oload_types. self.fn_attr = "" # attribute shorthands: rn=does not access memory,ro=only reads from memory, self.is_deriv = False # whether this is some kind of derivative self.is_gradient = False # whether this requires a gradient calculation @@ -98,6 +123,9 @@ def __init__(self, name, **kwargs): self.is_reserved = self.dxil_class == "Reserved" self.shader_model_translated = () # minimum shader model required with translation by linker self.props = {} # extra properties + self.num_oloads = 0 # number of overloads for this instruction + if self.is_dxil_op: + self.process_oload_types() def __str__(self): return self.name @@ -105,6 +133,127 @@ def __str__(self): def fully_qualified_name(self): return "{}::{}".format(self.fully_qualified_name_prefix, self.name) + def process_oload_types(self): + if type(self.oload_types) is not str: + raise ValueError( + f"overload for '{self.name}' should be a string - use empty if n/a" + ) + # Early out for LLVM instructions + if not self.is_dxil_op: + return + + self.num_oloads = 0 + + # Early out for void overloads. + if self.oload_types == "v": + return + + if self.oload_types == "": + raise ValueError( + f"overload for '{self.name}' should not be empty - use void if n/a" + ) + if "v" in self.oload_types: + raise ValueError( + f"void overload should be exclusive to other types for '({self.name})'" + ) + + # Process oload_types for extended and vector overloads. + # Contrived example: "hf<, dxil_max_overload_dims: + raise ValueError( + "Too many overload dimensions for DXIL op " + f"{self.name}: '{self.oload_types}'" + ) + + def check_duplicate_overloads(oloads): + if len(oloads) != len(set(oloads)): + raise ValueError( + "Duplicate overload types specified for DXIL op " + f"{self.name}: '{oloads}' in '{self.oload_types}'" + ) + + def check_overload_chars(oloads, valid_chars): + invalid_chars = set(oloads).difference(set(valid_chars)) + if invalid_chars: + raise ValueError( + "Invalid overload type character(s) used for DXIL op " + f"{self.name}: '{invalid_chars}' in '{oloads}' from " + f"'{self.oload_types}'" + ) + + for n, oloads in enumerate(oload_types): + if len(oloads) == 0: + raise ValueError( + f"Invalid empty overload type for DXIL op " + f"{self.name}: '{self.oload_types}'" + ) + check_overload_chars(oloads, dxil_all_user_oload_chars) + + # split at vector for component overloads, if vector specified + # without following components, use the scalar overloads that + # precede the vector character. + split = oloads.split("<") + if len(split) == 1: + # No vector overload. + continue + elif len(split) != 2: + raise ValueError( + f"Invalid vector overload for DXIL op {self.name}: " + f"{oloads} in '{self.oload_types}'" + ) + + # Split into scalar and vector component overloads. + scalars, vector_oloads = split + check_duplicate_overloads(scalars) + if not vector_oloads: + vector_oloads = scalars + else: + check_duplicate_overloads(vector_oloads) + if not vector_oloads: + raise ValueError( + "No scalar overload types provided with vector overload " + f"for DXIL op {self.name}: '{self.oload_types}'" + ) + check_overload_chars(vector_oloads, dxil_scalar_oload_chars) + oload_types[n] = scalars + "<" + vector_oloads + # Reconstruct overload string with default vector overloads. + self.oload_types = ",".join(oload_types) + self.check_extended_oload_ops() + + def check_extended_oload_ops(self): + "Ensure ops has sequential extended overload references with $x0, $x1, etc." + if self.num_oloads < 2: + return + next_oload_idx = 0 + for i in self.ops: + if i.llvm_type.startswith("$x"): + if i.llvm_type != "$x" + str(next_oload_idx): + raise ValueError( + "Extended overloads are not sequentially referenced in " + f"DXIL op {self.name}: {i.llvm_type} != $x{next_oload_idx}" + ) + next_oload_idx += 1 + if next_oload_idx != self.num_oloads: + raise ValueError( + "Extended overloads are not referenced for all overload " + f"dimensions in DXIL op {self.name}: {next_oload_idx} != " + f"{self.num_oloads}" + ) + class db_dxil_metadata(object): "A representation for a metadata record" @@ -477,9 +626,7 @@ def populate_categories_and_models(self): "closesthit", ) for i in "GeometryIndex".split(","): - self.name_idx[ - i - ].category = ( + self.name_idx[i].category = ( "Raytracing object space uint System Values, raytracing tier 1.1" ) self.name_idx[i].shader_model = 6, 5 @@ -574,9 +721,7 @@ def populate_categories_and_models(self): self.name_idx[i].shader_model = 6, 3 self.name_idx[i].shader_stages = ("library", "intersection") for i in "CreateHandleForLib".split(","): - self.name_idx[ - i - ].category = ( + self.name_idx[i].category = ( "Library create handle from resource struct (like HL intrinsic)" ) self.name_idx[i].shader_model = 6, 3 @@ -5652,18 +5797,6 @@ def UFI(name, **mappings): ) for i in self.instr: self.verify_dense(i.ops, lambda x: x.pos, lambda x: i.name) - for i in self.instr: - if i.is_dxil_op: - assert i.oload_types != "", ( - "overload for DXIL operation %s should not be empty - use void if n/a" - % (i.name) - ) - assert i.oload_types == "v" or i.oload_types.find("v") < 0, ( - "void overload should be exclusive to other types (%s)" % i.name - ) - assert ( - type(i.oload_types) is str - ), "overload for %s should be a string - use empty if n/a" % (i.name) # Verify that all operations in each class have the same signature. import itertools @@ -8391,6 +8524,7 @@ def __init__( self.template_id_idx = template_id_idx # Template ID numeric value self.component_id_idx = component_id_idx # Component ID numeric value + class db_hlsl(object): "A database of HLSL language data" diff --git a/utils/hct/hctdb_instrhelp.py b/utils/hct/hctdb_instrhelp.py index 4580e6c12c..aeb32d027e 100644 --- a/utils/hct/hctdb_instrhelp.py +++ b/utils/hct/hctdb_instrhelp.py @@ -40,8 +40,10 @@ def get_hlsl_opcode_data(): g_hlsl_opcode_data = {} return g_hlsl_opcode_data + g_db_hlsl = None + def get_db_hlsl(): global g_db_hlsl if g_db_hlsl is None: @@ -51,6 +53,11 @@ def get_db_hlsl(): return g_db_hlsl +def get_max_oload_dims(): + db = get_db_dxil() + return f"const unsigned kDxilMaxOloadDims = {dxil_max_overload_dims};" + + def format_comment(prefix, val): "Formats a value with a line-comment prefix." result = "" @@ -507,26 +514,15 @@ def print_opfunc_props(self): OP=self.OP ) ) - print( - "// OpCode OpCode name, OpCodeClass OpCodeClass name, void, h, f, d, i1, i8, i16, i32, i64, udt, obj, function attribute" - ) - # Example formatted string: - # { OC::TempRegLoad, "TempRegLoad", OCC::TempRegLoad, "tempRegLoad", false, true, true, false, true, false, true, true, false, Attribute::ReadOnly, }, - # 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789 - # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 last_category = None - # overload types are a string of (v)oid, (h)alf, (f)loat, (d)ouble, (1)-bit, (8)-bit, (w)ord, (i)nt, (l)ong, u(dt) - f = lambda i, c: "true" if i.oload_types.find(c) >= 0 else "false" lower_exceptions = { "CBufferLoad": "cbufferLoad", "CBufferLoadLegacy": "cbufferLoadLegacy", "GSInstanceID": "gsInstanceID", } - lower_fn = ( - lambda t: lower_exceptions[t] - if t in lower_exceptions - else t[:1].lower() + t[1:] + lower_fn = lambda t: ( + lower_exceptions[t] if t in lower_exceptions else t[:1].lower() + t[1:] ) attr_dict = { "": "None", @@ -537,35 +533,43 @@ def print_opfunc_props(self): "nr": "NoReturn", "wv": "None", } - attr_fn = lambda i: "Attribute::" + attr_dict[i.fn_attr] + "," + attr_fn = lambda i: "Attribute::" + attr_dict[i.fn_attr] + oload_to_mask = lambda oload: sum( + [1 << dxil_all_user_oload_chars.find(c) for c in oload] + ) + oloads_fn = lambda oloads: ( + "{" + ",".join(["{0x%x}" % m for m in oloads]) + "}" + ) for i in self.instrs: if last_category != i.category: if last_category != None: print("") - print( - " // {category:118} void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute".format( - category=i.category - ) - ) + if not i.is_reserved: + print(f" // {i.category}") last_category = i.category + scalar_masks = [] + vector_masks = [] + if i.num_oloads > 0: + for n, o in enumerate(i.oload_types.split(",")): + v = o.split("<") + scalar_masks.append(oload_to_mask(v[0] + "<" if len(v) > 1 else v[0])) + vector_masks.append(oload_to_mask(v[1]) if len(v) > 1 else 0) print( - " {{ {OC}::{name:24} {quotName:27} {OCC}::{className:25} {classNameQuot:28} {{{v:>6},{h:>6},{f:>6},{d:>6},{b:>6},{e:>6},{w:>6},{i:>6},{l:>6},{u:>6},{o:>6}}}, {attr:20} }},".format( + ( + " {{ {OC}::{name:24} {quotName:27} {OCC}::{className:25} " + + "{classNameQuot:28} {attr:20}, {num_oloads}, " + + "{scalar_masks:16}, {vector_masks:16} }}, " + + "// Overloads: {oloads}" + ).format( name=i.name + ",", quotName='"' + i.name + '",', className=i.dxil_class + ",", classNameQuot='"' + lower_fn(i.dxil_class) + '",', - v=f(i, "v"), - h=f(i, "h"), - f=f(i, "f"), - d=f(i, "d"), - b=f(i, "1"), - e=f(i, "8"), - w=f(i, "w"), - i=f(i, "i"), - l=f(i, "l"), - u=f(i, "u"), - o=f(i, "o"), attr=attr_fn(i), + num_oloads=i.num_oloads, + scalar_masks=oloads_fn(scalar_masks), + vector_masks=oloads_fn(vector_masks), + oloads=i.oload_types, OC=self.OC, OCC=self.OCC, ) @@ -621,6 +625,9 @@ def print_opfunc_table(self): "nodeproperty": "A(nodeProperty);", "noderecordproperty": "A(nodeRecordProperty);", "hit_object": "A(pHit);", + # Extended overload slots, extend as needed: + "$x0": "EXT(0);", + "$x1": "EXT(1);", } last_category = None for i in self.instrs: @@ -651,14 +658,24 @@ def print_opfunc_oload_type(self): obj_ty = "obj" vec_ty = "$vec" gsptr_ty = "$gsptr" + extended_ty = "$x" last_category = None index_dict = collections.OrderedDict() ptr_index_dict = collections.OrderedDict() single_dict = collections.OrderedDict() + # extended_dict collects overloads with multiple overload types + # grouped by the set of overload parameter indices. + extended_dict = collections.OrderedDict() struct_list = [] + extended_list = [] for instr in self.instrs: + if instr.num_oloads > 1: + # Process extended overloads separately. + extended_list.append(instr) + continue + ret_ty = instr.ops[0].llvm_type # Skip case return type is overload type if ret_ty == elt_ty: @@ -730,8 +747,7 @@ def print_opfunc_oload_type(self): "i": "IntegerType::get(Ctx, 32)", "l": "IntegerType::get(Ctx, 64)", "v": "Type::getVoidTy(Ctx)", - "u": "Type::getInt32PtrTy(Ctx)", - "o": "Type::getInt32PtrTy(Ctx)", + # No other types should be referenced here. } assert ty in type_code_texts, "llvm type %s is unknown" % (ty) ty_code = type_code_texts[ty] @@ -791,6 +807,61 @@ def print_opfunc_oload_type(self): line = line + "}" print(line) + for instr in extended_list: + # Collect indices for overloaded return and types, make a tuple of + # indices the key, and add the opcode to a list of opcodes for that + # key. Indices start with 0 for return type, and 1 for the first + # function parameter, which is the DXIL OpCode. + indices = [] + for index, op in enumerate(instr.ops): + # Skip dxil opcode. + if op.pos == 1: + continue + + op_type = op.llvm_type + if op_type.startswith(extended_ty): + try: + extended_index = int(op_type[2:]) + except: + raise ValueError( + "Error parsing extended operand type " + + f"'{op_type}' for DXIL op '{instr.name}'" + ) + if extended_index != len(indices): + raise ValueError( + f"'$x{extended_index}' is not in sequential " + + f"order for DXIL op '{instr.name}'" + ) + indices.append(op.pos) + + if len(indices) != instr.num_oloads: + raise ValueError( + f"DXIL op {instr.name}: extended overload count " + + "mismatches the number of overload types" + ) + extended_dict.setdefault(tuple(indices), []).append(instr.name) + + def get_type_at_index(index): + if index == 0: + return "FT->getReturnType()" + return f"FT->getParamType({index - 1})" + + for index_tuple, opcodes in extended_dict.items(): + line = "" + for opcode in opcodes: + line = line + f"case OpCode::{opcode}:\n" + if index_tuple[-1] > 0: + line += ( + f" if (FT->getNumParams() < {index_tuple[-1]})\n" + + " return nullptr;\n" + ) + line += ( + " return llvm::StructType::get(Ctx, {" + + ", ".join([get_type_at_index(index) for index in index_tuple]) + + "});\n" + ) + print(line) + class db_valfns_gen: "A generator of validation functions." @@ -1599,6 +1670,7 @@ def get_highest_released_shader_model(): ) return result + def get_highest_shader_model(): result = """static const unsigned kHighestMajor = %d; static const unsigned kHighestMinor = %d;""" % ( @@ -1607,6 +1679,7 @@ def get_highest_shader_model(): ) return result + def get_dxil_version_minor(): return "const unsigned kDxilMinor = %d;" % highest_minor From 9002f713e2d74a437918dfe573fdadb7e3520edf Mon Sep 17 00:00:00 2001 From: Greg Roth Date: Mon, 24 Mar 2025 15:41:13 -0600 Subject: [PATCH 02/31] Allow lowering of vector load stores Enable native vector DXIL intrinsic overload for vector load/store Add a new native vector overload type to DXIL intrinsics and the corresponding generation. Add new raw buffer vector load/store intrinsics that use that overload type. Generate native vector raw buffers load/stores When the loaded/stored type is a vector of more than 1 element, the shader model is 6.9 or higher, and the operation is on a raw buffer, enable the generation of a native vector raw buffer load or store. Add validation of vector load stores --- include/dxc/DXIL/DxilConstants.h | 13 +- include/dxc/DXIL/DxilInstructions.h | 99 ++- include/dxc/HLSL/DxilGenerationPass.h | 2 + include/dxc/Test/WEXAdapter.h | 4 +- lib/DXIL/DxilOperations.cpp | 42 +- lib/DxilValidation/DxilValidation.cpp | 13 +- lib/HLSL/CMakeLists.txt | 1 + lib/HLSL/DxilLinker.cpp | 8 + lib/HLSL/DxilScalarizeVectorLoadStores.cpp | 241 ++++++ lib/HLSL/HLOperationLower.cpp | 68 +- tools/clang/lib/Sema/SemaHLSL.cpp | 27 +- .../intrinsics/buffer-load-stores-sm69.hlsl | 90 +++ .../hlsl/types/longvec-operators-cs.hlsl | 708 ++++++++++++++++++ .../hlsl/types/longvec-operators.hlsl | 18 - .../DXILValidation/vector-validation.hlsl | 19 + .../linker/resources/preserve_sb_types.hlsl | 4 +- .../LitDXILValidation/vector-validation.ll | 78 ++ utils/hct/hctdb.py | 87 ++- 18 files changed, 1456 insertions(+), 66 deletions(-) create mode 100644 lib/HLSL/DxilScalarizeVectorLoadStores.cpp create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/buffer-load-stores-sm69.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators-cs.hlsl create mode 100644 tools/clang/test/DXILValidation/vector-validation.hlsl create mode 100644 tools/clang/test/LitDXILValidation/vector-validation.ll diff --git a/include/dxc/DXIL/DxilConstants.h b/include/dxc/DXIL/DxilConstants.h index 447728300b..4bf98e3771 100644 --- a/include/dxc/DXIL/DxilConstants.h +++ b/include/dxc/DXIL/DxilConstants.h @@ -487,6 +487,9 @@ inline bool IsFeedbackTexture(DXIL::ResourceKind ResourceKind) { // Enumeration for operations specified by DXIL enum class OpCode : unsigned { // + RawBufferVectorLoad = 303, // reads from a raw buffer and structured buffer + RawBufferVectorStore = + 304, // writes to a RWByteAddressBuffer or RWStructuredBuffer Reserved0 = 226, // Reserved Reserved1 = 227, // Reserved Reserved10 = 236, // Reserved @@ -1043,8 +1046,9 @@ enum class OpCode : unsigned { NumOpCodes_Dxil_1_6 = 222, NumOpCodes_Dxil_1_7 = 226, NumOpCodes_Dxil_1_8 = 258, + NumOpCodes_Dxil_1_9 = 305, - NumOpCodes = 303 // exclusive last value of enumeration + NumOpCodes = 305 // exclusive last value of enumeration }; // OPCODE-ENUM:END @@ -1056,6 +1060,8 @@ enum class OpCode : unsigned { // Groups for DXIL operations with equivalent function templates enum class OpCodeClass : unsigned { // + RawBufferVectorLoad, + RawBufferVectorStore, Reserved, // Amplification shader instructions @@ -1355,8 +1361,9 @@ enum class OpCodeClass : unsigned { NumOpClasses_Dxil_1_6 = 149, NumOpClasses_Dxil_1_7 = 153, NumOpClasses_Dxil_1_8 = 174, + NumOpClasses_Dxil_1_9 = 179, - NumOpClasses = 177 // exclusive last value of enumeration + NumOpClasses = 179 // exclusive last value of enumeration }; // OPCODECLASS-ENUM:END @@ -1424,7 +1431,7 @@ const unsigned kRawBufferStoreVal1OpIdx = 5; const unsigned kRawBufferStoreVal2OpIdx = 6; const unsigned kRawBufferStoreVal3OpIdx = 7; const unsigned kRawBufferStoreMaskOpIdx = 8; -const unsigned kRawBufferStoreAlignmentOpIdx = 8; +const unsigned kRawBufferStoreAlignmentOpIdx = 9; // TextureStore. const unsigned kTextureStoreHandleOpIdx = 1; diff --git a/include/dxc/DXIL/DxilInstructions.h b/include/dxc/DXIL/DxilInstructions.h index f8d9ae77f3..5d78336d2a 100644 --- a/include/dxc/DXIL/DxilInstructions.h +++ b/include/dxc/DXIL/DxilInstructions.h @@ -5079,15 +5079,15 @@ struct DxilInst_RawBufferLoad { bool requiresUniformInputs() const { return false; } // Operand indexes enum OperandIdx { - arg_srv = 1, + arg_buf = 1, arg_index = 2, arg_elementOffset = 3, arg_mask = 4, arg_alignment = 5, }; // Accessors - llvm::Value *get_srv() const { return Instr->getOperand(1); } - void set_srv(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_buf() const { return Instr->getOperand(1); } + void set_buf(llvm::Value *val) { Instr->setOperand(1, val); } llvm::Value *get_index() const { return Instr->getOperand(2); } void set_index(llvm::Value *val) { Instr->setOperand(2, val); } llvm::Value *get_elementOffset() const { return Instr->getOperand(3); } @@ -8923,5 +8923,98 @@ struct DxilInst_HitObject_MakeNop { // Metadata bool requiresUniformInputs() const { return false; } }; + +/// This instruction reads from a raw buffer and structured buffer +struct DxilInst_RawBufferVectorLoad { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_RawBufferVectorLoad(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::RawBufferVectorLoad); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (5 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_buf = 1, + arg_index = 2, + arg_elementOffset = 3, + arg_alignment = 4, + }; + // Accessors + llvm::Value *get_buf() const { return Instr->getOperand(1); } + void set_buf(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_index() const { return Instr->getOperand(2); } + void set_index(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_elementOffset() const { return Instr->getOperand(3); } + void set_elementOffset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_alignment() const { return Instr->getOperand(4); } + void set_alignment(llvm::Value *val) { Instr->setOperand(4, val); } + int32_t get_alignment_val() const { + return (int32_t)(llvm::dyn_cast(Instr->getOperand(4)) + ->getZExtValue()); + } + void set_alignment_val(int32_t val) { + Instr->setOperand(4, llvm::Constant::getIntegerValue( + llvm::IntegerType::get(Instr->getContext(), 32), + llvm::APInt(32, (uint64_t)val))); + } +}; + +/// This instruction writes to a RWByteAddressBuffer or RWStructuredBuffer +struct DxilInst_RawBufferVectorStore { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_RawBufferVectorStore(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::RawBufferVectorStore); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_uav = 1, + arg_index = 2, + arg_elementOffset = 3, + arg_value0 = 4, + arg_alignment = 5, + }; + // Accessors + llvm::Value *get_uav() const { return Instr->getOperand(1); } + void set_uav(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_index() const { return Instr->getOperand(2); } + void set_index(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_elementOffset() const { return Instr->getOperand(3); } + void set_elementOffset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_value0() const { return Instr->getOperand(4); } + void set_value0(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_alignment() const { return Instr->getOperand(5); } + void set_alignment(llvm::Value *val) { Instr->setOperand(5, val); } + int32_t get_alignment_val() const { + return (int32_t)(llvm::dyn_cast(Instr->getOperand(5)) + ->getZExtValue()); + } + void set_alignment_val(int32_t val) { + Instr->setOperand(5, llvm::Constant::getIntegerValue( + llvm::IntegerType::get(Instr->getContext(), 32), + llvm::APInt(32, (uint64_t)val))); + } +}; // INSTR-HELPER:END } // namespace hlsl diff --git a/include/dxc/HLSL/DxilGenerationPass.h b/include/dxc/HLSL/DxilGenerationPass.h index c77ddab3d0..9df93e9232 100644 --- a/include/dxc/HLSL/DxilGenerationPass.h +++ b/include/dxc/HLSL/DxilGenerationPass.h @@ -81,6 +81,7 @@ ModulePass *createResumePassesPass(); FunctionPass *createMatrixBitcastLowerPass(); ModulePass *createDxilCleanupAddrSpaceCastPass(); ModulePass *createDxilRenameResourcesPass(); +ModulePass *createDxilScalarizeVectorLoadStoresPass(); void initializeDxilLowerCreateHandleForLibPass(llvm::PassRegistry &); void initializeDxilAllocateResourcesForLibPass(llvm::PassRegistry &); @@ -115,6 +116,7 @@ void initializeResumePassesPass(llvm::PassRegistry &); void initializeMatrixBitcastLowerPassPass(llvm::PassRegistry &); void initializeDxilCleanupAddrSpaceCastPass(llvm::PassRegistry &); void initializeDxilRenameResourcesPass(llvm::PassRegistry &); +void initializeDxilScalarizeVectorLoadStoresPass(llvm::PassRegistry &); ModulePass *createDxilValidateWaveSensitivityPass(); void initializeDxilValidateWaveSensitivityPass(llvm::PassRegistry &); diff --git a/include/dxc/Test/WEXAdapter.h b/include/dxc/Test/WEXAdapter.h index f180c01a99..e8263eb576 100644 --- a/include/dxc/Test/WEXAdapter.h +++ b/include/dxc/Test/WEXAdapter.h @@ -178,8 +178,8 @@ inline void EndGroup(const wchar_t *name) { wprintf(L"END TEST(S): <%ls>\n", name); } inline void Comment(const wchar_t *msg) { - fputws(msg, stdout); - fputwc(L'\n', stdout); + fputws(msg, stderr); + fputwc(L'\n', stderr); } inline void Error(const wchar_t *msg) { fputws(msg, stderr); diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index a2b0432ce3..c422d86593 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -2633,6 +2633,22 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = { 0, {}, {}}, // Overloads: v + {OC::RawBufferVectorLoad, + "RawBufferVectorLoad", + OCC::RawBufferVectorLoad, + "rawBufferVectorLoad", + Attribute::ReadOnly, + 1, + {{0x4e7}}, + {{0xe7}}}, // Overloads: hfwidlgetNumParams() <= 4) return nullptr; return FT->getParamType(4); @@ -6135,7 +6170,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::TextureGatherRaw: case OpCode::SampleCmpLevel: case OpCode::SampleCmpGrad: - case OpCode::SampleCmpBias: { + case OpCode::SampleCmpBias: + case OpCode::RawBufferVectorLoad: { StructType *ST = cast(Ty); return ST->getElementType(0); } diff --git a/lib/DxilValidation/DxilValidation.cpp b/lib/DxilValidation/DxilValidation.cpp index cac074adc3..fd05495e86 100644 --- a/lib/DxilValidation/DxilValidation.cpp +++ b/lib/DxilValidation/DxilValidation.cpp @@ -1487,7 +1487,7 @@ static void ValidateResourceDxilOp(CallInst *CI, DXIL::OpCode opcode, DXIL::ComponentType compTy; DXIL::ResourceClass resClass; DXIL::ResourceKind resKind = - GetResourceKindAndCompTy(bufLd.get_srv(), compTy, resClass, ValCtx); + GetResourceKindAndCompTy(bufLd.get_buf(), compTy, resClass, ValCtx); if (resClass != DXIL::ResourceClass::SRV && resClass != DXIL::ResourceClass::UAV) { @@ -1496,12 +1496,9 @@ static void ValidateResourceDxilOp(CallInst *CI, DXIL::OpCode opcode, Value *offset = bufLd.get_elementOffset(); Value *align = bufLd.get_alignment(); - unsigned alignSize = 0; if (!isa(align)) { ValCtx.EmitInstrError(CI, ValidationRule::InstrCoordinateCountForRawTypedBuf); - } else { - alignSize = bufLd.get_alignment_val(); } switch (resKind) { case DXIL::ResourceKind::RawBuffer: @@ -1551,12 +1548,9 @@ static void ValidateResourceDxilOp(CallInst *CI, DXIL::OpCode opcode, Value *offset = bufSt.get_elementOffset(); Value *align = bufSt.get_alignment(); - unsigned alignSize = 0; if (!isa(align)) { ValCtx.EmitInstrError(CI, ValidationRule::InstrCoordinateCountForRawTypedBuf); - } else { - alignSize = bufSt.get_alignment_val(); } switch (resKind) { case DXIL::ResourceKind::RawBuffer: @@ -1683,7 +1677,9 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI, case DXIL::OpCode::CBufferLoad: case DXIL::OpCode::CBufferLoadLegacy: case DXIL::OpCode::RawBufferLoad: + case DXIL::OpCode::RawBufferVectorLoad: case DXIL::OpCode::RawBufferStore: + case DXIL::OpCode::RawBufferVectorStore: ValidateResourceDxilOp(CI, opcode, ValCtx); break; // Input output. @@ -2714,8 +2710,7 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { } // Instructions must be allowed. - if (!IsLLVMInstructionAllowed(I) || - !IsLLVMInstructionAllowedForShaderModel(I, ValCtx)) { + if (!IsLLVMInstructionAllowed(I) || !IsLLVMInstructionAllowedForShaderModel(I, ValCtx)) { if (!IsLLVMInstructionAllowedForLib(I, ValCtx)) { ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed); continue; diff --git a/lib/HLSL/CMakeLists.txt b/lib/HLSL/CMakeLists.txt index 947fc4c14f..21bb9523a7 100644 --- a/lib/HLSL/CMakeLists.txt +++ b/lib/HLSL/CMakeLists.txt @@ -25,6 +25,7 @@ add_llvm_library(LLVMHLSL DxilNoops.cpp DxilPreserveAllOutputs.cpp DxilRenameResourcesPass.cpp + DxilScalarizeVectorLoadStores.cpp DxilSimpleGVNHoist.cpp DxilSignatureValidation.cpp DxilTargetLowering.cpp diff --git a/lib/HLSL/DxilLinker.cpp b/lib/HLSL/DxilLinker.cpp index ca343662ab..007e21ff19 100644 --- a/lib/HLSL/DxilLinker.cpp +++ b/lib/HLSL/DxilLinker.cpp @@ -1247,6 +1247,10 @@ void DxilLinkJob::RunPreparePass(Module &M) { PM.add(createDxilReinsertNopsPass()); PM.add(createAlwaysInlinerPass(/*InsertLifeTime*/ false)); + // Need to lower vector load/stores to scalars here? + // If we need SROA and dynamicindexvector to array, it has to be here. + PM.add(createDxilScalarizeVectorLoadStoresPass()); + // Remove unused functions. PM.add(createDxilDeadFunctionEliminationPass()); @@ -1272,6 +1276,10 @@ void DxilLinkJob::RunPreparePass(Module &M) { // Clean up vectors, and run mem2reg again PM.add(createScalarizerPass()); + + // Need dxilelimvector for pre 6.9 + //PM.add(createDxilEliminateVectorPass()); + PM.add(createPromoteMemoryToRegisterPass()); PM.add(createSimplifyInstPass()); diff --git a/lib/HLSL/DxilScalarizeVectorLoadStores.cpp b/lib/HLSL/DxilScalarizeVectorLoadStores.cpp new file mode 100644 index 0000000000..5b5c43875e --- /dev/null +++ b/lib/HLSL/DxilScalarizeVectorLoadStores.cpp @@ -0,0 +1,241 @@ +/////////////////////////////////////////////////////////////////////////////// +// // +// DxilScalarizeVectorLoadStores.cpp // +// Copyright (C) Microsoft Corporation. All rights reserved. // +// This file is distributed under the University of Illinois Open Source // +// License. See LICENSE.TXT for details. // +// // +// Lowers native vector load stores to potentially multiple scalar calls. // +// // +/////////////////////////////////////////////////////////////////////////////// + +#include "dxc/DXIL/DxilInstructions.h" +#include "dxc/DXIL/DxilModule.h" +#include "dxc/HLSL/DxilGenerationPass.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" + +using namespace llvm; +using namespace hlsl; + +class DxilScalarizeVectorLoadStores : public ModulePass { +private: + DxilModule *m_DM; + + void scalarizeVectorLoad(hlsl::OP *HlslOP, const DataLayout &DL, CallInst *CI); + void scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL, CallInst *CI); + +public: + static char ID; // Pass identification, replacement for typeid + explicit DxilScalarizeVectorLoadStores() : ModulePass(ID) {} + + StringRef getPassName() const override { + return "DXIL scalarize vector load/stores"; + } + + bool runOnModule(Module &M) override { + DxilModule &DM = M.GetOrCreateDxilModule(); + m_DM = &DM; + + // Shader Model 6.9 allows native vectors and doesn't need this pass. + if (DM.GetShaderModel()->IsSM69Plus()) + return false; + + bool Changed = false; + + hlsl::OP *HlslOP = DM.GetOP(); + auto &LoadList = HlslOP->GetOpFuncList(DXIL::OpCode::RawBufferVectorLoad); + for (auto FIt = LoadList.begin(), FEnd = LoadList.end(); FIt != FEnd; FIt++) { + Function *F = FIt->second; + if (!F) + continue; + for (auto U = F->user_begin(), E = F->user_end(); U != E;) { + CallInst *CI = cast(*(U++)); + scalarizeVectorLoad(HlslOP, M.getDataLayout(), CI); + Changed = true; + } + F->eraseFromParent(); + } + + auto &StoreList = HlslOP->GetOpFuncList(DXIL::OpCode::RawBufferVectorStore); + for (auto FIt = StoreList.begin(), FEnd = StoreList.end(); FIt != FEnd; FIt++) { + Function *F = FIt->second; + if (!F) + continue; + for (auto U = F->user_begin(), E = F->user_end(); U != E;) { + CallInst *CI = cast(*(U++)); + scalarizeVectorStore(HlslOP, M.getDataLayout(), CI); + Changed = true; + } + F->eraseFromParent(); + } + return Changed; + } +}; + +static unsigned GetRawBufferMask(unsigned NumComponents) { + switch (NumComponents) { + case 0: + return 0; + case 1: + return DXIL::kCompMask_X; + case 2: + return DXIL::kCompMask_X | DXIL::kCompMask_Y; + case 3: + return DXIL::kCompMask_X | DXIL::kCompMask_Y | DXIL::kCompMask_Z; + case 4: + default: + return DXIL::kCompMask_All; + } + return DXIL::kCompMask_All; +} + +void DxilScalarizeVectorLoadStores::scalarizeVectorLoad(hlsl::OP *HlslOP, + const DataLayout &DL, + CallInst *CI) { + IRBuilder<> Builder(CI); + // Collect the information required to break this into scalar ops from args. + DxilInst_RawBufferVectorLoad VecLd(CI); + OP::OpCode OpCode = OP::OpCode::RawBufferLoad; + llvm::Constant *opArg = Builder.getInt32((unsigned)OpCode); + SmallVector Args; + Args.emplace_back(opArg); // opcode @0. + Args.emplace_back(VecLd.get_buf()); // Resource handle @1. + Args.emplace_back(VecLd.get_index()); // Index @2. + Args.emplace_back(VecLd.get_elementOffset()); // Offset @3. + Args.emplace_back(nullptr); // Mask to be set later @4. + Args.emplace_back(VecLd.get_alignment()); // Alignment @5. + + // Set offset to increment depending on whether the real offset is defined. + unsigned OffsetIdx = 0; + if (isa(VecLd.get_elementOffset())) + // Byte Address Buffers can't use offset, so use index. + OffsetIdx = DXIL::OperandIndex::kRawBufferLoadIndexOpIdx; + else + OffsetIdx = DXIL::OperandIndex::kRawBufferLoadElementOffsetOpIdx; + + StructType *ResRetTy = cast(CI->getType()); + Type *Ty = ResRetTy->getElementType(0); + unsigned NumComponents = Ty->getVectorNumElements(); + Type *EltTy = Ty->getScalarType(); + unsigned EltSize = DL.getTypeAllocSize(EltTy); + + const unsigned MaxElemCount = 4; + SmallVector Elts(NumComponents); + Value *Ld = nullptr; + for (unsigned EIx = 0; EIx < NumComponents;) { + // Load 4 elements or however many less than 4 are left to load. + unsigned ChunkSize = std::min(NumComponents - EIx, MaxElemCount); + Args[DXIL::OperandIndex::kRawBufferLoadMaskOpIdx] = + HlslOP->GetI8Const(GetRawBufferMask(ChunkSize)); + // If we've loaded a chunk already, update offset to next chunk. + if (EIx > 0) + Args[OffsetIdx] = + Builder.CreateAdd(Args[OffsetIdx], HlslOP->GetU32Const(4 * EltSize)); + Function *F = HlslOP->GetOpFunc(OpCode, EltTy); + Ld = Builder.CreateCall(F, Args, OP::GetOpCodeName(OpCode)); + for (unsigned ChIx = 0; ChIx < ChunkSize; ChIx++, EIx++) + Elts[EIx] = Builder.CreateExtractValue(Ld, ChIx); + } + + Value *RetValNew = UndefValue::get(VectorType::get(EltTy, NumComponents)); + for (unsigned ElIx = 0; ElIx < NumComponents; ElIx++) + RetValNew = Builder.CreateInsertElement(RetValNew, Elts[ElIx], ElIx); + + // Replace users of the vector extracted from the vector load resret + // With our constructed one and we'll see if the can tell the difference. + Value *Status = nullptr; + for (auto CU = CI->user_begin(), CE = CI->user_end(); CU != CE;) { + auto EV = cast(*(CU++)); + unsigned Ix = EV->getIndices()[0]; + if (Ix == 0) { + // Handle value uses. + EV->replaceAllUsesWith(RetValNew); + } else if (Ix == 1) { + // Handle status uses. + if (!Status) + Status = Builder.CreateExtractValue(Ld, DXIL::kResRetStatusIndex); + EV->replaceAllUsesWith(Status); + } + EV->eraseFromParent(); + } + CI->eraseFromParent(); +} + +void DxilScalarizeVectorLoadStores::scalarizeVectorStore(hlsl::OP *HlslOP, + const DataLayout &DL, + CallInst *CI) { + IRBuilder<> Builder(CI); + // Collect the information required to break this into scalar ops from args. + DxilInst_RawBufferVectorStore VecSt(CI); + OP::OpCode OpCode = OP::OpCode::RawBufferStore; + llvm::Constant *opArg = Builder.getInt32((unsigned)OpCode); + SmallVector Args; + Args.emplace_back(opArg); // opcode @0. + Args.emplace_back(VecSt.get_uav()); // Resource handle @1. + Args.emplace_back(VecSt.get_index()); // Index @2. + Args.emplace_back(VecSt.get_elementOffset()); // Offset @3. + Args.emplace_back(nullptr); // Val0 to be set later @4. + Args.emplace_back(nullptr); // Val1 to be set later @5. + Args.emplace_back(nullptr); // Val2 to be set later @6. + Args.emplace_back(nullptr); // Val3 to be set later @7. + Args.emplace_back(nullptr); // Mask to be set later @8. + Args.emplace_back(VecSt.get_alignment()); // Alignment @9. + + // Set offset to increment depending on whether the real offset is defined. + unsigned OffsetIdx = 0; + if (isa(VecSt.get_elementOffset())) + // Byte Address Buffers can't use offset, so use index. + OffsetIdx = DXIL::OperandIndex::kRawBufferLoadIndexOpIdx; + else + OffsetIdx = DXIL::OperandIndex::kRawBufferLoadElementOffsetOpIdx; + + Value *VecVal = VecSt.get_value0(); + + const unsigned MaxElemCount = 4; + Type *Ty = VecVal->getType(); + const unsigned NumComponents = Ty->getVectorNumElements(); + Type *EltTy = Ty->getScalarType(); + Value *UndefVal = UndefValue::get(EltTy); + unsigned EltSize = DL.getTypeAllocSize(EltTy); + Function *F = HlslOP->GetOpFunc(OpCode, EltTy); + for (unsigned EIx = 0; EIx < NumComponents;) { + // Store 4 elements or however many less than 4 are left to store. + unsigned ChunkSize = std::min(NumComponents - EIx, MaxElemCount); + // For second and subsequent store calls, increment the resource-appropriate + // index or offset parameter. + if (EIx > 0) + Args[OffsetIdx] = + Builder.CreateAdd(Args[OffsetIdx], HlslOP->GetU32Const(4 * EltSize)); + // Populate all value arguments either with the vector or undefs. + uint8_t Mask = 0; + unsigned ChIx = 0; + for (; ChIx < ChunkSize; ChIx++, EIx++) { + Args[DXIL::OperandIndex::kRawBufferStoreVal0OpIdx + ChIx] = Builder.CreateExtractElement(VecVal, EIx); + Mask |= (1 << ChIx); + } + for (; ChIx < MaxElemCount; ChIx++) + Args[DXIL::OperandIndex::kRawBufferStoreVal0OpIdx + ChIx] = UndefVal; + + Args[DXIL::OperandIndex::kRawBufferStoreMaskOpIdx] = HlslOP->GetU8Const(Mask); + Builder.CreateCall(F, Args); + } + CI->eraseFromParent(); +} + +char DxilScalarizeVectorLoadStores::ID = 0; + +ModulePass *llvm::createDxilScalarizeVectorLoadStoresPass() { + return new DxilScalarizeVectorLoadStores(); +} + +INITIALIZE_PASS(DxilScalarizeVectorLoadStores, + "hlsl-dxil-scalarize-vector-load-stores", + "DXIL scalarize vector load/stores", false, false) + diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp index 3ab1f9fdec..078a63ba48 100644 --- a/lib/HLSL/HLOperationLower.cpp +++ b/lib/HLSL/HLOperationLower.cpp @@ -3953,6 +3953,11 @@ struct ResLoadHelper { : intrinsicOpCode(IntrinsicOp::Num_Intrinsics), handle(h), retVal(Inst), addr(idx), offset(Offset), status(nullptr), mipLevel(mip) { opcode = LoadOpFromResKind(RK); + Type *Ty = Inst->getType(); + if (opcode == OP::OpCode::RawBufferLoad && Ty->isVectorTy() && + Ty->getVectorNumElements() > 1 && + Inst->getModule()->GetHLModule().GetShaderModel()->IsSM69Plus()) + opcode = OP::OpCode::RawBufferVectorLoad; } OP::OpCode opcode; IntrinsicOp intrinsicOpCode; @@ -4022,6 +4027,14 @@ ResLoadHelper::ResLoadHelper(CallInst *CI, DxilResource::Kind RK, if (RC == DxilResourceBase::Class::SRV) OffsetIdx = IsMS ? HLOperandIndex::kTex2DMSLoadOffsetOpIdx : HLOperandIndex::kTexLoadOffsetOpIdx; + } else if (opcode == OP::OpCode::RawBufferLoad) { + // If native vectors are available and this load had a vector + // with more than one elements, convert the RawBufferLod to the + // native vector variant RawBufferVectorLoad. + Type *Ty = CI->getType(); + if (Ty->isVectorTy() && Ty->getVectorNumElements() > 1 && + CI->getModule()->GetHLModule().GetShaderModel()->IsSM69Plus()) + opcode = OP::OpCode::RawBufferVectorLoad; } // Set offset. @@ -4079,7 +4092,7 @@ Value *GenerateRawBufLd(Value *handle, Value *bufIdx, Value *offset, // Sets up arguments for buffer load call. static SmallVector GetBufLoadArgs(ResLoadHelper helper, HLResource::Kind RK, - IRBuilder<> Builder, Type *EltTy, + IRBuilder<> Builder, unsigned LdSize) { OP::OpCode opcode = helper.opcode; llvm::Constant *opArg = Builder.getInt32((uint32_t)opcode); @@ -4127,6 +4140,7 @@ static SmallVector GetBufLoadArgs(ResLoadHelper helper, // If not TextureLoad, it could be a typed or raw buffer load. // They have mostly similar arguments. DXASSERT(opcode == OP::OpCode::RawBufferLoad || + opcode == OP::OpCode::RawBufferVectorLoad || opcode == OP::OpCode::BufferLoad, "Wrong opcode in get load args"); Args.emplace_back( @@ -4137,6 +4151,9 @@ static SmallVector GetBufLoadArgs(ResLoadHelper helper, // Unlike typed buffer load, raw buffer load has mask and alignment. Args.emplace_back(nullptr); // Mask will be added later %4. Args.emplace_back(alignmentVal); // alignment @5. + } else if (opcode == OP::OpCode::RawBufferVectorLoad) { + // RawBufferVectorLoad takes just alignment, no mask. + Args.emplace_back(alignmentVal); // alignment @4 } } return Args; @@ -4162,18 +4179,19 @@ Value *TranslateBufLoad(ResLoadHelper &helper, HLResource::Kind RK, if (isBool || (is64 && isTyped)) EltTy = Builder.getInt32Ty(); - // 64-bit types are stored as int32 pairs in typed buffers. + // Adjust number of components as needed. if (is64 && isTyped) { + // 64-bit types are stored as int32 pairs in typed buffers. DXASSERT(NumComponents <= 2, "Typed buffers only allow 4 dwords."); NumComponents *= 2; + } else if (opcode == OP::OpCode::RawBufferVectorLoad) { + // Native vector loads only have a single vector element in ResRet. + EltTy = VectorType::get(EltTy, NumComponents); + NumComponents = 1; } unsigned LdSize = DL.getTypeAllocSize(EltTy); - - SmallVector Elts(NumComponents); - - SmallVector Args = - GetBufLoadArgs(helper, RK, Builder, EltTy, LdSize); + SmallVector Args = GetBufLoadArgs(helper, RK, Builder, LdSize); // Keep track of the first load for debug info migration. Value *FirstLd = nullptr; @@ -4185,9 +4203,10 @@ Value *TranslateBufLoad(ResLoadHelper &helper, HLResource::Kind RK, else if (RK == DxilResource::Kind::StructuredBuffer) OffsetIdx = DXIL::OperandIndex::kRawBufferLoadElementOffsetOpIdx; - // Create calls to function object. + // Create call(s) to function object and collect results in Elts. // Typed buffer loads are limited to one load of up to 4 32-bit values. // Raw buffer loads might need multiple loads in chunks of 4. + SmallVector Elts(NumComponents); for (unsigned i = 0; i < NumComponents;) { // Load 4 elements or however many less than 4 are left to load. unsigned chunkSize = std::min(NumComponents - i, 4U); @@ -4197,7 +4216,7 @@ Value *TranslateBufLoad(ResLoadHelper &helper, HLResource::Kind RK, Args[DXIL::OperandIndex::kRawBufferLoadMaskOpIdx] = GetRawBufferMaskForETy(EltTy, chunkSize, OP); // If we've loaded a chunk already, update offset to next chunk. - if (FirstLd != nullptr && opcode == OP::OpCode::RawBufferLoad) + if (FirstLd != nullptr) Args[OffsetIdx] = Builder.CreateAdd(Args[OffsetIdx], OP->GetU32Const(4 * LdSize)); } @@ -4206,8 +4225,13 @@ Value *TranslateBufLoad(ResLoadHelper &helper, HLResource::Kind RK, Value *Ld = Builder.CreateCall(F, Args, OP::GetOpCodeName(opcode)); // Extract elements from returned ResRet. - for (unsigned j = 0; j < chunkSize; j++, i++) - Elts[i] = Builder.CreateExtractValue(Ld, j); + // Native vector loads just have one vector element in the ResRet. + // Others have up to four scalars that need to be individually extracted. + if (opcode == OP::OpCode::RawBufferVectorLoad) + Elts[i++] = Builder.CreateExtractValue(Ld, 0); + else + for (unsigned j = 0; j < chunkSize; j++, i++) + Elts[i] = Builder.CreateExtractValue(Ld, j); // Update status. UpdateStatus(Ld, helper.status, Builder, OP); @@ -4245,9 +4269,10 @@ Value *TranslateBufLoad(ResLoadHelper &helper, HLResource::Kind RK, } } - // Package elements into a vector. + // Package elements into a vector as needed. Value *retValNew = nullptr; - if (!Ty->isVectorTy()) { + // Scalar or native vector loads need not construct vectors from elements. + if (!Ty->isVectorTy() || opcode == OP::OpCode::RawBufferVectorLoad) { retValNew = Elts[0]; } else { retValNew = UndefValue::get(VectorType::get(EltTy, NumComponents)); @@ -4345,6 +4370,10 @@ void TranslateStore(DxilResource::Kind RK, Value *handle, Value *val, case DxilResource::Kind::StructuredBuffer: IsTyped = false; opcode = OP::OpCode::RawBufferStore; + // Where shader model and type allows, use vector store intrinsic. + if (OP->GetModule()->GetHLModule().GetShaderModel()->IsSM69Plus() && + Ty->isVectorTy() && Ty->getVectorNumElements() > 1) + opcode = OP::OpCode::RawBufferVectorStore; break; case DxilResource::Kind::TypedBuffer: opcode = OP::OpCode::BufferStore; @@ -4387,7 +4416,6 @@ void TranslateStore(DxilResource::Kind RK, Value *handle, Value *val, EltTy = i32Ty; } - Function *F = OP->GetOpFunc(opcode, EltTy); llvm::Constant *opArg = OP->GetU32Const((unsigned)opcode); llvm::Value *undefI = @@ -4401,6 +4429,7 @@ void TranslateStore(DxilResource::Kind RK, Value *handle, Value *val, unsigned OffsetIdx = 0; if (opcode == OP::OpCode::RawBufferStore || + opcode == OP::OpCode::RawBufferVectorStore || opcode == OP::OpCode::BufferStore) { // Append Coord0 (Index) value. if (Idx->getType()->isVectorTy()) { @@ -4420,7 +4449,6 @@ void TranslateStore(DxilResource::Kind RK, Value *handle, Value *val, OffsetIdx = storeArgs.size() - 1; // Coord1 (Offset). - // Only relevant when storing more than 4 elements to structured buffers. storeArgs.emplace_back(offset); } else { // texture store @@ -4441,6 +4469,16 @@ void TranslateStore(DxilResource::Kind RK, Value *handle, Value *val, // TODO: support mip for texture ST } + // RawBufferVectorStore only takes a single value and alignment arguments. + if (opcode == DXIL::OpCode::RawBufferVectorStore) { + storeArgs.emplace_back(val); + storeArgs.emplace_back(Alignment); + Function *F = OP->GetOpFunc(DXIL::OpCode::RawBufferVectorStore, Ty); + Builder.CreateCall(F, storeArgs); + return; + } + Function *F = OP->GetOpFunc(opcode, EltTy); + constexpr unsigned MaxStoreElemCount = 4; const unsigned CompCount = Ty->isVectorTy() ? Ty->getVectorNumElements() : 1; const unsigned StoreInstCount = diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index d20daa0ac0..18d0bfec01 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -344,6 +344,7 @@ enum ArBasicKind { #define BPROP_FEEDBACKTEXTURE \ 0x00800000 // Whether the type is a feedback texture. #define BPROP_ENUM 0x01000000 // Whether the type is a enum +#define BPROP_RAWBUFFER 0x02000000 // Whether the type is a raw buffer. #define GET_BPROP_PRIM_KIND(_Props) \ ((_Props) & (BPROP_BOOLEAN | BPROP_INTEGER | BPROP_FLOATING)) @@ -384,6 +385,7 @@ enum ArBasicKind { (IS_BPROP_AINT(_Props) && GET_BPROP_BITS(_Props) != BPROP_BITS12) #define IS_BPROP_ENUM(_Props) (((_Props)&BPROP_ENUM) != 0) +#define IS_BPROP_RAWBUFFER(_Props) (((_Props)&BPROP_RAWBUFFER) != 0) const UINT g_uBasicKindProps[] = { BPROP_PRIMITIVE | BPROP_BOOLEAN | BPROP_INTEGER | BPROP_NUMERIC | @@ -512,22 +514,22 @@ const UINT g_uBasicKindProps[] = { BPROP_OBJECT | BPROP_RWBUFFER | BPROP_TEXTURE, // AR_OBJECT_RWTEXTURE3D BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWBUFFER - BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_BYTEADDRESS_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWBYTEADDRESS_BUFFER - BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_STRUCTURED_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_ALLOC - BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_CONSUME - BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_APPEND_STRUCTURED_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_CONSUME_STRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_BYTEADDRESS_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWBYTEADDRESS_BUFFER + BPROP_OBJECT | BPROP_RBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_STRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_ALLOC + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_CONSUME + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_APPEND_STRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_CONSUME_STRUCTURED_BUFFER BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_CONSTANT_BUFFER BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_TEXTURE_BUFFER BPROP_OBJECT | BPROP_RWBUFFER | BPROP_ROVBUFFER, // AR_OBJECT_ROVBUFFER - BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER | BPROP_ROVBUFFER, // AR_OBJECT_ROVBYTEADDRESS_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER | BPROP_ROVBUFFER, // AR_OBJECT_ROVSTRUCTURED_BUFFER BPROP_OBJECT | BPROP_RWBUFFER | BPROP_ROVBUFFER, // AR_OBJECT_ROVTEXTURE1D BPROP_OBJECT | BPROP_RWBUFFER | @@ -641,6 +643,8 @@ C_ASSERT(ARRAYSIZE(g_uBasicKindProps) == AR_BASIC_MAXIMUM_COUNT); #define IS_BASIC_ENUM(_Kind) IS_BPROP_ENUM(GetBasicKindProps(_Kind)) +#define IS_BASIC_RAWBUFFER(_Kind) IS_BPROP_RAWBUFFER(GetBasicKindProps(_Kind)) + #define BITWISE_ENUM_OPS(_Type) \ inline _Type operator|(_Type F1, _Type F2) { \ return (_Type)((UINT)F1 | (UINT)F2); \ @@ -15071,7 +15075,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth, } // Disallow long vecs from $Global cbuffers. - if (isGlobal && !isStatic && !isGroupShared) { + if (isGlobal && !isStatic && !isGroupShared && + !IS_BASIC_RAWBUFFER(basicKind)) { // Suppress actual emitting of errors for incompletable types here // They are redundant to those produced in ActOnUninitializedDecl. struct SilentDiagnoser : public TypeDiagnoser { diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/buffer-load-stores-sm69.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/buffer-load-stores-sm69.hlsl new file mode 100644 index 0000000000..b1e3b92f79 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/buffer-load-stores-sm69.hlsl @@ -0,0 +1,90 @@ +// RUN: %dxc -DTYPE=float -DNUM=4 -T vs_6_9 %s | FileCheck %s +// RUN: %dxc -DTYPE=bool -DNUM=4 -T vs_6_9 %s | FileCheck %s --check-prefixes=CHECK,I1 +// RUN: %dxc -DTYPE=uint64_t -DNUM=2 -T vs_6_9 %s | FileCheck %s +// RUN: %dxc -DTYPE=double -DNUM=2 -T vs_6_9 %s | FileCheck %s + +// RUN: %dxc -DTYPE=float -DNUM=6 -T vs_6_9 %s | FileCheck %s +// RUN: %dxc -DTYPE=bool -DNUM=13 -T vs_6_9 %s | FileCheck %s --check-prefixes=CHECK,I1 +// RUN: %dxc -DTYPE=uint64_t -DNUM=24 -T vs_6_9 %s | FileCheck %s +// RUN: %dxc -DTYPE=double -DNUM=32 -T vs_6_9 %s | FileCheck %s + +/////////////////////////////////////////////////////////////////////// +// Test codegen for various load and store operations and conversions +// for different scalar/vector buffer types and indices. +/////////////////////////////////////////////////////////////////////// + +// CHECK: %dx.types.ResRet.[[VTY:v[0-9]*[a-z][0-9][0-9]]] = type { <[[NUM:[0-9]*]] x [[TYPE:[a-z_0-9]*]]>, i32 } + + ByteAddressBuffer RoByBuf : register(t1); +RWByteAddressBuffer RwByBuf : register(u1); + +StructuredBuffer< vector > RoStBuf : register(t2); +RWStructuredBuffer< vector > RwStBuf : register(u2); + +ConsumeStructuredBuffer > CnStBuf : register(u4); +AppendStructuredBuffer > ApStBuf : register(u5); + +// CHECK-LABEL: define void @main +void main(uint ix[2] : IX) { + // ByteAddressBuffer Tests + + // CHECK-DAG: [[HDLROBY:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 1, i32 1, i32 0, i8 0 }, i32 1, i1 false) + // CHECK-DAG: [[HDLRWBY:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 1, i32 1, i32 0, i8 1 }, i32 1, i1 false) + + // CHECK-DAG: [[HDLROST:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 2, i32 2, i32 0, i8 0 }, i32 2, i1 false) + // CHECK-DAG: [[HDLRWST:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 2, i32 2, i32 0, i8 1 }, i32 2, i1 false) + + // CHECK-DAG: [[HDLCON:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 4, i32 4, i32 0, i8 1 }, i32 4, i1 false) + // CHECK-DAG: [[HDLAPP:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 5, i32 5, i32 0, i8 1 }, i32 5, i1 false) + + // CHECK: [[IX0:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, + + // CHECK: [[ANHDLRWBY:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[HDLRWBY]] + // CHECK: call %dx.types.ResRet.[[VTY]] @dx.op.rawBufferVectorLoad.[[VTY]](i32 303, %dx.types.Handle [[ANHDLRWBY]], i32 [[IX0]] + // I1: icmp ne <[[NUM]] x i32> %{{.*}}, zeroinitializer + vector babElt1 = RwByBuf.Load< vector >(ix[0]); + + // CHECK: [[ANHDLROBY:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[HDLROBY]] + // CHECK: call %dx.types.ResRet.[[VTY]] @dx.op.rawBufferVectorLoad.[[VTY]](i32 303, %dx.types.Handle [[ANHDLROBY]], i32 [[IX0]] + // I1: icmp ne <[[NUM]] x i32> %{{.*}}, zeroinitializer + vector babElt2 = RoByBuf.Load< vector >(ix[0]); + + // I1: zext <[[NUM]] x i1> %{{.*}} to <[[NUM]] x i32> + // CHECK: all void @dx.op.rawBufferVectorStore.[[VTY]](i32 304, %dx.types.Handle [[ANHDLRWBY]], i32 [[IX0]] + RwByBuf.Store< vector >(ix[0], babElt1 + babElt2); + + // StructuredBuffer Tests + // CHECK: [[ANHDLRWST:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[HDLRWST]] + // CHECK: call %dx.types.ResRet.[[VTY]] @dx.op.rawBufferVectorLoad.[[VTY]](i32 303, %dx.types.Handle [[ANHDLRWST]], i32 [[IX0]] + // I1: icmp ne <[[NUM]] x i32> %{{.*}}, zeroinitializer + vector stbElt1 = RwStBuf.Load(ix[0]); + // CHECK: [[IX1:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, + // CHECK: call %dx.types.ResRet.[[VTY]] @dx.op.rawBufferVectorLoad.[[VTY]](i32 303, %dx.types.Handle [[ANHDLRWST]], i32 [[IX1]] + // I1: icmp ne <[[NUM]] x i32> %{{.*}}, zeroinitializer + vector stbElt2 = RwStBuf[ix[1]]; + + // CHECK: [[ANHDLROST:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[HDLROST]] + // CHECK: call %dx.types.ResRet.[[VTY]] @dx.op.rawBufferVectorLoad.[[VTY]](i32 303, %dx.types.Handle [[ANHDLROST]], i32 [[IX0]] + // I1: icmp ne <[[NUM]] x i32> %{{.*}}, zeroinitializer + vector stbElt3 = RoStBuf.Load(ix[0]); + // CHECK: call %dx.types.ResRet.[[VTY]] @dx.op.rawBufferVectorLoad.[[VTY]](i32 303, %dx.types.Handle [[ANHDLROST]], i32 [[IX1]] + // I1: icmp ne <[[NUM]] x i32> %{{.*}}, zeroinitializer + vector stbElt4 = RoStBuf[ix[1]]; + + // I1: zext <[[NUM]] x i1> %{{.*}} to <[[NUM]] x i32> + // CHECK: all void @dx.op.rawBufferVectorStore.[[VTY]](i32 304, %dx.types.Handle [[ANHDLRWST]], i32 [[IX0]] + RwStBuf[ix[0]] = stbElt1 + stbElt2 + stbElt3 + stbElt4; + + // {Append/Consume}StructuredBuffer Tests + // CHECK: [[ANHDLCON:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[HDLCON]] + // CHECK: [[CONIX:%.*]] = call i32 @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[ANHDLCON]], i8 -1) + // CHECK: call %dx.types.ResRet.[[VTY]] @dx.op.rawBufferVectorLoad.[[VTY]](i32 303, %dx.types.Handle [[ANHDLCON]], i32 [[CONIX]] + // I1: icmp ne <[[NUM]] x i32> %{{.*}}, zeroinitializer + vector cnElt = CnStBuf.Consume(); + + // CHECK: [[ANHDLAPP:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[HDLAPP]] + // CHECK: [[APPIX:%.*]] = call i32 @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[ANHDLAPP]], i8 1) + // I1: zext <[[NUM]] x i1> %{{.*}} to <[[NUM]] x i32> + // CHECK: all void @dx.op.rawBufferVectorStore.[[VTY]](i32 304, %dx.types.Handle [[ANHDLAPP]], i32 [[APPIX]] + ApStBuf.Append(cnElt); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators-cs.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators-cs.hlsl new file mode 100644 index 0000000000..e6a5def3b6 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators-cs.hlsl @@ -0,0 +1,708 @@ +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=float -DNUM=2 %s | FileCheck %s --check-prefixes=CHECK,NODBL,NOINT +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=float -DNUM=17 %s | FileCheck %s --check-prefixes=CHECK,NODBL,NOINT +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=int -DNUM=2 -DINT %s | FileCheck %s --check-prefixes=CHECK,NODBL,INT,SIG +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=uint -DNUM=5 -DINT %s | FileCheck %s --check-prefixes=CHECK,NODBL,INT,UNSIG +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=double -DNUM=3 -DDBL %s | FileCheck %s --check-prefixes=CHECK,DBL,NOINT +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=uint64_t -DNUM=9 -DINT %s | FileCheck %s --check-prefixes=CHECK,NODBL,INT,UNSIG +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=float16_t -DNUM=17 -enable-16bit-types %s | FileCheck %s --check-prefixes=CHECK,NODBL,NOINT +// RUN: %dxc -HV 2018 -T cs_6_9 -DTYPE=int16_t -DNUM=33 -DINT -enable-16bit-types %s | FileCheck %s --check-prefixes=CHECK,NODBL,INT,SIG + +// Linking tests. +// RUN: %dxc -HV 2018 -T lib_6_9 -DTYPE=float -DNUM=6 -Fo %t.1 %s +// RUN: %dxl -T cs_6_9 %t.1 | FileCheck %s --check-prefixes=CHECK,NODBL,NOINT +// RUN: %dxc -HV 2018 -T lib_6_9 -DTYPE=double -DNUM=3 -DDBL -Fo %t.2 %s +// RUN: %dxl -T cs_6_9 %t.2 | FileCheck %s --check-prefixes=CHECK,DBL,NOINT +// RUN: %dxc -HV 2018 -T lib_6_9 -DTYPE=uint16_t -DNUM=12 -DINT -enable-16bit-types -Fo %t.3 %s +// RUN: %dxl -T cs_6_9 %t.3 | FileCheck %s --check-prefixes=CHECK,NODBL,INT,UNSIG + +// Test relevant operators on an assortment vector sizes and types with 6.9 native vectors. +// Tests in a CS environment where vector operations were previously disallowed to confirm that they are retained. + +// Just a trick to capture the needed type spellings since the DXC version of FileCheck can't do that explicitly. +// Uses non vector buffer to avoid interacting with that implementation. +// CHECK-DAG: %dx.types.ResRet.[[TY:v[0-9]*[a-z][0-9]*]] = type { <[[NUM:[0-9]*]] x [[TYPE:[a-z_0-9]*]]> +// CHECK-DAG: %dx.types.ResRet.[[STY:[a-z][0-9]*]] = type { [[STYPE:[a-z0-9_]*]] +// CHECK-DAG: %dx.types.ResRet.[[ITY:v[0-9]*i32]] = type { <[[NUM]] x i32> + +export void assignments(inout vector things[11], TYPE scales[10]); +export vector arithmetic(inout vector things[11])[11]; +export vector scarithmetic(vector things[11], TYPE scales[10])[11]; +export vector logic(vector truth[10], vector consequences[11])[10]; +export vector index(vector things[11], int i, TYPE val)[11]; +export void bittwiddlers(inout vector things[13]); + +struct Viface { + vector values[11]; +}; + +struct Siface { + TYPE values[10]; +}; + +struct Liface { + vector values[10]; +}; + +struct Biface { + vector values[13]; +}; + +// Requires vector loading support. Enable when available. +RWStructuredBuffer Input : register(u11); +RWStructuredBuffer Output : register(u12); +RWStructuredBuffer Scales : register(u13); +RWStructuredBuffer Truths : register(u14); +RWStructuredBuffer Bits : register(u15); +RWStructuredBuffer > Offsets : register(u16); + +TYPE g_val; + +[shader("compute")] +[numthreads(8,1,1)] +// CHECK-LABEL: define void @main +void main(uint3 GID : SV_GroupThreadID) { + + // CHECK-DAG: [[Input:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 11, i32 11, i32 0, i8 1 }, i32 11 + // CHECK-DAG: [[Output:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 12, i32 12, i32 0, i8 1 }, i32 12 + // CHECK-DAG: [[Scales:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 13, i32 13, i32 0, i8 1 }, i32 13 + // CHECK-DAG: [[Truths:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 14, i32 14, i32 0, i8 1 }, i32 14 + // INT-DAG: [[Bits:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 15, i32 15, i32 0, i8 1 }, i32 15 + + // CHECK: [[InIx1:%.*]] = call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 0) + // CHECK: [[InIx2:%.*]] = call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 1) + // CHECK: [[OutIx:%.*]] = call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 2) + // CHECK: [[scratch1:%.*]] = alloca [11 x <[[NUM]] x [[TYPE]]>] + // CHECK: [[scratch2:%.*]] = alloca [11 x <[[NUM]] x [[TYPE]]>] + + uint InIx1 = GID[0]; + uint InIx2 = GID[1]; + uint OutIx = GID[2]; + + // Assign vector offsets to capture the expected values. + // CHECK: call void @dx.op.rawBufferVectorStore.v13i32(i32 304, %dx.types.Handle {{%.*}}, i32 0, i32 0, <13 x i32> + Offsets[0] = vector(sizeof(vector)*0, + sizeof(vector)*1, + sizeof(vector)*2, + sizeof(vector)*3, + sizeof(vector)*4, + sizeof(vector)*5, + sizeof(vector)*6, + sizeof(vector)*7, + sizeof(vector)*8, + sizeof(vector)*9, + sizeof(vector)*10, + sizeof(vector)*11, + sizeof(vector)*12); + + // Assign scalar offsets to capture the expected values. + // CHECK: call void @dx.op.rawBufferVectorStore.v13i32(i32 304, %dx.types.Handle {{%.*}}, i32 1, i32 0, <13 x i32> + Offsets[1] = vector(sizeof(TYPE)*0, + sizeof(TYPE)*1, + sizeof(TYPE)*2, + sizeof(TYPE)*3, + sizeof(TYPE)*4, + sizeof(TYPE)*5, + sizeof(TYPE)*6, + sizeof(TYPE)*7, + sizeof(TYPE)*8, + sizeof(TYPE)*9, + sizeof(TYPE)*10, + sizeof(TYPE)*11, + sizeof(TYPE));// Effectively alignof. + + // Assign boolean offsets to capture the expected values. + // CHECK: call void @dx.op.rawBufferVectorStore.v13i32(i32 304, %dx.types.Handle {{%.*}}, i32 2, i32 0, <13 x i32> + Offsets[2] = vector(sizeof(vector)*0, + sizeof(vector)*1, + sizeof(vector)*2, + sizeof(vector)*3, + sizeof(vector)*4, + sizeof(vector)*5, + sizeof(vector)*6, + sizeof(vector)*7, + sizeof(vector)*8, + sizeof(vector)*9, + sizeof(vector)*10, + sizeof(vector)*11, + sizeof(vector)*12); + + assignments(Input[InIx1+1].values, Scales[InIx2+1].values); + Output[OutIx+2].values = arithmetic(Input[InIx1+2].values); + Output[OutIx+3].values = scarithmetic(Input[InIx1+3].values, Scales[InIx2+3].values); + Truths[OutIx+4].values = logic(Truths[InIx2+4].values, Input[InIx1+4].values); + Output[OutIx+5].values = index(Input[InIx1+5].values, InIx2+5, g_val); +#ifdef INT + bittwiddlers(Bits[InIx1+6].values); +#endif +} + +// A mixed-type overload to test overload resolution and mingle different vector element types in ops +// Test assignment operators. +void assignments(inout vector things[11], TYPE scales[10]) { + + // CHECK: [[VcIx:%.*]] = add i32 [[InIx1]], 1 + // CHECK: [[InHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Input]] + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF1]], i32 8) + // CHECK: [[vec1:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF2]], i32 8) + // CHECK: [[vec2:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF3]], i32 8) + // CHECK: [[vec3:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF4]], i32 8) + // CHECK: [[vec4:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF5]], i32 8) + // CHECK: [[vec5:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF6]], i32 8) + // CHECK: [[vec6:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF7]], i32 8) + // CHECK: [[vec7:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF8]], i32 8) + // CHECK: [[vec8:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF9]], i32 8) + // CHECK: [[vec9:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + + // CHECK: [[ScIx:%.*]] = add i32 [[InIx2]], 1 + // CHECK: [[ScHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Scales]] + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[ScHdl]], i32 [[ScIx]], i32 [[OFF0]], i8 1, i32 [[ALN]]) + // CHECK: [[scl0:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[ScHdl]], i32 [[ScIx]], i32 [[SOFF1]], i8 1, i32 [[ALN]]) + // CHECK: [[scl1:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[ScHdl]], i32 [[ScIx]], i32 [[SOFF2]], i8 1, i32 [[ALN]]) + // CHECK: [[scl2:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[ScHdl]], i32 [[ScIx]], i32 [[SOFF3]], i8 1, i32 [[ALN]]) + // CHECK: [[scl3:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[ScHdl]], i32 [[ScIx]], i32 [[SOFF4]], i8 1, i32 [[ALN]]) + // CHECK: [[scl4:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl0]], i32 0 + // CHECK: [[res0:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + things[0] = scales[0]; + + // CHECK: [[res1:%[0-9]*]] = [[ADD:f?add( fast)?]] <[[NUM]] x [[TYPE]]> [[vec5]], [[vec1]] + things[1] += things[5]; + + // CHECK: [[res2:%[0-9]*]] = [[SUB:f?sub( fast)?]] <[[NUM]] x [[TYPE]]> [[vec2]], [[vec6]] + things[2] -= things[6]; + + // CHECK: [[res3:%[0-9]*]] = [[MUL:f?mul( fast)?]] <[[NUM]] x [[TYPE]]> [[vec7]], [[vec3]] + things[3] *= things[7]; + + // CHECK: [[res4:%[0-9]*]] = [[DIV:[ufs]?div( fast)?]] <[[NUM]] x [[TYPE]]> [[vec4]], [[vec8]] + things[4] /= things[8]; + +#ifdef DBL + // DBL can't use remainder operator, do something anyway to keep the rest consistent. + // DBL: [[fvec9:%[0-9]*]] = fptrunc <[[NUM]] x double> [[vec9]] to <[[NUM]] x float> + // DBL: [[fvec5:%[0-9]*]] = fptrunc <[[NUM]] x double> [[vec5]] to <[[NUM]] x float> + // DBL: [[fres5:%[0-9]*]] = [[REM:[ufs]?rem( fast)?]] <[[NUM]] x float> [[fvec5]], [[fvec9]] + // DBL: [[res5:%[0-9]*]] = fpext <[[NUM]] x float> [[fres5]] to <[[NUM]] x double> + vector f9 = (vector)things[9]; + vector f5 = (vector)things[5]; + f5 %= f9; + things[5] = f5; +#else + // NODBL: [[res5:%[0-9]*]] = [[REM:[ufs]?rem( fast)?]] <[[NUM]] x [[TYPE]]> [[vec5]], [[vec9]] + things[5] %= things[9]; +#endif + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl1]], i32 0 + // CHECK: [[spt1:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res6:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[spt1]], [[vec6]] + things[6] += scales[1]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl2]], i32 0 + // CHECK: [[spt2:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res7:%[0-9]*]] = [[SUB]] <[[NUM]] x [[TYPE]]> [[vec7]], [[spt2]] + things[7] -= scales[2]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl3]], i32 0 + // CHECK: [[spt3:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res8:%[0-9]*]] = [[MUL]] <[[NUM]] x [[TYPE]]> [[spt3]], [[vec8]] + things[8] *= scales[3]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl4]], i32 0 + // CHECK: [[spt4:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res9:%[0-9]*]] = [[DIV]] <[[NUM]] x [[TYPE]]> [[vec9]], [[spt4]] + things[9] /= scales[4]; + + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF0]], <[[NUM]] x [[TYPE]]> [[res0]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF1]], <[[NUM]] x [[TYPE]]> [[res1]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF2]], <[[NUM]] x [[TYPE]]> [[res2]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF3]], <[[NUM]] x [[TYPE]]> [[res3]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF4]], <[[NUM]] x [[TYPE]]> [[res4]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF5]], <[[NUM]] x [[TYPE]]> [[res5]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF6]], <[[NUM]] x [[TYPE]]> [[res6]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF7]], <[[NUM]] x [[TYPE]]> [[res7]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF8]], <[[NUM]] x [[TYPE]]> [[res8]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF9]], <[[NUM]] x [[TYPE]]> [[res9]], i32 [[ALN]]) + +} + +// Test arithmetic operators. +vector arithmetic(inout vector things[11])[11] { + vector res[11]; + + // CHECK: [[ResIx:%.*]] = add i32 [[OutIx]], 2 + // CHECK: [[ResHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Output]] + // CHECK: [[VecIx:%.*]] = add i32 [[InIx1]], 2 + // CHECK: [[InHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Input]] + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF0]], i32 8) + // CHECK: [[vec0:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF1]], i32 8) + // CHECK: [[vec1:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF2]], i32 8) + // CHECK: [[vec2:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF3]], i32 8) + // CHECK: [[vec3:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF4]], i32 8) + // CHECK: [[vec4:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF5]], i32 8) + // CHECK: [[vec5:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF6]], i32 8) + // CHECK: [[vec6:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF7]], i32 8) + // CHECK: [[vec7:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF8]], i32 8) + // CHECK: [[vec8:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF9]], i32 8) + // CHECK: [[vec9:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF10]], i32 8) + // CHECK: [[vec10:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + + // NOINT: [[res0:%[0-9]*]] = [[SUB]] <[[NUM]] x [[TYPE]]> <[[TYPE]] {{-?(0|0\.0*e\+0*|0xH8000),.*}}>, [[vec0]] + // INT: [[res0:%[0-9]*]] = [[SUB]] <[[NUM]] x [[TYPE]]> zeroinitializer, [[vec0]] + res[0] = -things[0]; + res[1] = +things[0]; + + // CHECK: [[res2:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[vec2]], [[vec1]] + res[2] = things[1] + things[2]; + + // CHECK: [[res3:%[0-9]*]] = [[SUB]] <[[NUM]] x [[TYPE]]> [[vec2]], [[vec3]] + res[3] = things[2] - things[3]; + + // CHECK: [[res4:%[0-9]*]] = [[MUL]] <[[NUM]] x [[TYPE]]> [[vec4]], [[vec3]] + res[4] = things[3] * things[4]; + + // CHECK: [[res5:%[0-9]*]] = [[DIV]] <[[NUM]] x [[TYPE]]> [[vec4]], [[vec5]] + res[5] = things[4] / things[5]; + + // DBL: [[fvec5:%[0-9]*]] = fptrunc <[[NUM]] x double> [[vec5]] to <[[NUM]] x float> +#ifdef DBL + // DBL can't use remainder operator, do something anyway to keep the rest consistent. + // DBL: [[fvec6:%[0-9]*]] = fptrunc <[[NUM]] x double> [[vec6]] to <[[NUM]] x float> + // DBL: [[fres6:%[0-9]*]] = [[REM]] <[[NUM]] x float> [[fvec5]], [[fvec6]] + // DBL: [[res6:%[0-9]*]] = fpext <[[NUM]] x float> [[fres6]] to <[[NUM]] x double> + res[6] = (vector)things[5] % (vector)things[6]; +#else + // NODBL: [[res6:%[0-9]*]] = [[REM]] <[[NUM]] x [[TYPE]]> [[vec5]], [[vec6]] + res[6] = things[5] % things[6]; +#endif + + // CHECK: [[res7:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[vec7]], <[[TYPE]] [[POS1:(1|1\.0*e\+0*|0xH3C00)]] + res[7] = things[7]++; + + // CHECK: [[res8:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[vec8]], <[[TYPE]] [[NEG1:(-1|-1\.0*e\+0*|0xHBC00)]] + res[8] = things[8]--; + + // CHECK: [[res9:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[vec9]], <[[TYPE]] [[POS1]] + res[9] = ++things[9]; + + // CHECK: [[res10:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[vec10]], <[[TYPE]] [[NEG1]] + res[10] = --things[10]; + + // Things[] input gets all the result values since pre/post inc/decrements don't change the end result. + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF7]], <[[NUM]] x [[TYPE]]> [[res7]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF8]], <[[NUM]] x [[TYPE]]> [[res8]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF9]], <[[NUM]] x [[TYPE]]> [[res9]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF10]], <[[NUM]] x [[TYPE]]> [[res10]], i32 [[ALN]]) + + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF0]], <[[NUM]] x [[TYPE]]> [[res0]], i32 [[ALN]]) + // res1 is just vec0 since it was just the unary + operator. + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF1]], <[[NUM]] x [[TYPE]]> [[vec0]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF2]], <[[NUM]] x [[TYPE]]> [[res2]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF3]], <[[NUM]] x [[TYPE]]> [[res3]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF4]], <[[NUM]] x [[TYPE]]> [[res4]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF5]], <[[NUM]] x [[TYPE]]> [[res5]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF6]], <[[NUM]] x [[TYPE]]> [[res6]], i32 [[ALN]]) + // res[] input gets either the original or the preincremented value. + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF7]], <[[NUM]] x [[TYPE]]> [[vec7]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF8]], <[[NUM]] x [[TYPE]]> [[vec8]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF9]], <[[NUM]] x [[TYPE]]> [[res9]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF10]], <[[NUM]] x [[TYPE]]> [[res10]], i32 [[ALN]]) + + return res; +} + +// Test arithmetic operators with scalars. +vector scarithmetic(vector things[11], TYPE scales[10])[11] { + vector res[11]; + + // CHECK: [[ResIx:%.*]] = add i32 [[OutIx]], 3 + // CHECK: [[ResHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Output]] + // CHECK: [[VecIx:%.*]] = add i32 [[InIx1]], 3 + // CHECK: [[InHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Input]] + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF0]], i32 8) + // CHECK: [[vec0:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF1]], i32 8) + // CHECK: [[vec1:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF2]], i32 8) + // CHECK: [[vec2:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF3]], i32 8) + // CHECK: [[vec3:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF4]], i32 8) + // CHECK: [[vec4:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF5]], i32 8) + // CHECK: [[vec5:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF6]], i32 8) + // CHECK: [[vec6:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + + // CHECK: [[SclIx:%.*]] = add i32 [[InIx2]], 3 + // CHECK: [[SclHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Scales]] + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[SclHdl]], i32 [[SclIx]], i32 [[OFF0]], i8 1, i32 [[ALN]]) + // CHECK: [[scl0:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[SclHdl]], i32 [[SclIx]], i32 [[SOFF1]], i8 1, i32 [[ALN]]) + // CHECK: [[scl1:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[SclHdl]], i32 [[SclIx]], i32 [[SOFF2]], i8 1, i32 [[ALN]]) + // CHECK: [[scl2:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[SclHdl]], i32 [[SclIx]], i32 [[SOFF3]], i8 1, i32 [[ALN]]) + // CHECK: [[scl3:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[SclHdl]], i32 [[SclIx]], i32 [[SOFF4]], i8 1, i32 [[ALN]]) + // CHECK: [[scl4:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[SclHdl]], i32 [[SclIx]], i32 [[SOFF5]], i8 1, i32 [[ALN]]) + // CHECK: [[scl5:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[STY]] @dx.op.rawBufferLoad.[[STY]](i32 139, %dx.types.Handle [[SclHdl]], i32 [[SclIx]], i32 [[SOFF6]], i8 1, i32 [[ALN]]) + // CHECK: [[scl6:%.*]] = extractvalue %dx.types.ResRet.[[STY]] [[ld]], 0 + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl0]], i32 0 + // CHECK: [[spt0:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res0:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[spt0]], [[vec0]] + res[0] = things[0] + scales[0]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl1]], i32 0 + // CHECK: [[spt1:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res1:%[0-9]*]] = [[SUB]] <[[NUM]] x [[TYPE]]> [[vec1]], [[spt1]] + res[1] = things[1] - scales[1]; + + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl2]], i32 0 + // CHECK: [[spt2:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res2:%[0-9]*]] = [[MUL]] <[[NUM]] x [[TYPE]]> [[spt2]], [[vec2]] + res[2] = things[2] * scales[2]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl3]], i32 0 + // CHECK: [[spt3:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res3:%[0-9]*]] = [[DIV]] <[[NUM]] x [[TYPE]]> [[vec3]], [[spt3]] + res[3] = things[3] / scales[3]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl4]], i32 0 + // CHECK: [[spt4:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res4:%[0-9]*]] = [[ADD]] <[[NUM]] x [[TYPE]]> [[spt4]], [[vec4]] + res[4] = scales[4] + things[4]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl5]], i32 0 + // CHECK: [[spt5:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res5:%[0-9]*]] = [[SUB]] <[[NUM]] x [[TYPE]]> [[spt5]], [[vec5]] + res[5] = scales[5] - things[5]; + + // CHECK: [[spt:%[0-9]*]] = insertelement <[[NUM]] x [[TYPE]]> undef, [[TYPE]] [[scl6]], i32 0 + // CHECK: [[spt6:%[0-9]*]] = shufflevector <[[NUM]] x [[TYPE]]> [[spt]], <[[NUM]] x [[TYPE]]> undef, <[[NUM]] x i32> zeroinitializer + // CHECK: [[res6:%[0-9]*]] = [[MUL]] <[[NUM]] x [[TYPE]]> [[spt6]], [[vec6]] + res[6] = scales[6] * things[6]; + res[7] = res[8] = res[9] = res[10] = 0; + + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF0]], <[[NUM]] x [[TYPE]]> [[res0]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF1]], <[[NUM]] x [[TYPE]]> [[res1]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF2]], <[[NUM]] x [[TYPE]]> [[res2]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF3]], <[[NUM]] x [[TYPE]]> [[res3]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF4]], <[[NUM]] x [[TYPE]]> [[res4]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF5]], <[[NUM]] x [[TYPE]]> [[res5]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF6]], <[[NUM]] x [[TYPE]]> [[res6]], i32 [[ALN]]) + + return res; +} + +// Test logic operators. +// Only permissable in pre-HLSL2021 +vector logic(vector truth[10], vector consequences[11])[10] { + vector res[10]; + // CHECK: [[ResIx:%.*]] = add i32 [[OutIx]], 4 + // CHECK: [[TruHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Truths]] + // CHECK: [[TruIx:%.*]] = add i32 [[InIx2]], 4 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[ITY]] @dx.op.rawBufferVectorLoad.[[ITY]](i32 303, %dx.types.Handle [[TruHdl]], i32 [[TruIx]], i32 [[BOFF0]], i32 8) + // CHECK: [[ivec0:%.*]] = extractvalue %dx.types.ResRet.[[ITY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[ITY]] @dx.op.rawBufferVectorLoad.[[ITY]](i32 303, %dx.types.Handle [[TruHdl]], i32 [[TruIx]], i32 [[BOFF1]], i32 8) + // CHECK: [[ivec1:%.*]] = extractvalue %dx.types.ResRet.[[ITY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[ITY]] @dx.op.rawBufferVectorLoad.[[ITY]](i32 303, %dx.types.Handle [[TruHdl]], i32 [[TruIx]], i32 [[BOFF2]], i32 8) + // CHECK: [[ivec2:%.*]] = extractvalue %dx.types.ResRet.[[ITY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[ITY]] @dx.op.rawBufferVectorLoad.[[ITY]](i32 303, %dx.types.Handle [[TruHdl]], i32 [[TruIx]], i32 [[BOFF3]], i32 8) + // CHECK: [[ivec3:%.*]] = extractvalue %dx.types.ResRet.[[ITY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[ITY]] @dx.op.rawBufferVectorLoad.[[ITY]](i32 303, %dx.types.Handle [[TruHdl]], i32 [[TruIx]], i32 [[BOFF4]], i32 8) + // CHECK: [[ivec4:%.*]] = extractvalue %dx.types.ResRet.[[ITY]] [[ld]], 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[ITY]] @dx.op.rawBufferVectorLoad.[[ITY]](i32 303, %dx.types.Handle [[TruHdl]], i32 [[TruIx]], i32 [[BOFF5]], i32 8) + // CHECK: [[ivec5:%.*]] = extractvalue %dx.types.ResRet.[[ITY]] [[ld]], 0 + + // CHECK: [[VecIx:%.*]] = add i32 [[InIx1]], 4 + // CHECK: [[InHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Input]] + //CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF0]], i32 8) + //CHECK: [[vec0:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + //CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF1]], i32 8) + //CHECK: [[vec1:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + //CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF2]], i32 8) + //CHECK: [[vec2:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + //CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF3]], i32 8) + //CHECK: [[vec3:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + //CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF4]], i32 8) + //CHECK: [[vec4:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + //CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF5]], i32 8) + //CHECK: [[vec5:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + //CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF6]], i32 8) + //CHECK: [[vec6:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + + + // CHECK: [[cmp:%[0-9]*]] = icmp ne <[[NUM]] x i32> [[ivec0]], zeroinitializer + // CHECK: [[cmp0:%[0-9]*]] = icmp eq <[[NUM]] x i1> [[cmp]], zeroinitializer + // CHECK: [[res0:%[0-9]*]] = zext <[[NUM]] x i1> [[cmp0]] to <[[NUM]] x i32> + res[0] = !truth[0]; + + // CHECK: [[bvec1:%[0-9]*]] = icmp ne <[[NUM]] x i32> [[ivec1]], zeroinitializer + // CHECK: [[bvec2:%[0-9]*]] = icmp ne <[[NUM]] x i32> [[ivec2]], zeroinitializer + // CHECK: [[bres1:%[0-9]*]] = or <[[NUM]] x i1> [[bvec2]], [[bvec1]] + // CHECK: [[res1:%[0-9]*]] = zext <[[NUM]] x i1> [[bres1]] to <[[NUM]] x i32> + res[1] = truth[1] || truth[2]; + + // CHECK: [[bvec3:%[0-9]*]] = icmp ne <[[NUM]] x i32> [[ivec3]], zeroinitializer + // CHECK: [[bres2:%[0-9]*]] = and <[[NUM]] x i1> [[bvec3]], [[bvec2]] + // CHECK: [[res2:%[0-9]*]] = zext <[[NUM]] x i1> [[bres2]] to <[[NUM]] x i32> + res[2] = truth[2] && truth[3]; + + // CHECK: [[bvec4:%[0-9]*]] = icmp ne <[[NUM]] x i32> [[ivec4]], zeroinitializer + // CHECK: [[bvec5:%[0-9]*]] = icmp ne <[[NUM]] x i32> [[ivec5]], zeroinitializer + // CHECK: [[bres3:%[0-9]*]] = select <[[NUM]] x i1> [[bvec3]], <[[NUM]] x i1> [[bvec4]], <[[NUM]] x i1> [[bvec5]] + // CHECK: [[res3:%[0-9]*]] = zext <[[NUM]] x i1> [[bres3]] to <[[NUM]] x i32> + res[3] = truth[3] ? truth[4] : truth[5]; + + // CHECK: [[cmp4:%[0-9]*]] = [[CMP:[fi]?cmp( fast)?]] {{o?}}eq <[[NUM]] x [[TYPE]]> [[vec0]], [[vec1]] + // CHECK: [[res4:%[0-9]*]] = zext <[[NUM]] x i1> [[cmp4]] to <[[NUM]] x i32> + res[4] = consequences[0] == consequences[1]; + + // CHECK: [[cmp5:%[0-9]*]] = [[CMP]] {{u?}}ne <[[NUM]] x [[TYPE]]> [[vec1]], [[vec2]] + // CHECK: [[res5:%[0-9]*]] = zext <[[NUM]] x i1> [[cmp5]] to <[[NUM]] x i32> + res[5] = consequences[1] != consequences[2]; + + // CHECK: [[cmp6:%[0-9]*]] = [[CMP]] {{[osu]?}}lt <[[NUM]] x [[TYPE]]> [[vec2]], [[vec3]] + // CHECK: [[res6:%[0-9]*]] = zext <[[NUM]] x i1> [[cmp6]] to <[[NUM]] x i32> + res[6] = consequences[2] < consequences[3]; + + // CHECK: [[cmp7:%[0-9]*]] = [[CMP]] {{[osu]]?}}gt <[[NUM]] x [[TYPE]]> [[vec3]], [[vec4]] + // CHECK: [[res7:%[0-9]*]] = zext <[[NUM]] x i1> [[cmp7]] to <[[NUM]] x i32> + res[7] = consequences[3] > consequences[4]; + + // CHECK: [[cmp8:%[0-9]*]] = [[CMP]] {{[osu]]?}}le <[[NUM]] x [[TYPE]]> [[vec4]], [[vec5]] + // CHECK: [[res8:%[0-9]*]] = zext <[[NUM]] x i1> [[cmp8]] to <[[NUM]] x i32> + res[8] = consequences[4] <= consequences[5]; + + // CHECK: [[cmp9:%[0-9]*]] = [[CMP]] {{[osu]?}}ge <[[NUM]] x [[TYPE]]> [[vec5]], [[vec6]] + // CHECK: [[res9:%[0-9]*]] = zext <[[NUM]] x i1> [[cmp9]] to <[[NUM]] x i32> + res[9] = consequences[5] >= consequences[6]; + + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF0]], <[[NUM]] x i32> [[res0]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF1]], <[[NUM]] x i32> [[res1]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF2]], <[[NUM]] x i32> [[res2]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF3]], <[[NUM]] x i32> [[res3]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF4]], <[[NUM]] x i32> [[res4]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF5]], <[[NUM]] x i32> [[res5]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF6]], <[[NUM]] x i32> [[res6]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF7]], <[[NUM]] x i32> [[res7]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF8]], <[[NUM]] x i32> [[res8]], i32 4) + // CHECK: call void @dx.op.rawBufferVectorStore.[[ITY]](i32 304, %dx.types.Handle [[TruHdl]], i32 [[ResIx]], i32 [[BOFF9]], <[[NUM]] x i32> [[res9]], i32 4) + + return res; +} + +static const int Ix = 2; + +// Test indexing operators +vector index(vector things[11], int i, TYPE val)[11] { + vector res[11]; + + // CHECK: [[ResIx:%.*]] = add i32 [[OutIx]], 5 + // CHECK: [[ResHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Output]] + // CHECK: [[VecIx:%.*]] = add i32 [[InIx1]], 5 + // CHECK: [[InHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Input]] + + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 0 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF0]], i32 8) + // CHECK: [[vec0:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec0]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 1 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF1]], i32 8) + // CHECK: [[vec1:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec1]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 2 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF2]], i32 8) + // CHECK: [[vec2:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec2]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 3 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF3]], i32 8) + // CHECK: [[vec3:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec3]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 4 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF4]], i32 8) + // CHECK: [[vec4:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec4]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 5 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF5]], i32 8) + // CHECK: [[vec5:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec5]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 6 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF6]], i32 8) + // CHECK: [[vec6:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec6]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 7 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF7]], i32 8) + // CHECK: [[vec7:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec7]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 8 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF8]], i32 8) + // CHECK: [[vec8:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec8]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 9 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF9]], i32 8) + // CHECK: [[vec9:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec9]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 10 + // CHECK: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VecIx]], i32 [[OFF10]], i32 8) + // CHECK: [[vec10:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec10]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + + // CHECK: [[Ix:%.*]] = add i32 [[InIx2]], 5 + + // CHECK: [[adr0:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 0 + // CHECK: store <[[NUM]] x [[TYPE]]> zeroinitializer, <[[NUM]] x [[TYPE]]>* [[adr0]], align [[ALN]] + res[0] = 0; + + + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 [[Ix]] + // CHECK: store <[[NUM]] x [[TYPE]]> <[[TYPE]] [[POS1]],{{[^>]*}}>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + res[i] = 1; + + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 2 + // CHECK: store <[[NUM]] x [[TYPE]]> <[[TYPE]] [[TWO:(2|2\.?0*e?\+?0*|0xH4000)]],{{[^>]*}}>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + res[Ix] = 2; + + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 3 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec0]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + res[3] = things[0]; + + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch2]], i32 0, i32 [[Ix]] + // CHECK: [[ldix:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 4 + // CHECK: store <[[NUM]] x [[TYPE]]> [[ldix]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + res[4] = things[i]; + + + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 5 + // CHECK: store <[[NUM]] x [[TYPE]]> [[vec2]], <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + res[5] = things[Ix]; + + // CHECK: [[ld:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr0]], align [[ALN]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 0, <[[NUM]] x [[TYPE]]> [[ld]], i32 [[ALN]]) + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 1 + // CHECK: [[ld:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF1]], <[[NUM]] x [[TYPE]]> [[ld]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF2]], <[[NUM]] x [[TYPE]]> <[[TYPE]] [[TWO]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF3]], <[[NUM]] x [[TYPE]]> [[vec0]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF4]], <[[NUM]] x [[TYPE]]> [[ldix]], i32 [[ALN]]) + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF5]], <[[NUM]] x [[TYPE]]> [[vec2]], i32 [[ALN]]) + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 6 + // CHECK: [[ld:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF6]], <[[NUM]] x [[TYPE]]> [[ld]], i32 [[ALN]]) + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 7 + // CHECK: [[ld:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF7]], <[[NUM]] x [[TYPE]]> [[ld]], i32 [[ALN]]) + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 8 + // CHECK: [[ld:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF8]], <[[NUM]] x [[TYPE]]> [[ld]], i32 [[ALN]]) + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 9 + // CHECK: [[ld:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF9]], <[[NUM]] x [[TYPE]]> [[ld]], i32 [[ALN]]) + // CHECK: [[adr:%.*]] = getelementptr inbounds [11 x <[[NUM]] x [[TYPE]]>], [11 x <[[NUM]] x [[TYPE]]>]* [[scratch1]], i32 0, i32 10 + // CHECK: [[ld:%.*]] = load <[[NUM]] x [[TYPE]]>, <[[NUM]] x [[TYPE]]>* [[adr]], align [[ALN]] + // CHECK: call void @dx.op.rawBufferVectorStore.[[TY]](i32 304, %dx.types.Handle [[ResHdl]], i32 [[ResIx]], i32 [[OFF10]], <[[NUM]] x [[TYPE]]> [[ld]], i32 [[ALN]]) + + return res; +} + +#ifdef INT +// Test bit twiddling operators. +void bittwiddlers(inout vector things[13]) { + // INT: [[VcIx:%.*]] = add i32 [[InIx1]], 6 + // INT: [[InHdl:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[Bits]] + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF1]], i32 8) + // INT: [[vec1:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF2]], i32 8) + // INT: [[vec2:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF3]], i32 8) + // INT: [[vec3:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF4]], i32 8) + // INT: [[vec4:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF5]], i32 8) + // INT: [[vec5:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF6]], i32 8) + // INT: [[vec6:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF7]], i32 8) + // INT: [[vec7:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF8]], i32 8) + // INT: [[vec8:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF9]], i32 8) + // INT: [[vec9:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF10]], i32 8) + // INT: [[vec10:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF11]], i32 8) + // INT: [[vec11:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + // INT: [[ld:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[InHdl]], i32 [[VcIx]], i32 [[OFF12]], i32 8) + // INT: [[vec12:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[ld]], 0 + + // INT: [[res0:%[0-9]*]] = xor <[[NUM]] x [[TYPE]]> [[vec1]], <[[TYPE]] -1 + things[0] = ~things[1]; + + // INT: [[res1:%[0-9]*]] = or <[[NUM]] x [[TYPE]]> [[vec3]], [[vec2]] + things[1] = things[2] | things[3]; + + // INT: [[res2:%[0-9]*]] = and <[[NUM]] x [[TYPE]]> [[vec4]], [[vec3]] + things[2] = things[3] & things[4]; + + // INT: [[res3:%[0-9]*]] = xor <[[NUM]] x [[TYPE]]> [[vec4]], [[vec5]] + things[3] = things[4] ^ things[5]; + + // INT: [[shv6:%[0-9]*]] = and <[[NUM]] x [[TYPE]]> [[vec6]] + // INT: [[res4:%[0-9]*]] = shl <[[NUM]] x [[TYPE]]> [[vec5]], [[shv6]] + things[4] = things[5] << things[6]; + + // INT: [[shv7:%[0-9]*]] = and <[[NUM]] x [[TYPE]]> [[vec7]] + // UNSIG: [[res5:%[0-9]*]] = lshr <[[NUM]] x [[TYPE]]> [[vec6]], [[shv7]] + // SIG: [[res5:%[0-9]*]] = ashr <[[NUM]] x [[TYPE]]> [[vec6]], [[shv7]] + things[5] = things[6] >> things[7]; + + // INT: [[res6:%[0-9]*]] = or <[[NUM]] x [[TYPE]]> [[vec8]], [[vec6]] + things[6] |= things[8]; + + // INT: [[res7:%[0-9]*]] = and <[[NUM]] x [[TYPE]]> [[vec9]], [[vec7]] + things[7] &= things[9]; + + // INT: [[res8:%[0-9]*]] = xor <[[NUM]] x [[TYPE]]> [[vec8]], [[vec10]] + things[8] ^= things[10]; + + // INT: [[shv11:%[0-9]*]] = and <[[NUM]] x [[TYPE]]> [[vec11]] + // INT: [[res9:%[0-9]*]] = shl <[[NUM]] x [[TYPE]]> [[vec9]], [[shv11]] + things[9] <<= things[11]; + + // INT: [[shv12:%[0-9]*]] = and <[[NUM]] x [[TYPE]]> [[vec12]] + // UNSIG: [[res10:%[0-9]*]] = lshr <[[NUM]] x [[TYPE]]> [[vec10]], [[shv12]] + // SIG: [[res10:%[0-9]*]] = ashr <[[NUM]] x [[TYPE]]> [[vec10]], [[shv12]] + things[10] >>= things[12]; + + // INT: ret void +} +#endif // INT diff --git a/tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators.hlsl index ed7a2bff25..ba76eca619 100644 --- a/tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators.hlsl +++ b/tools/clang/test/CodeGenDXIL/hlsl/types/longvec-operators.hlsl @@ -48,24 +48,6 @@ struct Interface { TYPE scales[10]; }; -#if 0 -// Requires vector loading support. Enable when available. -RWStructuredBuffer Input; -RWStructuredBuffer Output; - -TYPE g_val; - -[shader("compute")] -[numthreads(8,1,1)] -void main(uint GI : SV_GroupIndex) { - assignments(Output[GI].assigned, Input[GI].scales); - Output[GI].arithmeticked = arithmetic(Input[GI].arithmeticked); - Output[GI].scarithmeticked = scarithmetic(Input[GI].scarithmeticked, Input[GI].scales); - Output[GI].logicked = logic(Input[GI].logicked, Input[GI].assigned); - Output[GI].indexed = index(Input[GI].indexed, GI, g_val); -} -#endif - // A mixed-type overload to test overload resolution and mingle different vector element types in ops // Test assignment operators. // CHECK-LABEL: define void @"\01?assignments diff --git a/tools/clang/test/DXILValidation/vector-validation.hlsl b/tools/clang/test/DXILValidation/vector-validation.hlsl new file mode 100644 index 0000000000..87f24b2b0b --- /dev/null +++ b/tools/clang/test/DXILValidation/vector-validation.hlsl @@ -0,0 +1,19 @@ +// RUN: %dxc -T vs_6_9 %s -Od | FileCheck %s + +// Just HLSL source for validation that vector operations produce errors pre-6.9 +// Output is modified to have 6.8 instead. + +struct Vector { int i; float4 f;}; + +RWStructuredBuffer VecBuf; +RWStructuredBuffer StrBuf; +RWStructuredBuffer ScalBuf; + +// some simple ways to generate the vector ops in question. +// CHECK-LABEL: define void @main +float4 main(float val : VAL) :SV_Position { + float4 vec = VecBuf[1]; + VecBuf[0] = val; + return vec[2]; +} + diff --git a/tools/clang/test/HLSLFileCheck/hlsl/linker/resources/preserve_sb_types.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/linker/resources/preserve_sb_types.hlsl index f9c75a9381..82dd3586c1 100644 --- a/tools/clang/test/HLSLFileCheck/hlsl/linker/resources/preserve_sb_types.hlsl +++ b/tools/clang/test/HLSLFileCheck/hlsl/linker/resources/preserve_sb_types.hlsl @@ -155,5 +155,7 @@ export float4 xform(float4 v) { [shader("vertex")] float4 main(float3 pos : Position) : SV_Position { - return xform(float4(pos, 1)) * StructBuf[0].f; + float4 res = xform(float4(pos, 1)); + res *=StructBuf[0].f; + return res ; } diff --git a/tools/clang/test/LitDXILValidation/vector-validation.ll b/tools/clang/test/LitDXILValidation/vector-validation.ll new file mode 100644 index 0000000000..74e8116e88 --- /dev/null +++ b/tools/clang/test/LitDXILValidation/vector-validation.ll @@ -0,0 +1,78 @@ +; RUN: not %dxv %s 2>&1 | FileCheck %s + +; Confirm that 6.9 specific LLVM operations and DXIL intrinsics fail in 6.8 + +target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-ms-dx" + +%dx.types.Handle = type { i8* } +%dx.types.ResBind = type { i32, i32, i32, i8 } +%dx.types.ResourceProperties = type { i32, i32 } +%dx.types.ResRet.v4f32 = type { <4 x float>, i32 } +%"class.RWStructuredBuffer >" = type { <4 x float> } + +; CHECK: Function: main: error: Instructions must be of an allowed type. +; CHECK: note: at '%6 = insertelement <4 x float> undef, float %2, i32 0 +; CHECK: Function: main: error: Instructions must be of an allowed type. +; CHECK: note: at '%7 = shufflevector <4 x float> %6, <4 x float> undef, <4 x i32> zeroinitializer +; CHECK: Function: main: error: Instructions must be of an allowed type. +; CHECK: note: at '%8 = extractelement <4 x float> %5, i32 2 +; CHECK: Function: main: error: Opcode RawBufferVectorLoad not valid in shader model vs_6_8. +; CHECK: note: at '%4 = call %dx.types.ResRet.v4f32 @dx.op.rawBufferVectorLoad.v4f32(i32 303, %dx.types.Handle %3, i32 1, i32 0, i32 8)' +; CHECK: Function: main: error: Opcode RawBufferVectorStore not valid in shader model vs_6_8. +; CHECK: note: at 'call void @dx.op.rawBufferVectorStore.v4f32(i32 304, %dx.types.Handle %3, i32 0, i32 0, <4 x float> %7, i32 4)' +; CHECK: Function: main: error: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details. +; CHECK: Function: main: error: Function uses features incompatible with the shader model. +define void @main() { + %1 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 0, i32 0, i8 1 }, i32 0, i1 false) + %2 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0, i32 undef) + %3 = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 4108, i32 16 }) + %4 = call %dx.types.ResRet.v4f32 @dx.op.rawBufferVectorLoad.v4f32(i32 303, %dx.types.Handle %3, i32 1, i32 0, i32 8) + %5 = extractvalue %dx.types.ResRet.v4f32 %4, 0 + %6 = insertelement <4 x float> undef, float %2, i32 0 + %7 = shufflevector <4 x float> %6, <4 x float> undef, <4 x i32> zeroinitializer + call void @dx.op.rawBufferVectorStore.v4f32(i32 304, %dx.types.Handle %3, i32 0, i32 0, <4 x float> %7, i32 4) + %8 = extractelement <4 x float> %5, i32 2 + call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %8) + call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %8) + call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %8) + call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %8) + ret void +} + +declare float @dx.op.loadInput.f32(i32, i32, i32, i8, i32) #0 +declare void @dx.op.storeOutput.f32(i32, i32, i32, i8, float) #1 +declare %dx.types.ResRet.v4f32 @dx.op.rawBufferVectorLoad.v4f32(i32, %dx.types.Handle, i32, i32, i32) #2 +declare void @dx.op.rawBufferVectorStore.v4f32(i32, %dx.types.Handle, i32, i32, <4 x float>, i32) #1 +declare %dx.types.Handle @dx.op.annotateHandle(i32, %dx.types.Handle, %dx.types.ResourceProperties) #0 +declare %dx.types.Handle @dx.op.createHandleFromBinding(i32, %dx.types.ResBind, i32, i1) #0 + +attributes #0 = { nounwind readnone } +attributes #1 = { nounwind } +attributes #2 = { nounwind readonly } + +!dx.version = !{!1} +!dx.valver = !{!1} +!dx.shaderModel = !{!2} +!dx.resources = !{!3} +!dx.viewIdState = !{!7} +!dx.entryPoints = !{!8} + +!1 = !{i32 1, i32 8} +!2 = !{!"vs", i32 6, i32 8} +!3 = !{null, !4, null, null} +!4 = !{!5} +!5 = !{i32 0, %"class.RWStructuredBuffer >"* undef, !"", i32 0, i32 0, i32 1, i32 12, i1 false, i1 false, i1 false, !6} +!6 = !{i32 1, i32 16} +!7 = !{[3 x i32] [i32 1, i32 4, i32 0]} +!8 = !{void ()* @main, !"main", !9, !3, !17} +!9 = !{!10, !14, null} +!10 = !{!11} +!11 = !{i32 0, !"VAL", i8 9, i8 0, !12, i8 0, i32 1, i8 1, i32 0, i8 0, !13} +!12 = !{i32 0} +!13 = !{i32 3, i32 1} +!14 = !{!15} +!15 = !{i32 0, !"SV_Position", i8 9, i8 3, !12, i8 4, i32 1, i8 4, i32 0, i8 0, !16} +!16 = !{i32 3, i32 15} +!17 = !{i32 0, i64 8590000144} + diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 57faee2fb2..0bf255ed65 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -604,6 +604,8 @@ def populate_categories_and_models(self): for i in "RawBufferLoad,RawBufferStore".split(","): self.name_idx[i].shader_model = 6, 2 self.name_idx[i].shader_model_translated = 6, 0 + for i in "RawBufferVectorLoad,RawBufferVectorStore,InsertElement,ShuffleVector,ExtractValue".split(","): + self.name_idx[i].shader_model = 6, 9 for i in "DispatchRaysIndex,DispatchRaysDimensions".split(","): self.name_idx[i].category = "Ray Dispatch Arguments" self.name_idx[i].shader_model = 6, 3 @@ -3504,7 +3506,7 @@ def UFI(name, **mappings): "ro", [ db_dxil_param(0, "$r", "", "the loaded value"), - db_dxil_param(2, "res", "srv", "handle of TypedBuffer SRV to sample"), + db_dxil_param(2, "res", "buf", "handle of Raw Buffer to load from"), db_dxil_param( 3, "i32", @@ -5776,6 +5778,83 @@ def UFI(name, **mappings): # Reserved block C next_op_idx = self.reserve_dxil_op_range("ReservedC", next_op_idx, 10) + # Long Vectors + self.add_dxil_op( + "RawBufferVectorLoad", + next_op_idx, + "RawBufferVectorLoad", + "reads from a raw buffer and structured buffer", + "hfwidl<", + "ro", + [ + db_dxil_param(0, "$r", "", "the loaded value"), + db_dxil_param(2, "res", "buf", "handle of Raw Buffer to load from"), + db_dxil_param( + 3, + "i32", + "index", + "element index for StructuredBuffer, or byte offset for ByteAddressBuffer", + ), + db_dxil_param( + 4, + "i32", + "elementOffset", + "offset into element for StructuredBuffer, or undef for ByteAddressBuffer", + ), + db_dxil_param( + 5, + "i32", + "alignment", + "relative load access alignment", + is_const=True, + ), + ], + counters=("tex_load",), + ) + next_op_idx += 1 + + self.add_dxil_op( + "RawBufferVectorStore", + next_op_idx, + "RawBufferVectorStore", + "writes to a RWByteAddressBuffer or RWStructuredBuffer", + "hfwidl<", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "res", "uav", "handle of UAV to store to"), + db_dxil_param( + 3, + "i32", + "index", + "element index for StructuredBuffer, or byte offset for ByteAddressBuffer", + ), + db_dxil_param( + 4, + "i32", + "elementOffset", + "offset into element for StructuredBuffer, or undef for ByteAddressBuffer", + ), + db_dxil_param(5, "$o", "value0", "value"), + db_dxil_param( + 6, + "i32", + "alignment", + "relative store access alignment", + is_const=True, + ), + ], + counters=("tex_store",), + ) + next_op_idx += 1 + + # End of DXIL 1.9 opcodes. + self.set_op_count_for_version(1, 9, next_op_idx) + assert next_op_idx == 305, ( + "305 is expected next operation index but encountered %d and thus opcodes are broken" + % next_op_idx + ) + # Set interesting properties. self.build_indices() for ( @@ -6383,6 +6462,12 @@ def add_pass(name, type_name, doc, opts): "DXIL Lower createHandleForLib", [], ) + add_pass( + "hlsl-dxil-scalarize-vector-load-stores", + "DxilScalarizeVectorLoadStores", + "DXIL scalarize vector load/stores", + [], + ) add_pass( "hlsl-dxil-cleanup-dynamic-resource-handle", "DxilCleanupDynamicResourceHandle", From 9ec8bcc94ae43afe229c004b9428f4211bbd354b Mon Sep 17 00:00:00 2001 From: Greg Roth Date: Mon, 2 Dec 2024 22:31:31 -1000 Subject: [PATCH 03/31] Enable select native vector intrinsics actually allow the given ops to take vectors add vector overload type and apply to the relevant builtins Build lowering functions to allow vector supporting intrinsics through Preliminary groupshared support. keep groupshared as vectors for 6.9. They are no longer represented as inidivual groupshared scalars. adds groupshared to the test and performs the switch to CS to allow it. Support dot product on long vecs by expanding the inrinsic into mul/mad ops like is done with integer dot products Since the or() and and() intrinsics did their own scalarization, the or/and operators would never be applied to full vectors. This leaves the scalarization for the scalarization pass, which will skip it for 6.9 --- lib/DXIL/DxilOperations.cpp | 100 ++++---- lib/HLSL/HLOperationLower.cpp | 216 +++++++++++++----- .../test/CodeGenDXIL/hlsl/types/longvecs.hlsl | 154 +++++++++++++ utils/hct/hctdb.py | 8 +- 4 files changed, 363 insertions(+), 115 deletions(-) create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/types/longvecs.hlsl diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index c422d86593..5e0757aac1 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -144,112 +144,112 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = { "unary", Attribute::ReadNone, 1, - {{0x3}}, - {{0x0}}}, // Overloads: hf + {{0x403}}, + {{0x3}}}, // Overloads: hf refArgs, return TrivialDxilOperation(opcode, refArgs, Ty, Inst->getType(), hlslOP, B); } +Value *TrivialDxilVectorOperation(Function *dxilFunc, OP::OpCode opcode, + ArrayRef refArgs, Type *Ty, + OP *hlslOP, IRBuilder<> &Builder) { + if (!Ty->isVoidTy()) { + Value *retVal = + Builder.CreateCall(dxilFunc, refArgs, hlslOP->GetOpCodeName(opcode)); + return retVal; + } else { + // Cannot add name to void. + return Builder.CreateCall(dxilFunc, refArgs); + } +} + + +Value *TrivialDxilVectorUnaryOperationRet(OP::OpCode opcode, Value *src, Type *Ty, + OP *hlslOP, IRBuilder<> &Builder) { + + Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); + Value *args[] = {opArg, src}; + + Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty); + + return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, Builder); +} + +Value *TrivialDxilVectorBinaryOperation(OP::OpCode opcode, Value *src0, Value *src1, + hlsl::OP *hlslOP, IRBuilder<> &Builder) { + Type *Ty = src0->getType(); + + Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); + Value *args[] = {opArg, src0, src1}; + + Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty); + + return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, Builder); +} + Value *TrivialDxilUnaryOperationRet(OP::OpCode opcode, Value *src, Type *RetTy, hlsl::OP *hlslOP, IRBuilder<> &Builder) { Type *Ty = src->getType(); @@ -507,17 +544,26 @@ Value *TrivialDxilBinaryOperation(OP::OpCode opcode, Value *src0, Value *src1, return TrivialDxilOperation(opcode, args, Ty, Ty, hlslOP, Builder); } -Value *TrivialDxilTrinaryOperation(OP::OpCode opcode, Value *src0, Value *src1, - Value *src2, hlsl::OP *hlslOP, - IRBuilder<> &Builder) { - Type *Ty = src0->getType(); - +Value *TrivialDxilTrinaryOperationRet(OP::OpCode opcode, Value *src0, Value *src1, + Value *src2, Type *Ty, hlsl::OP *hlslOP, + IRBuilder<> &Builder) { Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); Value *args[] = {opArg, src0, src1, src2}; return TrivialDxilOperation(opcode, args, Ty, Ty, hlslOP, Builder); } +Value *TrivialDxilVectorTrinaryOperationRet(OP::OpCode opcode, Value *src0, Value *src1, + Value *src2, Type *Ty, hlsl::OP *hlslOP, + IRBuilder<> &Builder) { + Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); + Value *args[] = {opArg, src0, src1, src2}; + + Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty); + + return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, Builder); +} + Value *TrivialUnaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, @@ -530,6 +576,24 @@ Value *TrivialUnaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, return retVal; } +Value *TrivialVectorizableUnaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, + HLOperationLowerHelper &helper, + HLObjectOperationLowerHelper *pObjHelper, + bool &Translated) { + Value *src0 = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx); + Type *Ty = CI->getType(); + IRBuilder<> Builder(CI); + hlsl::OP *hlslOP = &helper.hlslOP; + + if (Ty->isVectorTy() && + helper.M.GetShaderModel()->IsSM69Plus()) + return TrivialDxilVectorUnaryOperationRet(opcode, src0, Ty, + hlslOP, Builder); + else + return TrivialDxilUnaryOperationRet(opcode, src0, Ty, + hlslOP, Builder); +} + Value *TrivialBinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, @@ -544,19 +608,36 @@ Value *TrivialBinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, return binOp; } -Value *TrivialTrinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { +Value *TrivialVectorBinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, + HLOperationLowerHelper &helper, + HLObjectOperationLowerHelper *pObjHelper, + bool &Translated) { hlsl::OP *hlslOP = &helper.hlslOP; + Value *src0 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx); + Value *src1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx); + IRBuilder<> Builder(CI); + + Value *binOp = + TrivialDxilVectorBinaryOperation(opcode, src0, src1, hlslOP, Builder); + return binOp; +} + +Value *TranslateFMA(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, + HLOperationLowerHelper &helper, + HLObjectOperationLowerHelper *pObjHelper, + bool &Translated) { + hlsl::OP *hlslOP = &helper.hlslOP; + Type *Ty = CI->getType(); Value *src0 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx); Value *src1 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx); Value *src2 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx); IRBuilder<> Builder(CI); - Value *triOp = - TrivialDxilTrinaryOperation(opcode, src0, src1, src2, hlslOP, Builder); - return triOp; + if (Ty->isVectorTy() && + helper.M.GetShaderModel()->IsSM69Plus()) + return TrivialDxilVectorTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, Builder); + else + return TrivialDxilTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, Builder); } Value *TrivialIsSpecialFloat(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -735,6 +816,12 @@ bool CanUseFxcMulOnlyPatternForPow(IRBuilder<> &Builder, Value *x, Value *pow, } } + // Only apply on aggregates of 16 or fewer elements, + // representing the max 4x4 matrix size. + Type *xTy = x->getType(); + if (xTy->isVectorTy() && xTy->getVectorNumElements() > 16) + return false; + APFloat powAPF = isa(pow) ? cast(pow)->getElementAsAPFloat(0) : // should be a splat value @@ -1896,9 +1983,16 @@ Value *TranslateClamp(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, IRBuilder<> Builder(CI); // min(max(x, minVal), maxVal). - Value *maxXMinVal = + if (Ty->isVectorTy() && + helper.M.GetShaderModel()->IsSM69Plus()) { + Value *maxXMinVal = + TrivialDxilVectorBinaryOperation(maxOp, x, minVal, hlslOP, Builder); + return TrivialDxilVectorBinaryOperation(minOp, maxXMinVal, maxVal, hlslOP, Builder); + } else { + Value *maxXMinVal = TrivialDxilBinaryOperation(maxOp, x, minVal, hlslOP, Builder); - return TrivialDxilBinaryOperation(minOp, maxXMinVal, maxVal, hlslOP, Builder); + return TrivialDxilBinaryOperation(minOp, maxXMinVal, maxVal, hlslOP, Builder); + } } Value *TranslateClip(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -2211,8 +2305,11 @@ Value *TranslateExp(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, ConstantVector::getSplat(Ty->getVectorNumElements(), log2eConst); } val = Builder.CreateFMul(log2eConst, val); - Value *exp = TrivialDxilUnaryOperation(OP::OpCode::Exp, val, hlslOP, Builder); - return exp; + if (Ty->isVectorTy() && + helper.M.GetShaderModel()->IsSM69Plus()) + return TrivialDxilVectorUnaryOperationRet(OP::OpCode::Exp, val, Ty, hlslOP, Builder); + else + return TrivialDxilUnaryOperationRet(OP::OpCode::Exp, val, Ty, hlslOP, Builder); } Value *TranslateLog(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -2227,7 +2324,12 @@ Value *TranslateLog(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, if (Ty != Ty->getScalarType()) { ln2Const = ConstantVector::getSplat(Ty->getVectorNumElements(), ln2Const); } - Value *log = TrivialDxilUnaryOperation(OP::OpCode::Log, val, hlslOP, Builder); + Value *log = nullptr; + if (Ty->isVectorTy() && + helper.M.GetShaderModel()->IsSM69Plus()) + log = TrivialDxilVectorUnaryOperationRet(OP::OpCode::Log, val, Ty, hlslOP, Builder); + else + log = TrivialDxilUnaryOperationRet(OP::OpCode::Log, val, Ty, hlslOP, Builder); return Builder.CreateFMul(ln2Const, log); } @@ -2287,8 +2389,13 @@ Value *TranslateFUIBinary(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, break; } } - return TrivialBinaryOperation(CI, IOP, opcode, helper, pObjHelper, - Translated); + if (CI->getType()->isVectorTy() && + helper.M.GetShaderModel()->IsSM69Plus()) + return TrivialVectorBinaryOperation(CI, IOP, opcode, helper, pObjHelper, + Translated); + else + return TrivialBinaryOperation(CI, IOP, opcode, helper, pObjHelper, + Translated); } Value *TranslateFUITrinary(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -2305,8 +2412,15 @@ Value *TranslateFUITrinary(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, break; } } - return TrivialTrinaryOperation(CI, IOP, opcode, helper, pObjHelper, - Translated); + + hlsl::OP *hlslOP = &helper.hlslOP; + Type *Ty = CI->getType(); + Value *src0 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx); + Value *src1 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx); + Value *src2 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx); + IRBuilder<> Builder(CI); + + return TrivialDxilTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, Builder); } Value *TranslateFrexp(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -2428,18 +2542,22 @@ Value *TrivialDotOperation(OP::OpCode opcode, Value *src0, Value *src1, return dotOP; } -Value *TranslateIDot(Value *arg0, Value *arg1, unsigned vecSize, - hlsl::OP *hlslOP, IRBuilder<> &Builder, - bool Unsigned = false) { +// Instead of using a DXIL intrinsic, implement a dot product operation using +// multiply and add operations. Used for integer dots and long vectors. +Value *ExpandDot(Value *arg0, Value *arg1, unsigned vecSize, + hlsl::OP *hlslOP, IRBuilder<> &Builder, + bool Unsigned = false) { auto madOpCode = Unsigned ? DXIL::OpCode::UMad : DXIL::OpCode::IMad; + if (arg0->getType()->getScalarType()->isFloatingPointTy()) + madOpCode = DXIL::OpCode::FMad; Value *Elt0 = Builder.CreateExtractElement(arg0, (uint64_t)0); Value *Elt1 = Builder.CreateExtractElement(arg1, (uint64_t)0); Value *Result = Builder.CreateMul(Elt0, Elt1); - for (unsigned iVecElt = 1; iVecElt < vecSize; ++iVecElt) { - Elt0 = Builder.CreateExtractElement(arg0, iVecElt); - Elt1 = Builder.CreateExtractElement(arg1, iVecElt); - Result = TrivialDxilTrinaryOperation(madOpCode, Elt0, Elt1, Result, hlslOP, - Builder); + for (unsigned Elt = 1; Elt < vecSize; ++Elt) { + Elt0 = Builder.CreateExtractElement(arg0, Elt); + Elt1 = Builder.CreateExtractElement(arg1, Elt); + Result = TrivialDxilTrinaryOperationRet(madOpCode, Elt0, Elt1, Result, Elt0->getType(), hlslOP, + Builder); } return Result; @@ -2477,10 +2595,10 @@ Value *TranslateDot(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, unsigned vecSize = Ty->getVectorNumElements(); Value *arg1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx); IRBuilder<> Builder(CI); - if (Ty->getScalarType()->isFloatingPointTy()) { + if (Ty->getScalarType()->isFloatingPointTy() && Ty->getVectorNumElements() <= 4) { return TranslateFDot(arg0, arg1, vecSize, hlslOP, Builder); } else { - return TranslateIDot(arg0, arg1, vecSize, hlslOP, Builder, + return ExpandDot(arg0, arg1, vecSize, hlslOP, Builder, IOP == IntrinsicOp::IOP_udot); } } @@ -2664,8 +2782,8 @@ Value *TranslateMSad4(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, byteSrc = Builder.CreateInsertElement(byteSrc, byteSrcElt, 3); // Msad on vecref and byteSrc. - return TrivialDxilTrinaryOperation(DXIL::OpCode::Msad, vecRef, byteSrc, accum, - hlslOP, Builder); + return TrivialDxilTrinaryOperationRet(DXIL::OpCode::Msad, vecRef, byteSrc, accum, + vecRef->getType(), hlslOP, Builder); } Value *TranslateRCP(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -3029,7 +3147,7 @@ Value *TranslateMul(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, if (arg0Ty->getScalarType()->isFloatingPointTy()) { return TranslateFDot(arg0, arg1, vecSize, hlslOP, Builder); } else { - return TranslateIDot(arg0, arg1, vecSize, hlslOP, Builder, + return ExpandDot(arg0, arg1, vecSize, hlslOP, Builder, IOP == IntrinsicOp::IOP_umul); } } else { @@ -6145,20 +6263,8 @@ Value *TranslateAnd(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, bool &Translated) { Value *x = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx); Value *y = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx); - Type *Ty = CI->getType(); - Type *EltTy = Ty->getScalarType(); IRBuilder<> Builder(CI); - if (Ty != EltTy) { - Value *Result = UndefValue::get(Ty); - for (unsigned i = 0; i < Ty->getVectorNumElements(); i++) { - Value *EltX = Builder.CreateExtractElement(x, i); - Value *EltY = Builder.CreateExtractElement(y, i); - Value *tmp = Builder.CreateAnd(EltX, EltY); - Result = Builder.CreateInsertElement(Result, tmp, i); - } - return Result; - } return Builder.CreateAnd(x, y); } Value *TranslateOr(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -6166,20 +6272,8 @@ Value *TranslateOr(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, HLObjectOperationLowerHelper *pObjHelper, bool &Translated) { Value *x = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx); Value *y = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx); - Type *Ty = CI->getType(); - Type *EltTy = Ty->getScalarType(); IRBuilder<> Builder(CI); - if (Ty != EltTy) { - Value *Result = UndefValue::get(Ty); - for (unsigned i = 0; i < Ty->getVectorNumElements(); i++) { - Value *EltX = Builder.CreateExtractElement(x, i); - Value *EltY = Builder.CreateExtractElement(y, i); - Value *tmp = Builder.CreateOr(EltX, EltY); - Result = Builder.CreateInsertElement(Result, tmp, i); - } - return Result; - } return Builder.CreateOr(x, y); } Value *TranslateSelect(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -6455,7 +6549,7 @@ IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP_asint16, TranslateBitcast, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_asuint, TranslateAsUint, DXIL::OpCode::SplitDouble}, {IntrinsicOp::IOP_asuint16, TranslateAsUint, DXIL::OpCode::NumOpCodes}, - {IntrinsicOp::IOP_atan, TrivialUnaryOperation, DXIL::OpCode::Atan}, + {IntrinsicOp::IOP_atan, TrivialVectorizableUnaryOperation, DXIL::OpCode::Atan}, {IntrinsicOp::IOP_atan2, TranslateAtan2, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_ceil, TrivialUnaryOperation, DXIL::OpCode::Round_pi}, {IntrinsicOp::IOP_clamp, TranslateClamp, DXIL::OpCode::NumOpCodes}, @@ -6498,7 +6592,7 @@ IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP_firstbitlow, TranslateFirstbitLo, DXIL::OpCode::FirstbitLo}, {IntrinsicOp::IOP_floor, TrivialUnaryOperation, DXIL::OpCode::Round_ni}, - {IntrinsicOp::IOP_fma, TrivialTrinaryOperation, DXIL::OpCode::Fma}, + {IntrinsicOp::IOP_fma, TranslateFMA, DXIL::OpCode::Fma}, {IntrinsicOp::IOP_fmod, TranslateFMod, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_frac, TrivialUnaryOperation, DXIL::OpCode::Frc}, {IntrinsicOp::IOP_frexp, TranslateFrexp, DXIL::OpCode::NumOpCodes}, @@ -6546,7 +6640,7 @@ IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP_sqrt, TrivialUnaryOperation, DXIL::OpCode::Sqrt}, {IntrinsicOp::IOP_step, TranslateStep, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_tan, TrivialUnaryOperation, DXIL::OpCode::Tan}, - {IntrinsicOp::IOP_tanh, TrivialUnaryOperation, DXIL::OpCode::Htan}, + {IntrinsicOp::IOP_tanh, TrivialVectorizableUnaryOperation, DXIL::OpCode::Htan}, {IntrinsicOp::IOP_tex1D, EmptyLower, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_tex1Dbias, EmptyLower, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_tex1Dgrad, EmptyLower, DXIL::OpCode::NumOpCodes}, diff --git a/tools/clang/test/CodeGenDXIL/hlsl/types/longvecs.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/types/longvecs.hlsl new file mode 100644 index 0000000000..1910e08a25 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/types/longvecs.hlsl @@ -0,0 +1,154 @@ +// RUN: %dxc -Wno-conversion -T cs_6_9 %s | FileCheck %s --check-prefixes=CHECK,F32 +// RUN: %dxc -Wno-conversion -T cs_6_9 -DF64 %s | FileCheck %s --check-prefixes=CHECK,F64 + +RWByteAddressBuffer buf; + +// "TYPE" is the mainly focused test type. +// "UNTYPE" is the other type used for mixed precision testing. +#ifdef F64 +typedef double TYPE; +typedef float UNTYPE; +#else +typedef float TYPE; +typedef double UNTYPE; +#endif + +// Two main test function overloads. One expects matching element types. +// The other uses different types to test ops and overload resolution. +template vector dostuff(vector thing1, vector thing2, vector thing3); +vector dostuff(vector thing1, vector thing2, vector thing3); + +// Just a trick to capture the needed type spellings since the DXC version of FileCheck can't do that explicitly. +// F32-DAG: %dx.types.ResRet.[[TY:v8f32]] = type { [[TYPE:<8 x float>]] +// F32-DAG: %dx.types.ResRet.[[UNTY:v8f64]] = type { [[UNTYPE:<8 x double>]] +// F64-DAG: %dx.types.ResRet.[[TY:v8f64]] = type { [[TYPE:<8 x double>]] +// F64-DAG: %dx.types.ResRet.[[UNTY:v8f32]] = type { [[UNTYPE:<8 x float>]] + +// Verify that groupshared vectors are kept as aggregates +// CHECK: @"\01?gs_vec1@@3V?$vector@{{M|N}}$07@@A" = external addrspace(3) global [[TYPE]] +// CHECK: @"\01?gs_vec2@@3V?$vector@{{M|N}}$07@@A" = external addrspace(3) global [[TYPE]] +// CHECK: @"\01?gs_vec3@@3V?$vector@{{M|N}}$07@@A" = external addrspace(3) global [[TYPE]] +groupshared vector gs_vec1, gs_vec2, gs_vec3; + +[numthreads(8,1,1)] +void main() { + // CHECK: [[buf:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 4107, i32 0 }) ; AnnotateHandle(res,props) resource: RWByteAddressBuffer + + // CHECK: [[vec1_res:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[buf]], i32 0 + // CHECK-DAG: [[vec1:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec1_res]], 0 + // F32-DAG: [[vec1_32:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec1_res]], 0 + // F64-DAG: [[vec1_64:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec1_res]], 0 + vector vec1 = buf.Load >(0); + + // CHECK: [[vec2_res:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[buf]], i32 60 + // CHECK-DAG: [[vec2:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec2_res]], 0 + // F32-DAG: [[vec2_32:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec2_res]], 0 + // F64-DAG: [[vec2_64:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec2_res]], 0 + vector vec2 = buf.Load >(60); + + // CHECK: [[vec3_res:%.*]] = call %dx.types.ResRet.[[TY]] @dx.op.rawBufferVectorLoad.[[TY]](i32 303, %dx.types.Handle [[buf]], i32 120 + // CHECK-DAG: [[vec3:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec3_res]], 0 + // F64-DAG: [[vec3_64:%.*]] = extractvalue %dx.types.ResRet.[[TY]] [[vec3_res]], 0 + vector vec3 = buf.Load >(120); + + // CHECK: [[unvec_res:%.*]] = call %dx.types.ResRet.[[UNTY]] @dx.op.rawBufferVectorLoad.[[UNTY]](i32 303, %dx.types.Handle [[buf]], i32 180 + // CHECK-DAG: [[unvec:%.*]] = extractvalue %dx.types.ResRet.[[UNTY]] [[unvec_res]], 0 + // F32-DAG: [[unvec_64:%.*]] = extractvalue %dx.types.ResRet.[[UNTY]] [[unvec_res]], 0 + // F64-DAG: [[unvec_32:%.*]] = extractvalue %dx.types.ResRet.[[UNTY]] [[unvec_res]], 0 + vector unvec = buf.Load >(180); + + vec1 = dostuff(vec1, vec2, vec3); + + // Test mixed type operations + vec2 = dostuff(vec2, unvec, vec3); + + gs_vec2 = dostuff(gs_vec1, gs_vec2, gs_vec3); + + // mix groupshared and non + //vec1 = dostuff(vec1, gs_vec2, vec3); + + buf.Store >(240, vec1 * vec2 - vec3 * gs_vec1 + gs_vec2 / gs_vec3); +} + +// Test the required ops on long vectors and confirm correct lowering. +template +vector dostuff(vector thing1, vector thing2, vector thing3) { + vector res = 0; + + // CHECK: call [[TYPE]] @dx.op.binary.[[TY]](i32 36, [[TYPE]] [[vec1]], [[TYPE]] [[vec2]]) ; FMin(a,b) + res += min(thing1, thing2); + // CHECK: call [[TYPE]] @dx.op.binary.[[TY]](i32 35, [[TYPE]] [[vec1]], [[TYPE]] [[vec3]]) ; FMax(a,b) + res += max(thing1, thing3); + + // CHECK: [[tmp:%.*]] = call [[TYPE]] @dx.op.binary.[[TY]](i32 35, [[TYPE]] [[vec1]], [[TYPE]] [[vec2]]) ; FMax(a,b) + // CHECK: call [[TYPE]] @dx.op.binary.[[TY]](i32 36, [[TYPE]] [[tmp]], [[TYPE]] [[vec3]]) ; FMin(a,b) + res += clamp(thing1, thing2, thing3); + + // F32: [[vec3_64:%.*]] = fpext <8 x float> [[vec3]] to <8 x double> + // F32: [[vec2_64:%.*]] = fpext <8 x float> [[vec2]] to <8 x double> + // F32: [[vec1_64:%.*]] = fpext <8 x float> [[vec1]] to <8 x double> + // CHECK: call <8 x double> @dx.op.tertiary.v8f64(i32 47, <8 x double> [[vec1_64]], <8 x double> [[vec2_64]], <8 x double> [[vec3_64]]) ; Fma(a,b,c) + res += (vector)fma((vector)thing1, (vector)(thing2), (vector)thing3); + + // Even in the double test, these will be downconverted because these builtins only take floats. + // F64: [[vec2_32:%.*]] = fptrunc <8 x double> [[vec2]] to <8 x float> + // F64: [[vec1_32:%.*]] = fptrunc <8 x double> [[vec1]] to <8 x float> + + // CHECK: [[tmp:%.*]] = fcmp fast olt <8 x float> [[vec2_32]], [[vec1_32]] + // CHECK: select <8 x i1> [[tmp]], [[TYPE]] zeroinitializer, [[TYPE]] + res += step(thing1, thing2); + + // CHECK: [[tmp:%.*]] = fmul fast <8 x float> [[vec1_32]], @dx.op.unary.v8f32(i32 21, <8 x float> [[tmp]]) ; Exp(value) + res += exp(thing1); + + // CHECK: [[tmp:%.*]] = call <8 x float> @dx.op.unary.v8f32(i32 23, <8 x float> [[vec1_32]]) ; Log(value) + // CHECK: fmul fast <8 x float> [[tmp]], @dx.op.unary.v8f32(i32 20, <8 x float> [[vec1_32]]) ; Htan(value) + res += tanh(thing1); + // CHECK: call <8 x float> @dx.op.unary.v8f32(i32 17, <8 x float> [[vec1_32]]) ; Atan(value) + res += atan(thing1); + + return res; +} + +// A mixed-type overload to test overload resolution and mingle different vector element types in ops +vector dostuff(vector thing1, vector thing2, vector thing3) { + vector res = 0; + + // F64: [[unvec_64:%.*]] = fpext <8 x float> [[unvec]] to <8 x double> + // CHECK: call <8 x double> @dx.op.binary.v8f64(i32 36, <8 x double> [[vec2_64]], <8 x double> [[unvec_64]]) ; FMin(a,b) + res += min(thing1, thing2); + + // CHECK: call [[TYPE]] @dx.op.binary.[[TY]](i32 35, [[TYPE]] [[vec2]], [[TYPE]] [[vec3]]) ; FMax(a,b) + res += max(thing1, thing3); + + // CHECK: [[tmp:%.*]] = call <8 x double> @dx.op.binary.v8f64(i32 35, <8 x double> [[vec2_64]], <8 x double> [[unvec_64]]) ; FMax(a,b) + // CHECK: call <8 x double> @dx.op.binary.v8f64(i32 36, <8 x double> [[tmp]], <8 x double> [[vec3_64]]) ; FMin(a,b) + res += clamp(thing1, thing2, thing3); + + // CHECK: call <8 x double> @dx.op.tertiary.v8f64(i32 47, <8 x double> [[vec2_64]], <8 x double> [[unvec_64]], <8 x double> [[vec3_64]]) ; Fma(a,b,c) + res += (vector)fma((vector)thing1, (vector)(thing2), (vector)thing3); + + // F32: [[unvec_32:%.*]] = fptrunc <8 x double> [[unvec]] to <8 x float> + // CHECK: [[tmp:%.*]] = fcmp fast olt <8 x float> [[unvec_32]], [[vec2_32]] + // CHECK: select <8 x i1> [[tmp]], [[TYPE]] zeroinitializer, [[TYPE]] + res += step(thing1, thing2); + + // CHECK: [[tmp:%.*]] = fmul fast <8 x float> [[vec2_32]], @dx.op.unary.v8f32(i32 21, <8 x float> [[tmp]]) ; Exp(value) + res += exp(thing1); + + // CHECK: [[tmp:%.*]] = call <8 x float> @dx.op.unary.v8f32(i32 23, <8 x float> [[vec2_32]]) ; Log(value) + // CHECK: fmul fast <8 x float> [[tmp]], @dx.op.unary.v8f32(i32 20, <8 x float> [[vec2_32]]) ; Htan(value) + res += tanh(thing1); + // CHECK: call <8 x float> @dx.op.unary.v8f32(i32 17, <8 x float> [[vec2_32]]) ; Atan(value) + res += atan(thing1); + + return res; +} diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 0bf255ed65..8f6887b5d4 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -1535,7 +1535,7 @@ def UFI(name, **mappings): next_op_idx, "Unary", "returns the " + i, - "hf", + "hf<", "rn", [ db_dxil_param(0, "$o", "", "operation result"), @@ -1599,7 +1599,7 @@ def UFI(name, **mappings): next_op_idx, "Binary", "returns the " + i + " of the input values", - "hfd", + "hfd<", "rn", [ db_dxil_param(0, "$o", "", "operation result"), @@ -1617,7 +1617,7 @@ def UFI(name, **mappings): next_op_idx, "Binary", "returns the " + i + " of the input values", - "wil", + "wil<", "rn", [ db_dxil_param(0, "$o", "", "operation result"), @@ -1689,7 +1689,7 @@ def UFI(name, **mappings): next_op_idx, "Tertiary", "performs a fused multiply add (FMA) of the form a * b + c", - "d", + "d<", "rn", [ db_dxil_param( From 51635fd07a3361b028a54d22efff54acd866adb2 Mon Sep 17 00:00:00 2001 From: Joshua Batista Date: Fri, 4 Apr 2025 10:43:25 -0700 Subject: [PATCH 04/31] Use internal validator for execution test that requires SM 6.8 (#7309) This PR modifies WaveSizeRange test which depends on shader model 6.8. The compiler needs -select-validator internal. This will allow the tests to be run in different testing environments when an external validator that isn't sufficient is available. --- .../unittests/HLSLExec/ExecutionTest.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 91b42f6b79..0ab6759d95 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -11641,20 +11641,23 @@ void ExecuteWaveSizeRangeInstance(UINT minWaveSize, UINT maxWaveSize, })"; // format compiler args - char compilerOptions[64]; + char compilerOptions[70]; if (usePreferred) { // putting spaces in between the %d's below will cause compilation issues. - VERIFY_IS_TRUE(sprintf_s(compilerOptions, sizeof(compilerOptions), - "-D WAVE_SIZE_ATTR=[wavesize(%d,%d,%d)]", - minShaderWaveSize, maxShaderWaveSize, - prefShaderWaveSize) != -1); + VERIFY_IS_TRUE( + sprintf_s( + compilerOptions, sizeof(compilerOptions), + "-D WAVE_SIZE_ATTR=[wavesize(%d,%d,%d)] -select-validator internal", + minShaderWaveSize, maxShaderWaveSize, prefShaderWaveSize) != -1); LogCommentFmt(L"Verifying wave size range test results for (min, max, " L"preferred): (%d, %d, %d)", minShaderWaveSize, maxShaderWaveSize, prefShaderWaveSize); } else { - VERIFY_IS_TRUE(sprintf_s(compilerOptions, sizeof(compilerOptions), - "-D WAVE_SIZE_ATTR=[wavesize(%d,%d)]", - minShaderWaveSize, maxShaderWaveSize) != -1); + VERIFY_IS_TRUE( + sprintf_s( + compilerOptions, sizeof(compilerOptions), + "-D WAVE_SIZE_ATTR=[wavesize(%d,%d)] -select-validator internal", + minShaderWaveSize, maxShaderWaveSize) != -1); LogCommentFmt( L"Verifying wave size range test results for (min, max): (%d, %d)", minShaderWaveSize, maxShaderWaveSize); From 73c42089dc2ca05012fab35e8c74a5934dc7eed8 Mon Sep 17 00:00:00 2001 From: Greg Roth Date: Mon, 7 Apr 2025 10:31:01 -0700 Subject: [PATCH 05/31] Add staging branches to azure pipeline testing (#7285) The name matches for azure pipelines to run excluded staging branches. This adds them in. Not sure this is desired, but it's here if we want it. --- azure-pipelines.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 33c5349f9e..285fc4028a 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -1,10 +1,12 @@ trigger: - main - release* + - staging* pr: - main - release* + - staging* resources: - repo: self From f69f2810e3afe9b54fd6c9fb7aecd5f5fb4634d5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 9 Apr 2025 15:48:17 +0000 Subject: [PATCH 06/31] chore: autopublish 2025-04-09T15:48:17Z --- lib/DxilValidation/DxilValidation.cpp | 3 +- lib/HLSL/DxilLinker.cpp | 2 +- lib/HLSL/HLOperationLower.cpp | 149 ++++++++++++++------------ tools/clang/lib/Sema/SemaHLSL.cpp | 26 +++-- 4 files changed, 100 insertions(+), 80 deletions(-) diff --git a/lib/DxilValidation/DxilValidation.cpp b/lib/DxilValidation/DxilValidation.cpp index 9e8f8574ac..a788f21d4e 100644 --- a/lib/DxilValidation/DxilValidation.cpp +++ b/lib/DxilValidation/DxilValidation.cpp @@ -2723,7 +2723,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { } // Instructions must be allowed. - if (!IsLLVMInstructionAllowed(I) || !IsLLVMInstructionAllowedForShaderModel(I, ValCtx)) { + if (!IsLLVMInstructionAllowed(I) || + !IsLLVMInstructionAllowedForShaderModel(I, ValCtx)) { if (!IsLLVMInstructionAllowedForLib(I, ValCtx)) { ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed); continue; diff --git a/lib/HLSL/DxilLinker.cpp b/lib/HLSL/DxilLinker.cpp index c4dae4b69f..c58a2e909a 100644 --- a/lib/HLSL/DxilLinker.cpp +++ b/lib/HLSL/DxilLinker.cpp @@ -1278,7 +1278,7 @@ void DxilLinkJob::RunPreparePass(Module &M) { PM.add(createScalarizerPass()); // Need dxilelimvector for pre 6.9 - //PM.add(createDxilEliminateVectorPass()); + // PM.add(createDxilEliminateVectorPass()); PM.add(createPromoteMemoryToRegisterPass()); diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp index 6078455805..a68bddaf32 100644 --- a/lib/HLSL/HLOperationLower.cpp +++ b/lib/HLSL/HLOperationLower.cpp @@ -485,11 +485,11 @@ Value *TrivialDxilOperation(OP::OpCode opcode, ArrayRef refArgs, } Value *TrivialDxilVectorOperation(Function *dxilFunc, OP::OpCode opcode, - ArrayRef refArgs, Type *Ty, - OP *hlslOP, IRBuilder<> &Builder) { + ArrayRef refArgs, Type *Ty, + OP *hlslOP, IRBuilder<> &Builder) { if (!Ty->isVoidTy()) { Value *retVal = - Builder.CreateCall(dxilFunc, refArgs, hlslOP->GetOpCodeName(opcode)); + Builder.CreateCall(dxilFunc, refArgs, hlslOP->GetOpCodeName(opcode)); return retVal; } else { // Cannot add name to void. @@ -497,20 +497,22 @@ Value *TrivialDxilVectorOperation(Function *dxilFunc, OP::OpCode opcode, } } - -Value *TrivialDxilVectorUnaryOperationRet(OP::OpCode opcode, Value *src, Type *Ty, - OP *hlslOP, IRBuilder<> &Builder) { +Value *TrivialDxilVectorUnaryOperationRet(OP::OpCode opcode, Value *src, + Type *Ty, OP *hlslOP, + IRBuilder<> &Builder) { Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); Value *args[] = {opArg, src}; Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty); - return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, Builder); + return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, + Builder); } -Value *TrivialDxilVectorBinaryOperation(OP::OpCode opcode, Value *src0, Value *src1, - hlsl::OP *hlslOP, IRBuilder<> &Builder) { +Value *TrivialDxilVectorBinaryOperation(OP::OpCode opcode, Value *src0, + Value *src1, hlsl::OP *hlslOP, + IRBuilder<> &Builder) { Type *Ty = src0->getType(); Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); @@ -518,7 +520,8 @@ Value *TrivialDxilVectorBinaryOperation(OP::OpCode opcode, Value *src0, Value *s Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty); - return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, Builder); + return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, + Builder); } Value *TrivialDxilUnaryOperationRet(OP::OpCode opcode, Value *src, Type *RetTy, @@ -547,24 +550,26 @@ Value *TrivialDxilBinaryOperation(OP::OpCode opcode, Value *src0, Value *src1, return TrivialDxilOperation(opcode, args, Ty, Ty, hlslOP, Builder); } -Value *TrivialDxilTrinaryOperationRet(OP::OpCode opcode, Value *src0, Value *src1, - Value *src2, Type *Ty, hlsl::OP *hlslOP, - IRBuilder<> &Builder) { +Value *TrivialDxilTrinaryOperationRet(OP::OpCode opcode, Value *src0, + Value *src1, Value *src2, Type *Ty, + hlsl::OP *hlslOP, IRBuilder<> &Builder) { Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); Value *args[] = {opArg, src0, src1, src2}; return TrivialDxilOperation(opcode, args, Ty, Ty, hlslOP, Builder); } -Value *TrivialDxilVectorTrinaryOperationRet(OP::OpCode opcode, Value *src0, Value *src1, - Value *src2, Type *Ty, hlsl::OP *hlslOP, - IRBuilder<> &Builder) { +Value *TrivialDxilVectorTrinaryOperationRet(OP::OpCode opcode, Value *src0, + Value *src1, Value *src2, Type *Ty, + hlsl::OP *hlslOP, + IRBuilder<> &Builder) { Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); Value *args[] = {opArg, src0, src1, src2}; Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty); - return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, Builder); + return TrivialDxilVectorOperation(dxilFunc, opcode, args, Ty, hlslOP, + Builder); } Value *TrivialUnaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -579,22 +584,20 @@ Value *TrivialUnaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, return retVal; } -Value *TrivialVectorizableUnaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { +Value *TrivialVectorizableUnaryOperation( + CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, + HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, + bool &Translated) { Value *src0 = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx); Type *Ty = CI->getType(); IRBuilder<> Builder(CI); hlsl::OP *hlslOP = &helper.hlslOP; - if (Ty->isVectorTy() && - helper.M.GetShaderModel()->IsSM69Plus()) - return TrivialDxilVectorUnaryOperationRet(opcode, src0, Ty, - hlslOP, Builder); + if (Ty->isVectorTy() && helper.M.GetShaderModel()->IsSM69Plus()) + return TrivialDxilVectorUnaryOperationRet(opcode, src0, Ty, hlslOP, + Builder); else - return TrivialDxilUnaryOperationRet(opcode, src0, Ty, - hlslOP, Builder); + return TrivialDxilUnaryOperationRet(opcode, src0, Ty, hlslOP, Builder); } Value *TrivialBinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -611,10 +614,11 @@ Value *TrivialBinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, return binOp; } -Value *TrivialVectorBinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { +Value *TrivialVectorBinaryOperation(CallInst *CI, IntrinsicOp IOP, + OP::OpCode opcode, + HLOperationLowerHelper &helper, + HLObjectOperationLowerHelper *pObjHelper, + bool &Translated) { hlsl::OP *hlslOP = &helper.hlslOP; Value *src0 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx); Value *src1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx); @@ -626,9 +630,9 @@ Value *TrivialVectorBinaryOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode op } Value *TranslateFMA(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { + HLOperationLowerHelper &helper, + HLObjectOperationLowerHelper *pObjHelper, + bool &Translated) { hlsl::OP *hlslOP = &helper.hlslOP; Type *Ty = CI->getType(); Value *src0 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx); @@ -636,11 +640,12 @@ Value *TranslateFMA(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, Value *src2 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx); IRBuilder<> Builder(CI); - if (Ty->isVectorTy() && - helper.M.GetShaderModel()->IsSM69Plus()) - return TrivialDxilVectorTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, Builder); + if (Ty->isVectorTy() && helper.M.GetShaderModel()->IsSM69Plus()) + return TrivialDxilVectorTrinaryOperationRet(opcode, src0, src1, src2, Ty, + hlslOP, Builder); else - return TrivialDxilTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, Builder); + return TrivialDxilTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, + Builder); } Value *TrivialIsSpecialFloat(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -1986,15 +1991,16 @@ Value *TranslateClamp(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, IRBuilder<> Builder(CI); // min(max(x, minVal), maxVal). - if (Ty->isVectorTy() && - helper.M.GetShaderModel()->IsSM69Plus()) { + if (Ty->isVectorTy() && helper.M.GetShaderModel()->IsSM69Plus()) { Value *maxXMinVal = - TrivialDxilVectorBinaryOperation(maxOp, x, minVal, hlslOP, Builder); - return TrivialDxilVectorBinaryOperation(minOp, maxXMinVal, maxVal, hlslOP, Builder); + TrivialDxilVectorBinaryOperation(maxOp, x, minVal, hlslOP, Builder); + return TrivialDxilVectorBinaryOperation(minOp, maxXMinVal, maxVal, hlslOP, + Builder); } else { Value *maxXMinVal = - TrivialDxilBinaryOperation(maxOp, x, minVal, hlslOP, Builder); - return TrivialDxilBinaryOperation(minOp, maxXMinVal, maxVal, hlslOP, Builder); + TrivialDxilBinaryOperation(maxOp, x, minVal, hlslOP, Builder); + return TrivialDxilBinaryOperation(minOp, maxXMinVal, maxVal, hlslOP, + Builder); } } @@ -2308,11 +2314,12 @@ Value *TranslateExp(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, ConstantVector::getSplat(Ty->getVectorNumElements(), log2eConst); } val = Builder.CreateFMul(log2eConst, val); - if (Ty->isVectorTy() && - helper.M.GetShaderModel()->IsSM69Plus()) - return TrivialDxilVectorUnaryOperationRet(OP::OpCode::Exp, val, Ty, hlslOP, Builder); + if (Ty->isVectorTy() && helper.M.GetShaderModel()->IsSM69Plus()) + return TrivialDxilVectorUnaryOperationRet(OP::OpCode::Exp, val, Ty, hlslOP, + Builder); else - return TrivialDxilUnaryOperationRet(OP::OpCode::Exp, val, Ty, hlslOP, Builder); + return TrivialDxilUnaryOperationRet(OP::OpCode::Exp, val, Ty, hlslOP, + Builder); } Value *TranslateLog(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -2328,11 +2335,12 @@ Value *TranslateLog(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, ln2Const = ConstantVector::getSplat(Ty->getVectorNumElements(), ln2Const); } Value *log = nullptr; - if (Ty->isVectorTy() && - helper.M.GetShaderModel()->IsSM69Plus()) - log = TrivialDxilVectorUnaryOperationRet(OP::OpCode::Log, val, Ty, hlslOP, Builder); + if (Ty->isVectorTy() && helper.M.GetShaderModel()->IsSM69Plus()) + log = TrivialDxilVectorUnaryOperationRet(OP::OpCode::Log, val, Ty, hlslOP, + Builder); else - log = TrivialDxilUnaryOperationRet(OP::OpCode::Log, val, Ty, hlslOP, Builder); + log = + TrivialDxilUnaryOperationRet(OP::OpCode::Log, val, Ty, hlslOP, Builder); return Builder.CreateFMul(ln2Const, log); } @@ -2392,13 +2400,12 @@ Value *TranslateFUIBinary(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, break; } } - if (CI->getType()->isVectorTy() && - helper.M.GetShaderModel()->IsSM69Plus()) + if (CI->getType()->isVectorTy() && helper.M.GetShaderModel()->IsSM69Plus()) return TrivialVectorBinaryOperation(CI, IOP, opcode, helper, pObjHelper, - Translated); + Translated); else return TrivialBinaryOperation(CI, IOP, opcode, helper, pObjHelper, - Translated); + Translated); } Value *TranslateFUITrinary(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -2423,7 +2430,8 @@ Value *TranslateFUITrinary(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, Value *src2 = CI->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx); IRBuilder<> Builder(CI); - return TrivialDxilTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, Builder); + return TrivialDxilTrinaryOperationRet(opcode, src0, src1, src2, Ty, hlslOP, + Builder); } Value *TranslateFrexp(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -2547,9 +2555,8 @@ Value *TrivialDotOperation(OP::OpCode opcode, Value *src0, Value *src1, // Instead of using a DXIL intrinsic, implement a dot product operation using // multiply and add operations. Used for integer dots and long vectors. -Value *ExpandDot(Value *arg0, Value *arg1, unsigned vecSize, - hlsl::OP *hlslOP, IRBuilder<> &Builder, - bool Unsigned = false) { +Value *ExpandDot(Value *arg0, Value *arg1, unsigned vecSize, hlsl::OP *hlslOP, + IRBuilder<> &Builder, bool Unsigned = false) { auto madOpCode = Unsigned ? DXIL::OpCode::UMad : DXIL::OpCode::IMad; if (arg0->getType()->getScalarType()->isFloatingPointTy()) madOpCode = DXIL::OpCode::FMad; @@ -2559,8 +2566,8 @@ Value *ExpandDot(Value *arg0, Value *arg1, unsigned vecSize, for (unsigned Elt = 1; Elt < vecSize; ++Elt) { Elt0 = Builder.CreateExtractElement(arg0, Elt); Elt1 = Builder.CreateExtractElement(arg1, Elt); - Result = TrivialDxilTrinaryOperationRet(madOpCode, Elt0, Elt1, Result, Elt0->getType(), hlslOP, - Builder); + Result = TrivialDxilTrinaryOperationRet(madOpCode, Elt0, Elt1, Result, + Elt0->getType(), hlslOP, Builder); } return Result; @@ -2598,11 +2605,12 @@ Value *TranslateDot(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, unsigned vecSize = Ty->getVectorNumElements(); Value *arg1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx); IRBuilder<> Builder(CI); - if (Ty->getScalarType()->isFloatingPointTy() && Ty->getVectorNumElements() <= 4) { + if (Ty->getScalarType()->isFloatingPointTy() && + Ty->getVectorNumElements() <= 4) { return TranslateFDot(arg0, arg1, vecSize, hlslOP, Builder); } else { return ExpandDot(arg0, arg1, vecSize, hlslOP, Builder, - IOP == IntrinsicOp::IOP_udot); + IOP == IntrinsicOp::IOP_udot); } } @@ -2785,8 +2793,9 @@ Value *TranslateMSad4(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, byteSrc = Builder.CreateInsertElement(byteSrc, byteSrcElt, 3); // Msad on vecref and byteSrc. - return TrivialDxilTrinaryOperationRet(DXIL::OpCode::Msad, vecRef, byteSrc, accum, - vecRef->getType(), hlslOP, Builder); + return TrivialDxilTrinaryOperationRet(DXIL::OpCode::Msad, vecRef, byteSrc, + accum, vecRef->getType(), hlslOP, + Builder); } Value *TranslateRCP(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, @@ -3151,7 +3160,7 @@ Value *TranslateMul(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, return TranslateFDot(arg0, arg1, vecSize, hlslOP, Builder); } else { return ExpandDot(arg0, arg1, vecSize, hlslOP, Builder, - IOP == IntrinsicOp::IOP_umul); + IOP == IntrinsicOp::IOP_umul); } } else { // mul(vector, scalar) == vector * scalar-splat @@ -6554,7 +6563,8 @@ IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP_asint16, TranslateBitcast, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_asuint, TranslateAsUint, DXIL::OpCode::SplitDouble}, {IntrinsicOp::IOP_asuint16, TranslateAsUint, DXIL::OpCode::NumOpCodes}, - {IntrinsicOp::IOP_atan, TrivialVectorizableUnaryOperation, DXIL::OpCode::Atan}, + {IntrinsicOp::IOP_atan, TrivialVectorizableUnaryOperation, + DXIL::OpCode::Atan}, {IntrinsicOp::IOP_atan2, TranslateAtan2, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_ceil, TrivialUnaryOperation, DXIL::OpCode::Round_pi}, {IntrinsicOp::IOP_clamp, TranslateClamp, DXIL::OpCode::NumOpCodes}, @@ -6645,7 +6655,8 @@ IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP_sqrt, TrivialUnaryOperation, DXIL::OpCode::Sqrt}, {IntrinsicOp::IOP_step, TranslateStep, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_tan, TrivialUnaryOperation, DXIL::OpCode::Tan}, - {IntrinsicOp::IOP_tanh, TrivialVectorizableUnaryOperation, DXIL::OpCode::Htan}, + {IntrinsicOp::IOP_tanh, TrivialVectorizableUnaryOperation, + DXIL::OpCode::Htan}, {IntrinsicOp::IOP_tex1D, EmptyLower, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_tex1Dbias, EmptyLower, DXIL::OpCode::NumOpCodes}, {IntrinsicOp::IOP_tex1Dgrad, EmptyLower, DXIL::OpCode::NumOpCodes}, diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 2ad97dcd9e..1ef555c6df 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -389,7 +389,7 @@ enum ArBasicKind { (IS_BPROP_AINT(_Props) && GET_BPROP_BITS(_Props) != BPROP_BITS12) #define IS_BPROP_ENUM(_Props) (((_Props)&BPROP_ENUM) != 0) -#define IS_BPROP_RAWBUFFER(_Props) (((_Props)&BPROP_RAWBUFFER) != 0) +#define IS_BPROP_RAWBUFFER(_Props) (((_Props) & BPROP_RAWBUFFER) != 0) const UINT g_uBasicKindProps[] = { BPROP_PRIMITIVE | BPROP_BOOLEAN | BPROP_INTEGER | BPROP_NUMERIC | @@ -518,14 +518,22 @@ const UINT g_uBasicKindProps[] = { BPROP_OBJECT | BPROP_RWBUFFER | BPROP_TEXTURE, // AR_OBJECT_RWTEXTURE3D BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWBUFFER - BPROP_OBJECT | BPROP_RBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_BYTEADDRESS_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWBYTEADDRESS_BUFFER - BPROP_OBJECT | BPROP_RBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_STRUCTURED_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_ALLOC - BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_CONSUME - BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_APPEND_STRUCTURED_BUFFER - BPROP_OBJECT | BPROP_RWBUFFER | BPROP_RAWBUFFER, // AR_OBJECT_CONSUME_STRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_BYTEADDRESS_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_RWBYTEADDRESS_BUFFER + BPROP_OBJECT | BPROP_RBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_STRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_ALLOC + BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_RWSTRUCTURED_BUFFER_CONSUME + BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_APPEND_STRUCTURED_BUFFER + BPROP_OBJECT | BPROP_RWBUFFER | + BPROP_RAWBUFFER, // AR_OBJECT_CONSUME_STRUCTURED_BUFFER BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_CONSTANT_BUFFER BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_TEXTURE_BUFFER From 3e18075813b5c99970ccd97e567f1a819a0aeed9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 10 Apr 2025 00:53:54 +0000 Subject: [PATCH 07/31] chore: autopublish 2025-04-10T00:53:54Z --- tools/clang/lib/Sema/SemaHLSL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 72dd6d41aa..2d668aace7 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -387,7 +387,7 @@ enum ArBasicKind { #define IS_BPROP_UNSIGNABLE(_Props) \ (IS_BPROP_AINT(_Props) && GET_BPROP_BITS(_Props) != BPROP_BITS12) -#define IS_BPROP_ENUM(_Props) (((_Props)&BPROP_ENUM) != 0) +#define IS_BPROP_ENUM(_Props) (((_Props) & BPROP_ENUM) != 0) const UINT g_uBasicKindProps[] = { BPROP_PRIMITIVE | BPROP_BOOLEAN | BPROP_INTEGER | BPROP_NUMERIC | From 1db8c5b30b41f600c4c014fad7669d0e8f154a45 Mon Sep 17 00:00:00 2001 From: Anupama Chandrasekhar Date: Fri, 18 Apr 2025 14:34:10 -0700 Subject: [PATCH 08/31] Implementation of the CoopVec Inference and Training builtin intrinisics (#7290) Implements HLSL: __builtin_MatVecMul __builtin_MatVecMulAdd __builtin_OuterProductAccumulate __builtin_VectorAccumulate Lowered to DXIL: @dx.op.matVecMul @dx.op.matVecMulAdd @dx.op.outerProductAccumulate @dx.op.vectorAccumulate --------- Co-authored-by: github-actions[bot] Co-authored-by: Damyan Pepper Co-authored-by: Simon Moll Co-authored-by: Tex Riddell Co-authored-by: Chris B --- docs/DXIL.rst | 12 + include/dxc/DXIL/DxilConstants.h | 96 +++++-- include/dxc/DXIL/DxilInstructions.h | 230 +++++++++++++++ .../dxc/DxilContainer/RDAT_LibraryTypes.inl | 6 +- include/dxc/HLSL/HLOperations.h | 48 ++++ include/dxc/HlslIntrinsicOp.h | 6 +- lib/DXIL/DxilOperations.cpp | 108 ++++++- lib/DxilValidation/DxilValidation.cpp | 271 ++++++++++++++++++ lib/HLSL/HLOperationLower.cpp | 203 +++++++++++++ tools/clang/lib/Sema/SemaHLSL.cpp | 12 + .../linalg_builtins/check-shader-stages.hlsl | 135 +++++++++ .../linalg_builtins/linalg-builtins.hlsl | 79 +++++ .../intrinsics/linalg_builtins/lit.local.cfg | 1 + .../mat-vec-mul-add_multioverload.hlsl | 108 +++++++ .../mat-vec-mul_multioverload.hlsl | 104 +++++++ ...uter-product-accumulate-multioverload.hlsl | 70 +++++ .../linalg_builtins/vector-accumulate.hlsl | 16 ++ .../DXC/Passes/DxilGen/linalg-builtins.ll | 189 ++++++++++++ .../hlsl/linalg/unavailable-pre-sm69.hlsl | 59 ++++ utils/hct/gen_intrin_main.txt | 8 + utils/hct/hctdb.py | 151 ++++++++++ utils/hct/hlsl_intrinsic_opcodes.json | 8 +- 22 files changed, 1894 insertions(+), 26 deletions(-) create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/check-shader-stages.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/linalg-builtins.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/lit.local.cfg create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul-add_multioverload.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul_multioverload.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/outer-product-accumulate-multioverload.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/vector-accumulate.hlsl create mode 100644 tools/clang/test/DXC/Passes/DxilGen/linalg-builtins.ll create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/unavailable-pre-sm69.hlsl diff --git a/docs/DXIL.rst b/docs/DXIL.rst index a1c5055085..c77dfa184a 100644 --- a/docs/DXIL.rst +++ b/docs/DXIL.rst @@ -2419,6 +2419,10 @@ ID Name Description 302 ReservedC9 reserved 303 RawBufferVectorLoad reads from a raw buffer and structured buffer 304 RawBufferVectorStore writes to a RWByteAddressBuffer or RWStructuredBuffer +305 MatVecMul Multiplies a MxK dimension matrix and a K sized input vector +306 MatVecMulAdd multiplies a MxK dimension matrix and a K sized input vector and adds an M-sized bias vector +307 OuterProductAccumulate Computes the outer product between column vectors and an MxN matrix is accumulated component-wise atomically (with device scope) in memory +308 VectorAccumulate Accumulates the components of a vector component-wise atomically (with device scope) to the corresponding elements of an array in memory === ===================================================== ======================================================================================================================================================================================================================= @@ -3134,6 +3138,14 @@ INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0] INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'. INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate. INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed. +INSTR.LINALGINTERPRETATIONPARAMARECONST In Linalg operations, Interpretation value is a constant. +INSTR.LINALGINVALIDMATRIXLAYOUTVALUEFORMATVECOPS Matrix Layout for Linalg Mul/MulAdd operation must be valid. +INSTR.LINALGINVALIDMEMORYINTERPVALUE In Memory Interpolation value must be valid. +INSTR.LINALGINVALIDREGISTERINTERPVALUE From Register Interpretation value must be valid. +INSTR.LINALGMATRIXLAYOUTNOTTRANSPOSABLE Row Major and Column Major matrix layouts are not transposable. +INSTR.LINALGMATRIXSHAPEPARAMSARECONST Matrix Layout, Dimensions and isTranspose are constants +INSTR.LINALGNOTANUNSIGNEDTYPE Unsigned flag set for a float signed type +INSTR.MATVECOPISUNSIGNEDFLAGSARECONST In Linalg Mul/MulAdd functions, IsUnsigned flag is a constant. INSTR.MAYREORDERTHREADUNDEFCOHERENCEHINTPARAM Use of undef coherence hint or num coherence hint bits in MaybeReorderThread. INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values. INSTR.MINPRECISONBITCAST Bitcast on minprecison types is not allowed. diff --git a/include/dxc/DXIL/DxilConstants.h b/include/dxc/DXIL/DxilConstants.h index 8c73328fbd..7fa4875070 100644 --- a/include/dxc/DXIL/DxilConstants.h +++ b/include/dxc/DXIL/DxilConstants.h @@ -162,24 +162,32 @@ const unsigned kDxilMaxOloadDims = 2; enum class ComponentType : uint32_t { Invalid = 0, - I1, - I16, - U16, - I32, - U32, - I64, - U64, - F16, - F32, - F64, - SNormF16, - UNormF16, - SNormF32, - UNormF32, - SNormF64, - UNormF64, - PackedS8x32, - PackedU8x32, + I1 = 1, + I16 = 2, + U16 = 3, + I32 = 4, + U32 = 5, + I64 = 6, + U64 = 7, + F16 = 8, + F32 = 9, + F64 = 10, + SNormF16 = 11, + UNormF16 = 12, + SNormF32 = 13, + UNormF32 = 14, + SNormF64 = 15, + UNormF64 = 16, + PackedS8x32 = 17, + PackedU8x32 = 18, + + // BEGIN NEW FOR SM 6.9 + U8 = 19, + I8 = 20, + F8_E4M3 = 21, + F8_E5M2 = 22, + // END + LastEntry }; @@ -743,6 +751,19 @@ enum class OpCode : unsigned { CreateHandleForLib = 160, // create resource handle from resource struct for library + // Linear Algebra Operations + MatVecMul = + 305, // Multiplies a MxK dimension matrix and a K sized input vector + MatVecMulAdd = 306, // multiplies a MxK dimension matrix and a K sized input + // vector and adds an M-sized bias vector + OuterProductAccumulate = + 307, // Computes the outer product between column vectors and an MxN + // matrix is accumulated component-wise atomically (with device + // scope) in memory + VectorAccumulate = 308, // Accumulates the components of a vector + // component-wise atomically (with device scope) to + // the corresponding elements of an array in memory + // Mesh shader instructions EmitIndices = 169, // emit a primitive's vertex indices in a mesh shader GetMeshPayload = @@ -1060,7 +1081,7 @@ enum class OpCode : unsigned { NumOpCodes_Dxil_1_7 = 226, NumOpCodes_Dxil_1_8 = 258, - NumOpCodes = 305 // exclusive last value of enumeration + NumOpCodes = 309 // exclusive last value of enumeration }; // OPCODE-ENUM:END @@ -1201,6 +1222,12 @@ enum class OpCodeClass : unsigned { // Library create handle from resource struct (like HL intrinsic) CreateHandleForLib, + // Linear Algebra Operations + MatVecMul, + MatVecMulAdd, + OuterProductAccumulate, + VectorAccumulate, + // Mesh shader instructions EmitIndices, GetMeshPayload, @@ -1385,7 +1412,7 @@ enum class OpCodeClass : unsigned { NumOpClasses_Dxil_1_7 = 153, NumOpClasses_Dxil_1_8 = 174, - NumOpClasses = 190 // exclusive last value of enumeration + NumOpClasses = 194 // exclusive last value of enumeration }; // OPCODECLASS-ENUM:END @@ -1556,6 +1583,28 @@ const unsigned kMSStoreOutputColOpIdx = 3; const unsigned kMSStoreOutputVIdxOpIdx = 4; const unsigned kMSStoreOutputValOpIdx = 5; +// MatVec Ops +const unsigned kMatVecMulInputVectorIdx = 1; +const unsigned kMatVecMulIsInputUnsignedIdx = 2; +const unsigned kMatVecMulInputInterpretationIdx = 3; +const unsigned kMatVecMulMatrixBufferIdx = 4; +const unsigned kMatVecMulMatrixOffsetIdx = 5; +const unsigned kMatVecMulMatrixInterpretationIdx = 6; +const unsigned kMatVecMulMatrixMIdx = 7; +const unsigned kMatVecMulMatrixKIdx = 8; +const unsigned kMatVecMulMatrixLayoutIdx = 9; +const unsigned kMatVecMulMatrixTransposeIdx = 10; +const unsigned kMatVecMulMatrixStrideIdx = 11; +const unsigned kMatVecMulIsOutputUnsignedIdx = 12; + +// MatVecAdd +const unsigned kMatVecMulAddBiasInterpretation = 14; +const unsigned kMatVecMulAddIsOutputUnsignedIdx = 15; + +// Outer Product Accumulate +const unsigned kOuterProdAccMatrixInterpretation = 5; +const unsigned kOuterProdAccMatrixLayout = 6; + // TODO: add operand index for all the OpCodeClass. } // namespace OperandIndex @@ -2127,6 +2176,13 @@ extern const char *kHostLayoutTypePrefix; extern const char *kWaveOpsIncludeHelperLanesString; +enum class LinalgMatrixLayout : uint32_t { + RowMajor = 0, + ColumnMajor = 1, + MulOptimal = 2, + OuterProductOptimal = 3, +}; + } // namespace DXIL } // namespace hlsl diff --git a/include/dxc/DXIL/DxilInstructions.h b/include/dxc/DXIL/DxilInstructions.h index a99c5360d4..9a4030fd8e 100644 --- a/include/dxc/DXIL/DxilInstructions.h +++ b/include/dxc/DXIL/DxilInstructions.h @@ -9918,5 +9918,235 @@ struct DxilInst_RawBufferVectorStore { llvm::APInt(32, (uint64_t)val))); } }; + +/// This instruction Multiplies a MxK dimension matrix and a K sized input +/// vector +struct DxilInst_MatVecMul { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatVecMul(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::MatVecMul); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (13 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_inputVector = 1, + arg_isInputUnsigned = 2, + arg_inputInterpretation = 3, + arg_matrixBuffer = 4, + arg_matrixOffset = 5, + arg_matrixIntepretation = 6, + arg_matrixM = 7, + arg_matrixK = 8, + arg_matrixLayout = 9, + arg_matrixTranspose = 10, + arg_matrixStride = 11, + arg_isOutputUnsigned = 12, + }; + // Accessors + llvm::Value *get_inputVector() const { return Instr->getOperand(1); } + void set_inputVector(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_isInputUnsigned() const { return Instr->getOperand(2); } + void set_isInputUnsigned(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_inputInterpretation() const { return Instr->getOperand(3); } + void set_inputInterpretation(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_matrixBuffer() const { return Instr->getOperand(4); } + void set_matrixBuffer(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_matrixOffset() const { return Instr->getOperand(5); } + void set_matrixOffset(llvm::Value *val) { Instr->setOperand(5, val); } + llvm::Value *get_matrixIntepretation() const { return Instr->getOperand(6); } + void set_matrixIntepretation(llvm::Value *val) { Instr->setOperand(6, val); } + llvm::Value *get_matrixM() const { return Instr->getOperand(7); } + void set_matrixM(llvm::Value *val) { Instr->setOperand(7, val); } + llvm::Value *get_matrixK() const { return Instr->getOperand(8); } + void set_matrixK(llvm::Value *val) { Instr->setOperand(8, val); } + llvm::Value *get_matrixLayout() const { return Instr->getOperand(9); } + void set_matrixLayout(llvm::Value *val) { Instr->setOperand(9, val); } + llvm::Value *get_matrixTranspose() const { return Instr->getOperand(10); } + void set_matrixTranspose(llvm::Value *val) { Instr->setOperand(10, val); } + llvm::Value *get_matrixStride() const { return Instr->getOperand(11); } + void set_matrixStride(llvm::Value *val) { Instr->setOperand(11, val); } + llvm::Value *get_isOutputUnsigned() const { return Instr->getOperand(12); } + void set_isOutputUnsigned(llvm::Value *val) { Instr->setOperand(12, val); } +}; + +/// This instruction multiplies a MxK dimension matrix and a K sized input +/// vector and adds an M-sized bias vector +struct DxilInst_MatVecMulAdd { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatVecMulAdd(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatVecMulAdd); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (16 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_inputVector = 1, + arg_isInputUnsigned = 2, + arg_inputInterpretation = 3, + arg_matrixBuffer = 4, + arg_matrixOffset = 5, + arg_matrixIntepretation = 6, + arg_matrixM = 7, + arg_matrixK = 8, + arg_matrixLayout = 9, + arg_matrixTranspose = 10, + arg_matrixStride = 11, + arg_biasBuffer = 12, + arg_biasOffset = 13, + arg_biasIntepretation = 14, + arg_isOutputUnsigned = 15, + }; + // Accessors + llvm::Value *get_inputVector() const { return Instr->getOperand(1); } + void set_inputVector(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_isInputUnsigned() const { return Instr->getOperand(2); } + void set_isInputUnsigned(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_inputInterpretation() const { return Instr->getOperand(3); } + void set_inputInterpretation(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_matrixBuffer() const { return Instr->getOperand(4); } + void set_matrixBuffer(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_matrixOffset() const { return Instr->getOperand(5); } + void set_matrixOffset(llvm::Value *val) { Instr->setOperand(5, val); } + llvm::Value *get_matrixIntepretation() const { return Instr->getOperand(6); } + void set_matrixIntepretation(llvm::Value *val) { Instr->setOperand(6, val); } + llvm::Value *get_matrixM() const { return Instr->getOperand(7); } + void set_matrixM(llvm::Value *val) { Instr->setOperand(7, val); } + llvm::Value *get_matrixK() const { return Instr->getOperand(8); } + void set_matrixK(llvm::Value *val) { Instr->setOperand(8, val); } + llvm::Value *get_matrixLayout() const { return Instr->getOperand(9); } + void set_matrixLayout(llvm::Value *val) { Instr->setOperand(9, val); } + llvm::Value *get_matrixTranspose() const { return Instr->getOperand(10); } + void set_matrixTranspose(llvm::Value *val) { Instr->setOperand(10, val); } + llvm::Value *get_matrixStride() const { return Instr->getOperand(11); } + void set_matrixStride(llvm::Value *val) { Instr->setOperand(11, val); } + llvm::Value *get_biasBuffer() const { return Instr->getOperand(12); } + void set_biasBuffer(llvm::Value *val) { Instr->setOperand(12, val); } + llvm::Value *get_biasOffset() const { return Instr->getOperand(13); } + void set_biasOffset(llvm::Value *val) { Instr->setOperand(13, val); } + llvm::Value *get_biasIntepretation() const { return Instr->getOperand(14); } + void set_biasIntepretation(llvm::Value *val) { Instr->setOperand(14, val); } + llvm::Value *get_isOutputUnsigned() const { return Instr->getOperand(15); } + void set_isOutputUnsigned(llvm::Value *val) { Instr->setOperand(15, val); } +}; + +/// This instruction Computes the outer product between column vectors and an +/// MxN matrix is accumulated component-wise atomically (with device scope) in +/// memory +struct DxilInst_OuterProductAccumulate { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_OuterProductAccumulate(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::OuterProductAccumulate); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (8 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_inputVector1 = 1, + arg_inputVector2 = 2, + arg_matrixBuffer = 3, + arg_matrixOffset = 4, + arg_matrixIntepretation = 5, + arg_matrixLayout = 6, + arg_matrixStride = 7, + }; + // Accessors + llvm::Value *get_inputVector1() const { return Instr->getOperand(1); } + void set_inputVector1(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_inputVector2() const { return Instr->getOperand(2); } + void set_inputVector2(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_matrixBuffer() const { return Instr->getOperand(3); } + void set_matrixBuffer(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_matrixOffset() const { return Instr->getOperand(4); } + void set_matrixOffset(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_matrixIntepretation() const { return Instr->getOperand(5); } + void set_matrixIntepretation(llvm::Value *val) { Instr->setOperand(5, val); } + int32_t get_matrixIntepretation_val() const { + return (int32_t)(llvm::dyn_cast(Instr->getOperand(5)) + ->getZExtValue()); + } + void set_matrixIntepretation_val(int32_t val) { + Instr->setOperand(5, llvm::Constant::getIntegerValue( + llvm::IntegerType::get(Instr->getContext(), 32), + llvm::APInt(32, (uint64_t)val))); + } + llvm::Value *get_matrixLayout() const { return Instr->getOperand(6); } + void set_matrixLayout(llvm::Value *val) { Instr->setOperand(6, val); } + int32_t get_matrixLayout_val() const { + return (int32_t)(llvm::dyn_cast(Instr->getOperand(6)) + ->getZExtValue()); + } + void set_matrixLayout_val(int32_t val) { + Instr->setOperand(6, llvm::Constant::getIntegerValue( + llvm::IntegerType::get(Instr->getContext(), 32), + llvm::APInt(32, (uint64_t)val))); + } + llvm::Value *get_matrixStride() const { return Instr->getOperand(7); } + void set_matrixStride(llvm::Value *val) { Instr->setOperand(7, val); } +}; + +/// This instruction Accumulates the components of a vector component-wise +/// atomically (with device scope) to the corresponding elements of an array in +/// memory +struct DxilInst_VectorAccumulate { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_VectorAccumulate(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::VectorAccumulate); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_inputVector = 1, + arg_arrayBuffer = 2, + arg_arrayOffset = 3, + }; + // Accessors + llvm::Value *get_inputVector() const { return Instr->getOperand(1); } + void set_inputVector(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_arrayBuffer() const { return Instr->getOperand(2); } + void set_arrayBuffer(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_arrayOffset() const { return Instr->getOperand(3); } + void set_arrayOffset(llvm::Value *val) { Instr->setOperand(3, val); } +}; // INSTR-HELPER:END } // namespace hlsl diff --git a/include/dxc/DxilContainer/RDAT_LibraryTypes.inl b/include/dxc/DxilContainer/RDAT_LibraryTypes.inl index 4b58b406c2..902f2e9652 100644 --- a/include/dxc/DxilContainer/RDAT_LibraryTypes.inl +++ b/include/dxc/DxilContainer/RDAT_LibraryTypes.inl @@ -565,9 +565,13 @@ RDAT_DXIL_ENUM_START(hlsl::DXIL::ComponentType, uint32_t) RDAT_ENUM_VALUE_NODEF(UNormF64) RDAT_ENUM_VALUE_NODEF(PackedS8x32) RDAT_ENUM_VALUE_NODEF(PackedU8x32) + RDAT_ENUM_VALUE_NODEF(U8) + RDAT_ENUM_VALUE_NODEF(I8) + RDAT_ENUM_VALUE_NODEF(F8_E4M3) + RDAT_ENUM_VALUE_NODEF(F8_E5M2) RDAT_ENUM_VALUE_NODEF(LastEntry) #if DEF_RDAT_ENUMS == DEF_RDAT_DUMP_IMPL - static_assert((unsigned)hlsl::DXIL::ComponentType::LastEntry == 19, + static_assert((unsigned)hlsl::DXIL::ComponentType::LastEntry == 23, "otherwise, RDAT_DXIL_ENUM definition needs updating"); #endif RDAT_ENUM_END() diff --git a/include/dxc/HLSL/HLOperations.h b/include/dxc/HLSL/HLOperations.h index f87d324baf..41def3ba2c 100644 --- a/include/dxc/HLSL/HLOperations.h +++ b/include/dxc/HLSL/HLOperations.h @@ -433,6 +433,54 @@ const unsigned kNodeHandleToResCastOpIdx = 1; const unsigned kAnnotateNodeHandleNodePropIdx = 2; const unsigned kAnnotateNodeRecordHandleNodeRecordPropIdx = 2; +// Linear Algebra Operations + +// MatVecMul +const unsigned kMatVecMulOutputVectorIdx = 1; +const unsigned kMatVecMulIsOutputUnsignedIdx = 2; +const unsigned kMatVecMulInputVectorIdx = 3; +const unsigned kMatVecMulIsInputUnsignedIdx = 4; +const unsigned kMatVecMulInputInterpretationIdx = 5; +const unsigned kMatVecMulMatrixBufferIdx = 6; +const unsigned kMatVecMulMatrixOffsetIdx = 7; +const unsigned kMatVecMulMatrixInterpretationIdx = 8; +const unsigned kMatVecMulMatrixMIdx = 9; +const unsigned kMatVecMulMatrixKIdx = 10; +const unsigned kMatVecMulMatrixLayoutIdx = 11; +const unsigned kMatVecMulMatrixTransposeIdx = 12; +const unsigned kMatVecMulMatrixStrideIdx = 13; + +// MatVecMulAdd +const unsigned kMatVecMulAddOutputVectorIdx = 1; +const unsigned kMatVecMulAddIsOutputUnsignedIdx = 2; +const unsigned kMatVecMulAddInputVectorIdx = 3; +const unsigned kMatVecMulAddIsInputUnsignedIdx = 4; +const unsigned kMatVecMulAddInputInterpretationIdx = 5; +const unsigned kMatVecMulAddMatrixBufferIdx = 6; +const unsigned kMatVecMulAddMatrixOffsetIdx = 7; +const unsigned kMatVecMulAddMatrixInterpretationIdx = 8; +const unsigned kMatVecMulAddMatrixMIdx = 9; +const unsigned kMatVecMulAddMatrixKIdx = 10; +const unsigned kMatVecMulAddMatrixLayoutIdx = 11; +const unsigned kMatVecMulAddMatrixTransposeIdx = 12; +const unsigned kMatVecMulAddMatrixStrideIdx = 13; +const unsigned kMatVecMulAddBiasBufferIdx = 14; +const unsigned kMatVecMulAddBiasOffsetIdx = 15; +const unsigned kMatVecMulAddBiasInterpretationIdx = 16; + +// OuterProductAccumulate +const unsigned kOuterProdAccInputVec1Idx = 1; +const unsigned kOuterProdAccInputVec2Idx = 2; +const unsigned kOuterProdAccMatrixIdx = 3; +const unsigned kOuterProdAccMatrixOffsetIdx = 4; +const unsigned kOuterProdAccMatrixInterpretationIdx = 5; +const unsigned kOuterProdAccMatrixLayoutIdx = 6; +const unsigned kOuterProdAccMatrixStrideIdx = 7; + +// Vector Accumulate +const unsigned kVectorAccInputVecIdx = 1; +const unsigned kVectorAccMatrixIdx = 2; +const unsigned kVectorAccMatrixOffsetIdx = 3; } // namespace HLOperandIndex llvm::Function *GetOrCreateHLFunction(llvm::Module &M, diff --git a/include/dxc/HlslIntrinsicOp.h b/include/dxc/HlslIntrinsicOp.h index d37c27a38e..197bd3e1f5 100644 --- a/include/dxc/HlslIntrinsicOp.h +++ b/include/dxc/HlslIntrinsicOp.h @@ -107,6 +107,10 @@ enum class IntrinsicOp { IOP_WorldToObject = 99, IOP_WorldToObject3x4 = 100, IOP_WorldToObject4x3 = 101, + IOP___builtin_MatVecMul = 390, + IOP___builtin_MatVecMulAdd = 391, + IOP___builtin_OuterProductAccumulate = 392, + IOP___builtin_VectorAccumulate = 393, IOP_abort = 102, IOP_abs = 103, IOP_acos = 104, @@ -396,7 +400,7 @@ enum class IntrinsicOp { IOP_usign = 355, MOP_InterlockedUMax = 356, MOP_InterlockedUMin = 357, - Num_Intrinsics = 390, + Num_Intrinsics = 394, }; inline bool HasUnsignedIntrinsicOpcode(IntrinsicOp opcode) { switch (opcode) { diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index f614ba9d14..95e8dfaeba 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -2652,6 +2652,40 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = { 1, {{0x4e7}}, {{0xe7}}}, // Overloads: hfwidlgetNumParams() <= 1) return nullptr; return FT->getParamType(1); @@ -6291,6 +6382,19 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { StructType *ST = cast(Ty); return ST->getElementType(0); } + case OpCode::MatVecMul: + case OpCode::MatVecMulAdd: + if (FT->getNumParams() < 2) + return nullptr; + return llvm::StructType::get(Ctx, + {FT->getReturnType(), FT->getParamType(1)}); + + case OpCode::OuterProductAccumulate: + if (FT->getNumParams() < 3) + return nullptr; + return llvm::StructType::get(Ctx, + {FT->getParamType(1), FT->getParamType(2)}); + // OPCODE-OLOAD-TYPES:END default: return Ty; diff --git a/lib/DxilValidation/DxilValidation.cpp b/lib/DxilValidation/DxilValidation.cpp index 00a6b9ae14..0b2ccf5f95 100644 --- a/lib/DxilValidation/DxilValidation.cpp +++ b/lib/DxilValidation/DxilValidation.cpp @@ -970,6 +970,267 @@ static void ValidateImmOperandForMathDxilOp(CallInst *CI, DXIL::OpCode Opcode, } } +static bool CheckLinalgInterpretation(uint32_t Input, bool InRegister) { + using CT = DXIL::ComponentType; + switch (static_cast(Input)) { + case CT::I16: + case CT::U16: + case CT::I32: + case CT::U32: + case CT::F16: + case CT::F32: + case CT::U8: + case CT::I8: + case CT::F8_E4M3: + case CT::F8_E5M2: + return true; + case CT::PackedS8x32: + case CT::PackedU8x32: + return InRegister; + default: + return false; + } +} + +static bool CheckMatrixLayoutForMatVecMulOps(unsigned Layout) { + return Layout <= + static_cast(DXIL::LinalgMatrixLayout::OuterProductOptimal); +} + +std::string GetMatrixLayoutStr(unsigned Layout) { + switch (static_cast(Layout)) { + case DXIL::LinalgMatrixLayout::RowMajor: + return "RowMajor"; + case DXIL::LinalgMatrixLayout::ColumnMajor: + return "ColumnMajor"; + case DXIL::LinalgMatrixLayout::MulOptimal: + return "MulOptimal"; + case DXIL::LinalgMatrixLayout::OuterProductOptimal: + return "OuterProductOptimal"; + default: + DXASSERT_NOMSG(false); + return "Invalid"; + } +} + +static bool CheckTransposeForMatrixLayout(unsigned Layout, bool Transposed) { + switch (static_cast(Layout)) { + case DXIL::LinalgMatrixLayout::RowMajor: + case DXIL::LinalgMatrixLayout::ColumnMajor: + return !Transposed; + + default: + return true; + } +} + +static bool CheckUnsignedFlag(Type *VecTy, bool IsUnsigned) { + Type *ElemTy = VecTy->getScalarType(); + if (ElemTy->isFloatingPointTy()) + return !IsUnsigned; + + return true; +} + +static Value *GetMatVecOpIsOutputUnsigned(CallInst *CI, DXIL::OpCode OpCode) { + switch (OpCode) { + case DXIL::OpCode::MatVecMul: + return CI->getOperand(DXIL::OperandIndex::kMatVecMulIsOutputUnsignedIdx); + case DXIL::OpCode::MatVecMulAdd: + return CI->getOperand(DXIL::OperandIndex::kMatVecMulAddIsOutputUnsignedIdx); + + default: + DXASSERT_NOMSG(false); + return nullptr; + } +} + +static void ValidateImmOperandsForMatVecOps(CallInst *CI, DXIL::OpCode OpCode, + ValidationContext &ValCtx) { + + llvm::Value *IsInputUnsigned = + CI->getOperand(DXIL::OperandIndex::kMatVecMulIsInputUnsignedIdx); + ConstantInt *IsInputUnsignedConst = + dyn_cast(IsInputUnsigned); + if (!IsInputUnsignedConst) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst, + {"IsInputUnsigned"}); + return; + } + + llvm::Value *IsOutputUnsigned = GetMatVecOpIsOutputUnsigned(CI, OpCode); + ConstantInt *IsOutputUnsignedConst = + dyn_cast(IsOutputUnsigned); + if (!IsOutputUnsignedConst) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst, + {"IsOutputUnsigned"}); + return; + } + + llvm::Value *InputInterpretation = + CI->getOperand(DXIL::OperandIndex::kMatVecMulInputInterpretationIdx); + ConstantInt *II = dyn_cast(InputInterpretation); + if (!II) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInterpretationParamAreConst, + {"InputInterpretation"}); + return; + } + uint64_t IIValue = II->getLimitedValue(); + if (!CheckLinalgInterpretation(IIValue, true)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue, + {std::to_string(IIValue), "Input"}); + return; + } + + llvm::Value *MatrixInterpretation = + CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixInterpretationIdx); + ConstantInt *MI = dyn_cast(MatrixInterpretation); + if (!MI) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInterpretationParamAreConst, + {"MatrixInterpretation"}); + return; + } + uint64_t MIValue = MI->getLimitedValue(); + if (!CheckLinalgInterpretation(MIValue, false)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue, + {std::to_string(MIValue), "Matrix"}); + return; + } + + llvm::Value *MatrixM = + CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixMIdx); + if (!llvm::isa(MatrixM)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst, + {"Matrix M dimension"}); + return; + } + + llvm::Value *MatrixK = + CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixKIdx); + if (!llvm::isa(MatrixK)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst, + {"Matrix K dimension"}); + return; + } + + llvm::Value *MatrixLayout = + CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx); + + ConstantInt *MatrixLayoutConst = dyn_cast(MatrixLayout); + if (!MatrixLayoutConst) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst, + {"Matrix Layout"}); + return; + } + uint64_t MLValue = MatrixLayoutConst->getLimitedValue(); + if (!CheckMatrixLayoutForMatVecMulOps(MLValue)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInvalidMatrixLayoutValueForMatVecOps, + {std::to_string(MLValue), + std::to_string( + static_cast(DXIL::LinalgMatrixLayout::RowMajor)), + std::to_string(static_cast( + DXIL::LinalgMatrixLayout::OuterProductOptimal))}); + return; + } + + llvm::Value *MatrixTranspose = + CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx); + ConstantInt *MatrixTransposeConst = dyn_cast(MatrixTranspose); + if (!MatrixTransposeConst) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst, + {"MatrixTranspose"}); + return; + } + + if (!CheckTransposeForMatrixLayout(MLValue, + MatrixTransposeConst->getLimitedValue())) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgMatrixLayoutNotTransposable, + {GetMatrixLayoutStr(MLValue)}); + return; + } + + llvm::Value *InputVector = + CI->getOperand(DXIL::OperandIndex::kMatVecMulInputVectorIdx); + if (!CheckUnsignedFlag(InputVector->getType(), + IsInputUnsignedConst->getLimitedValue())) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgNotAnUnsignedType, {"Input"}); + return; + } + + if (!CheckUnsignedFlag(CI->getType(), + IsOutputUnsignedConst->getLimitedValue())) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgNotAnUnsignedType, {"Output"}); + return; + } + + switch (OpCode) { + case DXIL::OpCode::MatVecMulAdd: { + llvm::Value *BiasInterpretation = + CI->getOperand(DXIL::OperandIndex::kMatVecMulAddBiasInterpretation); + ConstantInt *BI = cast(BiasInterpretation); + if (!BI) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInterpretationParamAreConst, + {"BiasInterpretation"}); + return; + } + uint64_t BIValue = BI->getLimitedValue(); + if (!CheckLinalgInterpretation(BIValue, false)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue, + {std::to_string(BIValue), "Bias vector"}); + return; + } + } break; + default: + break; + } +} + +static void ValidateImmOperandsForOuterProdAcc(CallInst *CI, + ValidationContext &ValCtx) { + + llvm::Value *MatrixInterpretation = + CI->getOperand(DXIL::OperandIndex::kOuterProdAccMatrixInterpretation); + ConstantInt *MI = cast(MatrixInterpretation); + if (!MI) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInterpretationParamAreConst, + {"MatrixInterpretation"}); + return; + } + uint64_t MIValue = MI->getLimitedValue(); + if (!CheckLinalgInterpretation(MIValue, false)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue, + {std::to_string(MIValue), "Matrix"}); + return; + } + + llvm::Value *MatrixLayout = + CI->getOperand(DXIL::OperandIndex::kOuterProdAccMatrixLayout); + if (!llvm::isa(MatrixLayout)) { + ValCtx.EmitInstrFormatError( + CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst, + {"MatrixLayout"}); + return; + } +} + // Validate the type-defined mask compared to the store value mask which // indicates which parts were defined returns true if caller should continue // validation @@ -1994,6 +2255,16 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI, GetLaunchTypeStr(NodeLaunchType)}); break; + case DXIL::OpCode::MatVecMul: + case DXIL::OpCode::MatVecMulAdd: + ValidateImmOperandsForMatVecOps(CI, Opcode, ValCtx); + break; + case DXIL::OpCode::OuterProductAccumulate: + ValidateImmOperandsForOuterProdAcc(CI, ValCtx); + break; + case DXIL::OpCode::VectorAccumulate: + + break; default: // TODO: make sure every Opcode is checked. diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp index b5114fa34b..4f55cb377d 100644 --- a/lib/HLSL/HLOperationLower.cpp +++ b/lib/HLSL/HLOperationLower.cpp @@ -6321,6 +6321,200 @@ Value *TranslateSelect(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, return Builder.CreateSelect(cond, t, f); } + +Value *TranslateMatVecMul(CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode, + HLOperationLowerHelper &Helper, + HLObjectOperationLowerHelper *ObjHelper, + bool &Translated) { + + hlsl::OP *HlslOp = &Helper.hlslOP; + IRBuilder<> Builder(CI); + + Constant *OpArg = HlslOp->GetU32Const(static_cast(OpCode)); + + // Input parameters + Value *InputVector = + CI->getArgOperand(HLOperandIndex::kMatVecMulInputVectorIdx); + Value *InputIsUnsigned = + CI->getArgOperand(HLOperandIndex::kMatVecMulIsInputUnsignedIdx); + Value *InputInterpretation = + CI->getArgOperand(HLOperandIndex::kMatVecMulInputInterpretationIdx); + + // Matrix parameters + Value *MatrixBuffer = + CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixBufferIdx); + Value *MatrixOffset = + CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixOffsetIdx); + Value *MatrixInterpretation = + CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixInterpretationIdx); + Value *MatrixM = CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixMIdx); + Value *MatrixK = CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixKIdx); + Value *MatrixLayout = + CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixLayoutIdx); + Value *MatrixTranspose = + CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixTransposeIdx); + Value *MatrixStride = + CI->getArgOperand(HLOperandIndex::kMatVecMulMatrixStrideIdx); + + // Output parameters + Value *OutputIsUnsigned = + CI->getArgOperand(HLOperandIndex::kMatVecMulIsOutputUnsignedIdx); + + // Get the DXIL function for the operation + Function *DxilFunc = HlslOp->GetOpFunc( + OpCode, {CI->getArgOperand(HLOperandIndex::kMatVecMulOutputVectorIdx) + ->getType() + ->getPointerElementType(), + InputVector->getType()}); + + // Create a call to the DXIL function + Value *NewCI = Builder.CreateCall( + DxilFunc, + {OpArg, InputVector, InputIsUnsigned, InputInterpretation, MatrixBuffer, + MatrixOffset, MatrixInterpretation, MatrixM, MatrixK, MatrixLayout, + MatrixTranspose, MatrixStride, OutputIsUnsigned}); + + // Get the output parameter and store the result + Value *OutParam = + CI->getArgOperand(HLOperandIndex::kMatVecMulOutputVectorIdx); + + Builder.CreateStore(NewCI, OutParam); + + return nullptr; +} + +Value *TranslateMatVecMulAdd(CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode, + HLOperationLowerHelper &Helper, + HLObjectOperationLowerHelper *ObjHelper, + bool &Translated) { + + hlsl::OP *HlslOp = &Helper.hlslOP; + IRBuilder<> Builder(CI); + + Constant *OpArg = HlslOp->GetU32Const(static_cast(OpCode)); + + // Input vector parameters + Value *InputVector = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddInputVectorIdx); + Value *InputIsUnsigned = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddIsInputUnsignedIdx); + Value *InputInterpretation = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddInputInterpretationIdx); + + // Matrix parameters + Value *MatrixBuffer = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixBufferIdx); + Value *MatrixOffset = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixOffsetIdx); + Value *MatrixInterpretation = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixInterpretationIdx); + Value *MatrixM = CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixMIdx); + Value *MatrixK = CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixKIdx); + Value *MatrixLayout = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixLayoutIdx); + Value *MatrixTranspose = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixTransposeIdx); + Value *MatrixStride = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddMatrixStrideIdx); + + // Bias parameters + Value *BiasBuffer = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddBiasBufferIdx); + Value *BiasOffset = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddBiasOffsetIdx); + Value *BiasInterpretation = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddBiasInterpretationIdx); + + // Output parameters + Value *OutputIsUnsigned = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddIsOutputUnsignedIdx); + + // Get the DXIL function for the operation + Function *DxilFunc = HlslOp->GetOpFunc( + OpCode, {CI->getArgOperand(HLOperandIndex::kMatVecMulAddOutputVectorIdx) + ->getType() + ->getPointerElementType(), + InputVector->getType()}); + + // Create a call to the DXIL function + Value *NewCI = Builder.CreateCall( + DxilFunc, {OpArg, InputVector, InputIsUnsigned, InputInterpretation, + MatrixBuffer, MatrixOffset, MatrixInterpretation, MatrixM, + MatrixK, MatrixLayout, MatrixTranspose, MatrixStride, + BiasBuffer, BiasOffset, BiasInterpretation, OutputIsUnsigned}); + + // Store the result in the output parameter + Value *OutParam = + CI->getArgOperand(HLOperandIndex::kMatVecMulAddOutputVectorIdx); + Builder.CreateStore(NewCI, OutParam); + + return nullptr; +} + +Value *TranslateOuterProductAccumulate(CallInst *CI, IntrinsicOp IOP, + OP::OpCode OpCode, + HLOperationLowerHelper &Helper, + HLObjectOperationLowerHelper *ObjHelper, + bool &Translated) { + + hlsl::OP *HlslOp = &Helper.hlslOP; + IRBuilder<> Builder(CI); + + Constant *OpArg = HlslOp->GetU32Const(static_cast(OpCode)); + + // Input vector parameters + Value *InputVector1 = + CI->getArgOperand(HLOperandIndex::kOuterProdAccInputVec1Idx); + Value *InputVector2 = + CI->getArgOperand(HLOperandIndex::kOuterProdAccInputVec2Idx); + + // Matrix parameters + Value *MatrixBuffer = + CI->getArgOperand(HLOperandIndex::kOuterProdAccMatrixIdx); + Value *MatrixOffset = + CI->getArgOperand(HLOperandIndex::kOuterProdAccMatrixOffsetIdx); + Value *MatrixInterpretation = + CI->getArgOperand(HLOperandIndex::kOuterProdAccMatrixInterpretationIdx); + Value *MatrixLayout = + CI->getArgOperand(HLOperandIndex::kOuterProdAccMatrixLayoutIdx); + Value *MatrixStride = + CI->getArgOperand(HLOperandIndex::kOuterProdAccMatrixStrideIdx); + + // Get the DXIL function for the operation + Function *DxilFunc = HlslOp->GetOpFunc( + OpCode, {InputVector1->getType(), InputVector2->getType()}); + + return Builder.CreateCall( + DxilFunc, {OpArg, InputVector1, InputVector2, MatrixBuffer, MatrixOffset, + MatrixInterpretation, MatrixLayout, MatrixStride}); +} + +Value *TranslateVectorAccumulate(CallInst *CI, IntrinsicOp IOP, + OP::OpCode OpCode, + HLOperationLowerHelper &Helper, + HLObjectOperationLowerHelper *ObjHelper, + bool &Translated) { + + hlsl::OP *HlslOp = &Helper.hlslOP; + IRBuilder<> Builder(CI); + + Constant *OpArg = HlslOp->GetU32Const(static_cast(OpCode)); + + // Input vector parameter + Value *InputVector = CI->getArgOperand(HLOperandIndex::kVectorAccInputVecIdx); + + // Matrix parameters + Value *MatrixBuffer = CI->getArgOperand(HLOperandIndex::kVectorAccMatrixIdx); + Value *MatrixOffset = + CI->getArgOperand(HLOperandIndex::kVectorAccMatrixOffsetIdx); + + // Get the DXIL function for the operation + Function *DxilFunc = HlslOp->GetOpFunc(OpCode, InputVector->getType()); + + return Builder.CreateCall(DxilFunc, + {OpArg, InputVector, MatrixBuffer, MatrixOffset}); +} + } // namespace // Lower table. @@ -7036,6 +7230,15 @@ IntrinsicLower gLowerTable[] = { DXIL::OpCode::HitObject_SetShaderTableIndex}, {IntrinsicOp::MOP_DxHitObject_TraceRay, TranslateHitObjectTraceRay, DXIL::OpCode::HitObject_TraceRay}, + + {IntrinsicOp::IOP___builtin_MatVecMul, TranslateMatVecMul, + DXIL::OpCode::MatVecMul}, + {IntrinsicOp::IOP___builtin_MatVecMulAdd, TranslateMatVecMulAdd, + DXIL::OpCode::MatVecMulAdd}, + {IntrinsicOp::IOP___builtin_OuterProductAccumulate, + TranslateOuterProductAccumulate, DXIL::OpCode::OuterProductAccumulate}, + {IntrinsicOp::IOP___builtin_VectorAccumulate, TranslateVectorAccumulate, + DXIL::OpCode::VectorAccumulate}, }; } // namespace static_assert( diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 03a37b6dbc..c6d3c014d9 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -12071,6 +12071,18 @@ void Sema::DiagnoseReachableHLSLCall(CallExpr *CE, const hlsl::ShaderModel *SM, break; case hlsl::IntrinsicOp::IOP_DxMaybeReorderThread: DiagnoseReachableSERCall(*this, CE, EntrySK, EntryDecl, true); + break; + case hlsl::IntrinsicOp::IOP___builtin_MatVecMul: + case hlsl::IntrinsicOp::IOP___builtin_MatVecMulAdd: + case hlsl::IntrinsicOp::IOP___builtin_OuterProductAccumulate: + case hlsl::IntrinsicOp::IOP___builtin_VectorAccumulate: + if (!SM->IsSM69Plus()) { + Diags.Report(CE->getExprLoc(), + diag::warn_hlsl_intrinsic_in_wrong_shader_model) + << FD->getNameAsString() << EntryDecl->getNameAsString() << "6.9"; + return; + } + break; default: break; diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/check-shader-stages.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/check-shader-stages.hlsl new file mode 100644 index 0000000000..74cb51260c --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/check-shader-stages.hlsl @@ -0,0 +1,135 @@ +// RUN: %dxc -T lib_6_9 %s | FileCheck %s + +ByteAddressBuffer matrix_buffer; +ByteAddressBuffer bias_buffer; +RWByteAddressBuffer rw_matrix_buffer; +ByteAddressBuffer input_vector_buffer; +RWByteAddressBuffer output_vector_buffer; + +void UseCoopVec() { + vector output_vector; + static const uint is_output_unsigned = 0; + + vector input_vector = input_vector_buffer.Load >(0); + const uint is_input_unsigned = 0; + const uint input_interpretation = 9; /*F32*/ + + const uint matrix_offset = 0; + const uint matrix_interpretation = 9; /*F32*/ + const uint matrix_dimM = 4; + const uint matrix_dimK = 4; + const uint matrix_layout = 0; /*RowMajor*/ + const bool matrix_is_transposed = false; + const uint matrix_stride = 64; + + __builtin_MatVecMul(output_vector, is_output_unsigned, input_vector, + is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, + matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout, + matrix_is_transposed, matrix_stride); + output_vector_buffer.Store(0, output_vector); + + const uint bias_offset = 0; + const uint bias_interpretation = 9; /*F32*/ + + __builtin_MatVecMulAdd(output_vector, is_output_unsigned, input_vector, + is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, + matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout, + matrix_is_transposed, matrix_stride, bias_buffer, bias_offset, + bias_interpretation); + output_vector_buffer.Store(1024, output_vector); + + vector input_vector1; + vector input_vector2; + const uint opa_matrix_offset = 0; + const uint opa_matrix_interpretation = 5; /*U32*/ + const uint opa_matrix_layout = 3; /*OuterProductOptimal*/ + const uint opa_matrix_stride = 64; + + __builtin_OuterProductAccumulate(input_vector1, input_vector2, + rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation, + opa_matrix_layout, opa_matrix_stride); + + const uint va_matrix_offset = 0; + + __builtin_VectorAccumulate(input_vector1, rw_matrix_buffer, + va_matrix_offset); +} + +// CHECK: define void @ps_main() +// CHECK: call <4 x float> @dx.op.matVecMul +// CHECK: call <4 x float> @dx.op.matVecMulAdd +// CHECK: call void @dx.op.outerProductAccumulate +// CHECK: call void @dx.op.vectorAccumulate + +[Shader("pixel")] +void ps_main() +{ + UseCoopVec(); +} + +// CHECK: define void @cs_main() +// CHECK: call <4 x float> @dx.op.matVecMul +// CHECK: call <4 x float> @dx.op.matVecMulAdd +// CHECK: call void @dx.op.outerProductAccumulate +// CHECK: call void @dx.op.vectorAccumulate + +[Shader("compute")] +[NumThreads(1,1,1)] +void cs_main() +{ + UseCoopVec(); +} + +// CHECK: define void @vs_main() +// CHECK: call <4 x float> @dx.op.matVecMul +// CHECK: call <4 x float> @dx.op.matVecMulAdd +// CHECK: call void @dx.op.outerProductAccumulate +// CHECK: call void @dx.op.vectorAccumulate + +[Shader("vertex")] +void vs_main() +{ + UseCoopVec(); +} + +struct MyRecord{ + uint a; +}; + +// CHECK: define void @ns_main() +// CHECK: call <4 x float> @dx.op.matVecMul +// CHECK: call <4 x float> @dx.op.matVecMulAdd +// CHECK: call void @dx.op.outerProductAccumulate +// CHECK: call void @dx.op.vectorAccumulate + +[Shader("node")] +[NodeLaunch("thread")] +void ns_main(ThreadNodeInputRecord input) +{ + UseCoopVec(); +} + +// Vertex shader output structure +struct VS_OUT { + float3 Color : COLOR0; +}; + +// Geometry shader output structure +struct GS_OUT { + float3 Color : COLOR0; + float2 TexCoord : TEXCOORD0; +}; + +// CHECK: define void @gs_main() +// CHECK: call <4 x float> @dx.op.matVecMul +// CHECK: call <4 x float> @dx.op.matVecMulAdd +// CHECK: call void @dx.op.outerProductAccumulate +// CHECK: call void @dx.op.vectorAccumulate + +[shader("geometry")] +[maxvertexcount(3)] +void gs_main(point VS_OUT input[1], + inout TriangleStream OutputStream) +{ + UseCoopVec(); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/linalg-builtins.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/linalg-builtins.hlsl new file mode 100644 index 0000000000..c3b4a3a8d7 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/linalg-builtins.hlsl @@ -0,0 +1,79 @@ +// RUN: %dxc -fcgl -T cs_6_9 -E cs_main %s | FileCheck %s + +ByteAddressBuffer input_vector_buffer; +ByteAddressBuffer opa_input_buffer; +ByteAddressBuffer matrix_buffer; +ByteAddressBuffer bias_buffer; +RWByteAddressBuffer rw_matrix_buffer; +RWByteAddressBuffer output_vector_buffer; + +[Shader("compute")] +[NumThreads(1,1,1)] +void cs_main() +{ + vector output_vector; + static const uint is_output_unsigned = 0; + + vector input_vector = input_vector_buffer.Load >(0); + const uint is_input_unsigned = 0; + const uint input_interpretation = 9; /*F32*/ + + const uint matrix_offset = 0; + const uint matrix_interpretation = 9; /*F32*/ + const uint matrix_dimM = 4; + const uint matrix_dimK = 4; + const uint matrix_layout = 0; /*RowMajor*/ + const bool matrix_is_transposed = false; + const uint matrix_stride = 64; + + // CHECK: %[[MLD0:[^ ]+]] = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?matrix_buffer@@3UByteAddressBuffer@@A" + // CHECK: %[[MCH0:[^ ]+]] = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %[[MLD0]]) + // CHECK: %[[MAH0:[^ ]+]] = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %[[MCH0]], %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer undef) + // CHECK: call void @"dx.hl.op..void (i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x float>* %{{[^ ]+}}, i1 false, <4 x float> %{{[^ ]+}}, i1 false, i32 9, %dx.types.Handle %[[MAH0]], i32 0, i32 9, i32 4, i32 4, i32 0, i1 false, i32 64) + __builtin_MatVecMul(output_vector, is_output_unsigned, input_vector, + is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, + matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout, + matrix_is_transposed, matrix_stride); + output_vector_buffer.Store(0, output_vector); + + const uint bias_offset = 0; + const uint bias_interpretation = 9; /*F32*/ + + // CHECK: %[[MLD1:[^ ]+]] = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?matrix_buffer@@3UByteAddressBuffer@@A" + // CHECK: %[[MCH1:[^ ]+]] = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %[[MLD1]]) + // CHECK: %[[MAH1:[^ ]+]] = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %[[MCH1]], %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer undef) + // CHECK-NEXT: %[[BLD1:[^ ]+]] = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?bias_buffer@@3UByteAddressBuffer@@A" + // CHECK-NEXT: %[[BCH1:[^ ]+]] = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %[[BLD1]]) + // CHECK-NEXT: %[[BAH1:[^ ]+]] = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %[[BCH1]], %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer undef) + // CHECK-NEXT: call void @"dx.hl.op..void (i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x float>* %{{[^ ]+}}, i1 false, <4 x float> %{{[^ ]+}}, i1 false, i32 9, %dx.types.Handle %[[MAH1]], i32 0, i32 9, i32 4, i32 4, i32 0, i1 false, i32 64, %dx.types.Handle %[[BAH1]], i32 0, i32 9) + __builtin_MatVecMulAdd(output_vector, is_output_unsigned, input_vector, + is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, + matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout, + matrix_is_transposed, matrix_stride, bias_buffer, bias_offset, + bias_interpretation); + output_vector_buffer.Store(1024, output_vector); + + vector input_vector1 = opa_input_buffer.Load >(0); + vector input_vector2 = opa_input_buffer.Load >(128); + const uint opa_matrix_offset = 0; + const uint opa_matrix_interpretation = 5; /*U32*/ + const uint opa_matrix_layout = 3; /*OuterProductOptimal*/ + const uint opa_matrix_stride = 64; + + // CHECK: %[[MLD2:[^ ]+]] = load %struct.RWByteAddressBuffer, %struct.RWByteAddressBuffer* @"\01?rw_matrix_buffer@@3URWByteAddressBuffer@@A" + // CHECK: %[[MCH2:[^ ]+]] = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.RWByteAddressBuffer)"(i32 0, %struct.RWByteAddressBuffer %[[MLD2]]) + // CHECK: %[[MAH2:[^ ]+]] = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer)"(i32 14, %dx.types.Handle %[[MCH2]], %dx.types.ResourceProperties { i32 4107, i32 0 }, %struct.RWByteAddressBuffer undef) + // CHECK: call void @"dx.hl.op..void (i32, <8 x i32>, <8 x i32>, %dx.types.Handle, i32, i32, i32, i32)"(i32 392, <8 x i32> %{{[^ ]+}}, <8 x i32> %{{[^ ]+}}, %dx.types.Handle %[[MAH2]], i32 0, i32 5, i32 3, i32 64) + __builtin_OuterProductAccumulate(input_vector1, input_vector2, + rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation, + opa_matrix_layout, opa_matrix_stride); + + const uint va_matrix_offset = 0; + + // CHECK: %[[MLD3:[^ ]+]] = load %struct.RWByteAddressBuffer, %struct.RWByteAddressBuffer* @"\01?rw_matrix_buffer@@3URWByteAddressBuffer@@A" + // CHECK: %[[MCH3:[^ ]+]] = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.RWByteAddressBuffer)"(i32 0, %struct.RWByteAddressBuffer %[[MLD3]]) + // CHECK: %[[MAH3:[^ ]+]] = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer)"(i32 14, %dx.types.Handle %[[MCH3]], %dx.types.ResourceProperties { i32 4107, i32 0 }, %struct.RWByteAddressBuffer undef) + // CHECK: call void @"dx.hl.op..void (i32, <8 x i32>, %dx.types.Handle, i32)"(i32 393, <8 x i32> %{{[^ ]+}}, %dx.types.Handle %[[MAH3]], i32 0) + __builtin_VectorAccumulate(input_vector1, rw_matrix_buffer, + va_matrix_offset); +} \ No newline at end of file diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/lit.local.cfg b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/lit.local.cfg new file mode 100644 index 0000000000..c2417a9e43 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/lit.local.cfg @@ -0,0 +1 @@ +config.unsupported = 'dxil-1-9' not in config.available_features \ No newline at end of file diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul-add_multioverload.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul-add_multioverload.hlsl new file mode 100644 index 0000000000..98a568fa22 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul-add_multioverload.hlsl @@ -0,0 +1,108 @@ +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F16 -DMI=F16 -DML=RowMajor -DMT=0 -DBI=F16 | FileCheck %s --check-prefixes COMMON,DXIL-0 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E4M3 -DMI=F8_E4M3 -DML=MulOptimal -DMT=0 -DBI=F16 | FileCheck %s --check-prefixes COMMON,DXIL-1 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E5M2 -DMI=F8_E5M2 -DML=MulOptimal -DMT=1 -DBI=F16 | FileCheck %s --check-prefixes COMMON,DXIL-2 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=uint -DII=PackedS8x32 -DMI=I8 -DML=OuterProductOptimal -DMT=1 -DBI=I32 | FileCheck %s --check-prefixes COMMON,DXIL-3 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=float -DII=I8 -DMI=I8 -DML=RowMajor -DMT=0 -DBI=I32 | FileCheck %s --check-prefixes COMMON,DXIL-4 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=1 -DOTY=uint -DIU=0 -DITY=float -DII=I8 -DMI=F16 -DML=RowMajor -DMT=0 -DBI=I8 | FileCheck %s --check-prefixes COMMON,DXIL-5 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=1 -DITY=uint -DII=U8 -DMI=I8 -DML=ColumnMajor -DMT=0 -DBI=I8 | FileCheck %s --check-prefixes COMMON,DXIL-6 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=int -DII=U8 -DMI=U8 -DML=MulOptimal -DMT=1 -DBI=I8 | FileCheck %s --check-prefixes COMMON,DXIL-7 + +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F16 -DMI=F16 -DML=RowMajor -DMT=0 -DBI=F16 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-0 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E4M3 -DMI=F8_E4M3 -DML=MulOptimal -DMT=0 -DBI=F16 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-1 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E5M2 -DMI=F8_E5M2 -DML=MulOptimal -DMT=1 -DBI=F16 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-2 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=uint -DII=PackedS8x32 -DMI=I8 -DML=OuterProductOptimal -DMT=1 -DBI=I32 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-3 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=float -DII=I8 -DMI=I8 -DML=RowMajor -DMT=0 -DBI=I32 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-4 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=1 -DOTY=uint -DIU=0 -DITY=float -DII=I8 -DMI=F16 -DML=RowMajor -DMT=0 -DBI=I8 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-5 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=1 -DITY=uint -DII=U8 -DMI=I8 -DML=ColumnMajor -DMT=0 -DBI=I8 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-6 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=int -DII=U8 -DMI=U8 -DML=MulOptimal -DMT=1 -DBI=I8 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-7 + + +// COMMON: define void @main() + +// Test minimum support set of combinations for matVecMul +// HLOP-0: call void @"dx.hl.op..void (i32, <4 x half>*, i1, <8 x half>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x half>* %output_vector, i1 false, <8 x half> %{{[^ ]+}}, i1 false, i32 8, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8) +// DXIL-0: call <4 x half> @dx.op.matVecMulAdd.v4f16.v8f16(i32 306, <8 x half> {{[^ ]+}}, i1 false, i32 8, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) +// HLOP-1: call void @"dx.hl.op..void (i32, <4 x half>*, i1, <8 x half>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x half>* %output_vector, i1 false, <8 x half> %{{[^ ]+}}, i1 false, i32 21, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 21, i32 8, i32 8, i32 2, i1 false, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8) +// DXIL-1: call <4 x half> @dx.op.matVecMulAdd.v4f16.v8f16(i32 306, <8 x half> {{[^ ]+}}, i1 false, i32 21, %dx.types.Handle {{[^ ]+}}, i32 0, i32 21, i32 8, i32 8, i32 2, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) +// HLOP-2: call void @"dx.hl.op..void (i32, <4 x half>*, i1, <8 x half>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x half>* %output_vector, i1 false, <8 x half> %{{[^ ]+}}, i1 false, i32 22, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 22, i32 8, i32 8, i32 2, i1 true, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8) +// DXIL-2: call <4 x half> @dx.op.matVecMulAdd.v4f16.v8f16(i32 306, <8 x half> {{[^ ]+}}, i1 false, i32 22, %dx.types.Handle {{[^ ]+}}, i32 0, i32 22, i32 8, i32 8, i32 2, i1 true, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) +// HLOP-3: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x i32>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x i32>* %output_vector, i1 false, <8 x i32> %{{[^ ]+}}, i1 false, i32 17, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 3, i1 true, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 4) +// DXIL-3: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8i32(i32 306, <8 x i32> {{[^ ]+}}, i1 false, i32 17, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 3, i1 true, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 4, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) +// HLOP-4: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x i32>* %output_vector, i1 false, <8 x float> %{{[^ ]+}}, i1 false, i32 20, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 4) +// DXIL-4: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8f32(i32 306, <8 x float> {{[^ ]+}}, i1 false, i32 20, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 4, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) + +// Test unsigned variations +// HLOP-5: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x i32>* %output_vector, i1 true, <8 x float> %{{[^ ]+}}, i1 false, i32 20, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20) +// DXIL-5: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8f32(i32 306, <8 x float> {{[^ ]+}}, i1 false, i32 20, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i1 true) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) +// HLOP-6: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x i32>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x i32>* %output_vector, i1 false, <8 x i32> %{{[^ ]+}}, i1 true, i32 19, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 1, i1 false, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20) +// DXIL-6: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8i32(i32 306, <8 x i32> {{[^ ]+}}, i1 true, i32 19, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 1, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) +// HLOP-7: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x i32>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x i32>* %output_vector, i1 false, <8 x i32> %{{[^ ]+}}, i1 false, i32 19, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 19, i32 8, i32 8, i32 2, i1 true, i32 64, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20) +// DXIL-7: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8i32(i32 306, <8 x i32> {{[^ ]+}}, i1 false, i32 19, %dx.types.Handle {{[^ ]+}}, i32 0, i32 19, i32 8, i32 8, i32 2, i1 true, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned) + + +ByteAddressBuffer input_vector_buffer; +ByteAddressBuffer matrix_buffer; +ByteAddressBuffer bias_buffer; +RWByteAddressBuffer rw_matrix_buffer; +RWByteAddressBuffer output_vector_buffer; + +enum CompType { + Invalid = 0, + I1 = 1, + I16 = 2, + U16 = 3, + I32 = 4, + U32 = 5, + I64 = 6, + U64 = 7, + F16 = 8, + F32 = 9, + F64 = 10, + SNormF16 = 11, + UNormF16 = 12, + SNormF32 = 13, + UNormF32 = 14, + SNormF64 = 15, + UNormF64 = 16, + PackedS8x32 = 17, + PackedU8x32 = 18, + + // BEGIN NEW FOR SM 6.9 + U8 = 19, + I8 = 20, + F8_E4M3 = 21, + F8_E5M2 = 22, +}; + +enum MatLayout { + RowMajor = 0, + ColumnMajor = 1, + MulOptimal = 2, + OuterProductOptimal = 3, +}; + +[NumThreads(1,1,1)] +void main() +{ + vector output_vector; + static const uint is_output_unsigned = OU; + + vector input_vector = input_vector_buffer.Load >(0); + const uint is_input_unsigned = IU; + const uint input_interpretation = II; + + const uint matrix_offset = 0; + const uint matrix_interpretation = MI; + const uint matrix_dimM = 8; + const uint matrix_dimK = 8; + const uint matrix_layout = ML; + const bool matrix_is_transposed = (bool) MT; + const uint matrix_stride = 64; + + const uint bias_offset = 0; + const uint bias_interpretation = BI; + + __builtin_MatVecMulAdd(output_vector, is_output_unsigned, input_vector, is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, matrix_interpretation, + matrix_dimM, matrix_dimK, matrix_layout, matrix_is_transposed, matrix_stride, bias_buffer, bias_offset, bias_interpretation); + output_vector_buffer.Store(0, output_vector); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul_multioverload.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul_multioverload.hlsl new file mode 100644 index 0000000000..2ca2648503 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/mat-vec-mul_multioverload.hlsl @@ -0,0 +1,104 @@ +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F16 -DMI=F16 -DML=RowMajor -DMT=0 | FileCheck %s --check-prefixes COMMON,DXIL-0 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E4M3 -DMI=F8_E4M3 -DML=MulOptimal -DMT=0 | FileCheck %s --check-prefixes COMMON,DXIL-1 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E5M2 -DMI=F8_E5M2 -DML=MulOptimal -DMT=1 | FileCheck %s --check-prefixes COMMON,DXIL-2 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=uint -DII=PackedS8x32 -DMI=I8 -DML=OuterProductOptimal -DMT=1 | FileCheck %s --check-prefixes COMMON,DXIL-3 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=float -DII=I8 -DMI=I8 -DML=RowMajor -DMT=0 | FileCheck %s --check-prefixes COMMON,DXIL-4 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=1 -DOTY=uint -DIU=0 -DITY=float -DII=I8 -DMI=F16 -DML=RowMajor -DMT=0 | FileCheck %s --check-prefixes COMMON,DXIL-5 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=1 -DITY=uint -DII=U8 -DMI=I8 -DML=ColumnMajor -DMT=0 | FileCheck %s --check-prefixes COMMON,DXIL-6 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=int -DII=U8 -DMI=U8 -DML=MulOptimal -DMT=1 | FileCheck %s --check-prefixes COMMON,DXIL-7 + +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F16 -DMI=F16 -DML=RowMajor -DMT=0 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-0 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E4M3 -DMI=F8_E4M3 -DML=MulOptimal -DMT=0 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-1 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E5M2 -DMI=F8_E5M2 -DML=MulOptimal -DMT=1 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-2 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=uint -DII=PackedS8x32 -DMI=I8 -DML=OuterProductOptimal -DMT=1 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-3 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=float -DII=I8 -DMI=I8 -DML=RowMajor -DMT=0 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-4 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=1 -DOTY=uint -DIU=0 -DITY=float -DII=I8 -DMI=F16 -DML=RowMajor -DMT=0 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-5 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=1 -DITY=uint -DII=U8 -DMI=I8 -DML=ColumnMajor -DMT=0 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-6 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=int -DII=U8 -DMI=U8 -DML=MulOptimal -DMT=1 -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-7 + +// COMMON: define void @main() + +// Test minimum support set of combinations for matVecMul +// HLOP-0: call void @"dx.hl.op..void (i32, <4 x half>*, i1, <8 x half>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x half>* %output_vector, i1 false, <8 x half> %{{[^ ]+}}, i1 false, i32 8, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64) +// DXIL-0: call <4 x half> @dx.op.matVecMul.v4f16.v8f16(i32 305, <8 x half> {{[^ ]+}}, i1 false, i32 8, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, i1 false) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) +// HLOP-1: call void @"dx.hl.op..void (i32, <4 x half>*, i1, <8 x half>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x half>* %output_vector, i1 false, <8 x half> %{{[^ ]+}}, i1 false, i32 21, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 21, i32 8, i32 8, i32 2, i1 false, i32 64) +// DXIL-1: call <4 x half> @dx.op.matVecMul.v4f16.v8f16(i32 305, <8 x half> {{[^ ]+}}, i1 false, i32 21, %dx.types.Handle {{[^ ]+}}, i32 0, i32 21, i32 8, i32 8, i32 2, i1 false, i32 64, i1 false) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) +// HLOP-2: call void @"dx.hl.op..void (i32, <4 x half>*, i1, <8 x half>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x half>* %output_vector, i1 false, <8 x half> %{{[^ ]+}}, i1 false, i32 22, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 22, i32 8, i32 8, i32 2, i1 true, i32 64) +// DXIL-2: call <4 x half> @dx.op.matVecMul.v4f16.v8f16(i32 305, <8 x half> {{[^ ]+}}, i1 false, i32 22, %dx.types.Handle {{[^ ]+}}, i32 0, i32 22, i32 8, i32 8, i32 2, i1 true, i32 64, i1 false) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) +// HLOP-3: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x i32>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x i32>* %output_vector, i1 false, <8 x i32> %{{[^ ]+}}, i1 false, i32 17, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 3, i1 true, i32 64) +// DXIL-3: call <4 x i32> @dx.op.matVecMul.v4i32.v8i32(i32 305, <8 x i32> {{[^ ]+}}, i1 false, i32 17, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 3, i1 true, i32 64, i1 false) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) +// HLOP-4: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x i32>* %output_vector, i1 false, <8 x float> %{{[^ ]+}}, i1 false, i32 20, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 0, i1 false, i32 64) +// DXIL-4: call <4 x i32> @dx.op.matVecMul.v4i32.v8f32(i32 305, <8 x float> {{[^ ]+}}, i1 false, i32 20, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 0, i1 false, i32 64, i1 false) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) + +// Test unsigned variations +// HLOP-5: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x i32>* %output_vector, i1 true, <8 x float> %{{[^ ]+}}, i1 false, i32 20, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64) +// DXIL-5: call <4 x i32> @dx.op.matVecMul.v4i32.v8f32(i32 305, <8 x float> {{[^ ]+}}, i1 false, i32 20, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, i1 true) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) +// HLOP-6: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x i32>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x i32>* %output_vector, i1 false, <8 x i32> %{{[^ ]+}}, i1 true, i32 19, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 1, i1 false, i32 64) +// DXIL-6: call <4 x i32> @dx.op.matVecMul.v4i32.v8i32(i32 305, <8 x i32> {{[^ ]+}}, i1 true, i32 19, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 1, i1 false, i32 64, i1 false) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) +// HLOP-7: call void @"dx.hl.op..void (i32, <4 x i32>*, i1, <8 x i32>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x i32>* %output_vector, i1 false, <8 x i32> %{{[^ ]+}}, i1 false, i32 19, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 19, i32 8, i32 8, i32 2, i1 true, i32 64) +// DXIL-7: call <4 x i32> @dx.op.matVecMul.v4i32.v8i32(i32 305, <8 x i32> {{[^ ]+}}, i1 false, i32 19, %dx.types.Handle {{[^ ]+}}, i32 0, i32 19, i32 8, i32 8, i32 2, i1 true, i32 64, i1 false) ; MatVecMul(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,isOutputUnsigned) + + +ByteAddressBuffer input_vector_buffer; +ByteAddressBuffer matrix_buffer; +ByteAddressBuffer bias_buffer; +RWByteAddressBuffer rw_matrix_buffer; +RWByteAddressBuffer output_vector_buffer; + +enum CompType { + Invalid = 0, + I1 = 1, + I16 = 2, + U16 = 3, + I32 = 4, + U32 = 5, + I64 = 6, + U64 = 7, + F16 = 8, + F32 = 9, + F64 = 10, + SNormF16 = 11, + UNormF16 = 12, + SNormF32 = 13, + UNormF32 = 14, + SNormF64 = 15, + UNormF64 = 16, + PackedS8x32 = 17, + PackedU8x32 = 18, + + // BEGIN NEW FOR SM 6.9 + U8 = 19, + I8 = 20, + F8_E4M3 = 21, + F8_E5M2 = 22, +}; + +enum MatLayout { + RowMajor = 0, + ColumnMajor = 1, + MulOptimal = 2, + OuterProductOptimal = 3, +}; + +[NumThreads(1,1,1)] +void main() +{ + vector output_vector; + static const uint is_output_unsigned = OU; + + vector input_vector = input_vector_buffer.Load >(0); + const uint is_input_unsigned = IU; + const uint input_interpretation = II; + + const uint matrix_offset = 0; + const uint matrix_interpretation = MI; + const uint matrix_dimM = 8; + const uint matrix_dimK = 8; + const uint matrix_layout = ML; + const bool matrix_is_transposed = (bool) MT; + const uint matrix_stride = 64; + + __builtin_MatVecMul(output_vector, is_output_unsigned, input_vector, is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, matrix_interpretation, + matrix_dimM, matrix_dimK, matrix_layout, matrix_is_transposed, matrix_stride); + output_vector_buffer.Store(0, output_vector); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/outer-product-accumulate-multioverload.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/outer-product-accumulate-multioverload.hlsl new file mode 100644 index 0000000000..40bbe62284 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/outer-product-accumulate-multioverload.hlsl @@ -0,0 +1,70 @@ +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DITY=float16_t -DMI=F16 -DML=RowMajor | FileCheck %s --check-prefixes COMMON,DXIL-0 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DITY=float16_t -DMI=F8_E4M3 -DML=OuterProductOptimal | FileCheck %s --check-prefixes COMMON,DXIL-1 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DITY=uint -DMI=U8 -DML=OuterProductOptimal | FileCheck %s --check-prefixes COMMON,DXIL-2 + +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DITY=float16_t -DMI=F16 -DML=RowMajor -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-0 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DITY=float16_t -DMI=F8_E4M3 -DML=OuterProductOptimal -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-1 +// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DITY=uint -DMI=U8 -DML=OuterProductOptimal -fcgl | FileCheck %s --check-prefixes COMMON,HLOP-2 + +ByteAddressBuffer input_vector_buffer; +ByteAddressBuffer input_vector_buffer2; +RWByteAddressBuffer matrix_buffer; + +// COMMON: define void @main() +// DXIL-0: call void @dx.op.outerProductAccumulate.v8f16.v8f16(i32 307, <8 x half> %{{[^ ]+}}, <8 x half> %{{[^ ]+}}, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8, i32 0, i32 64) ; OuterProductAccumulate(inputVector1,inputVector2,matrixBuffer,matrixOffset,matrixIntepretation,matrixLayout,matrixStride) +// HLOP-0: call void @"dx.hl.op..void (i32, <8 x half>, <8 x half>, %dx.types.Handle, i32, i32, i32, i32)"(i32 392, <8 x half> %{{[^ ]+}}, <8 x half> %{{[^ ]+}}, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 8, i32 0, i32 64) +// DXIL-1: call void @dx.op.outerProductAccumulate.v8f16.v8f16(i32 307, <8 x half> %{{[^ ]+}}, <8 x half> %{{[^ ]+}}, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 21, i32 3, i32 64) ; OuterProductAccumulate(inputVector1,inputVector2,matrixBuffer,matrixOffset,matrixIntepretation,matrixLayout,matrixStride) +// HLOP-1: call void @"dx.hl.op..void (i32, <8 x half>, <8 x half>, %dx.types.Handle, i32, i32, i32, i32)"(i32 392, <8 x half> %{{[^ ]+}}, <8 x half> %{{[^ ]+}}, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 21, i32 3, i32 64) +// DXIL-2: call void @dx.op.outerProductAccumulate.v8i32.v8i32(i32 307, <8 x i32> %{{[^ ]+}}, <8 x i32> %{{[^ ]+}}, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 19, i32 3, i32 64) ; OuterProductAccumulate(inputVector1,inputVector2,matrixBuffer,matrixOffset,matrixIntepretation,matrixLayout,matrixStride) +// HLOP-2: call void @"dx.hl.op..void (i32, <8 x i32>, <8 x i32>, %dx.types.Handle, i32, i32, i32, i32)"(i32 392, <8 x i32> %{{[^ ]+}}, <8 x i32> %{{[^ ]+}}, %dx.types.Handle %{{[^ ]+}}, i32 0, i32 19, i32 3, i32 64) + +enum CompType { + Invalid = 0, + I1 = 1, + I16 = 2, + U16 = 3, + I32 = 4, + U32 = 5, + I64 = 6, + U64 = 7, + F16 = 8, + F32 = 9, + F64 = 10, + SNormF16 = 11, + UNormF16 = 12, + SNormF32 = 13, + UNormF32 = 14, + SNormF64 = 15, + UNormF64 = 16, + PackedS8x32 = 17, + PackedU8x32 = 18, + + // BEGIN NEW FOR SM 6.9 + U8 = 19, + I8 = 20, + F8_E4M3 = 21, + F8_E5M2 = 22, +}; + +enum MatLayout { + RowMajor = 0, + ColumnMajor = 1, + MulOptimal = 2, + OuterProductOptimal = 3, +}; + + +[Numthreads(1,1,1)] +void main() +{ + vector input_vector1 = input_vector_buffer.Load >(0); + vector input_vector2 = input_vector_buffer2.Load >(0); + + const uint matrix_interpretation = MI; + const uint matrix_layout = ML; + const uint matrix_offset = 0; + const uint matrix_stride = 64; + + __builtin_OuterProductAccumulate(input_vector1, input_vector2, matrix_buffer, matrix_offset, matrix_interpretation, matrix_layout, matrix_stride); + +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/vector-accumulate.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/vector-accumulate.hlsl new file mode 100644 index 0000000000..dc1bb6c563 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/intrinsics/linalg_builtins/vector-accumulate.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -T cs_6_9 %s | FileCheck %s + +RWByteAddressBuffer matrix_buffer; + +// Test use of __builtin_VectorAccumulate in compute shader +// CHECK: define void @main() +// CHECK: call void @dx.op.vectorAccumulate.v2i32(i32 {{[0-9]+}}, <2 x i32> , %dx.types.Handle {{%[0-9]+}}, i32 0) + +[NumThreads(1,1,1)] +void main() +{ + vector input_vector1 = 5; + const uint matrix_offset = 0; + + __builtin_VectorAccumulate(input_vector1, matrix_buffer, matrix_offset); +} diff --git a/tools/clang/test/DXC/Passes/DxilGen/linalg-builtins.ll b/tools/clang/test/DXC/Passes/DxilGen/linalg-builtins.ll new file mode 100644 index 0000000000..6623f63031 --- /dev/null +++ b/tools/clang/test/DXC/Passes/DxilGen/linalg-builtins.ll @@ -0,0 +1,189 @@ +; RUN: %dxopt %s -hlsl-passes-resume -dxilgen -S | FileCheck %s +; REQUIRES: dxil-1-9 + +target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-ms-dx" + +%struct.ByteAddressBuffer = type { i32 } +%struct.RWByteAddressBuffer = type { i32 } +%dx.types.Handle = type { i8* } +%dx.types.ResourceProperties = type { i32, i32 } + +@"\01?input_vector_buffer@@3UByteAddressBuffer@@A" = external global %struct.ByteAddressBuffer, align 4 +@"\01?opa_input_buffer@@3UByteAddressBuffer@@A" = external global %struct.ByteAddressBuffer, align 4 +@"\01?matrix_buffer@@3UByteAddressBuffer@@A" = external global %struct.ByteAddressBuffer, align 4 +@"\01?bias_buffer@@3UByteAddressBuffer@@A" = external global %struct.ByteAddressBuffer, align 4 +@"\01?rw_matrix_buffer@@3URWByteAddressBuffer@@A" = external global %struct.RWByteAddressBuffer, align 4 +@"\01?output_vector_buffer@@3URWByteAddressBuffer@@A" = external global %struct.RWByteAddressBuffer, align 4 + +; Function Attrs: nounwind +define void @cs_main() #0 { +entry: + ;CHECK-DAG: %[[MLD:[^ ]+]] = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?matrix_buffer@@3UByteAddressBuffer@@A" + ;CHECK-DAG: %[[BLD:[^ ]+]] = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?bias_buffer@@3UByteAddressBuffer@@A" + ;CHECK-DAG: %[[RWMLD0:[^ ]+]] = load %struct.RWByteAddressBuffer, %struct.RWByteAddressBuffer* @"\01?rw_matrix_buffer@@3URWByteAddressBuffer@@A" + %output_vector = alloca <4 x float>, align 4 + %tmp = bitcast <4 x float>* %output_vector to i8*, !dbg !21 ; line:14 col:5 + call void @llvm.lifetime.start(i64 16, i8* %tmp) #0, !dbg !21 ; line:14 col:5 + %tmp1 = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?input_vector_buffer@@3UByteAddressBuffer@@A", !dbg !25 ; line:17 col:37 + %tmp2 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %tmp1), !dbg !25 ; line:17 col:37 + %tmp3 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp2, %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer zeroinitializer), !dbg !25 ; line:17 col:37 + %tmp4 = call <4 x float> @"dx.hl.op.ro.<4 x float> (i32, %dx.types.Handle, i32)"(i32 231, %dx.types.Handle %tmp3, i32 0), !dbg !25 ; line:17 col:37 + %tmp5 = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?matrix_buffer@@3UByteAddressBuffer@@A", !dbg !26 ; line:33 col:5 + %tmp6 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %tmp5), !dbg !26 ; line:33 col:5 + %tmp7 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp6, %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer zeroinitializer), !dbg !26 ; line:33 col:5 + + ;CHECK: %[[MCH0:[^ ]+]] = call %dx.types.Handle @dx.op.createHandleForLib.struct.ByteAddressBuffer(i32 160, %struct.ByteAddressBuffer %[[MLD]] + ;CHECK: %[[MAH0:[^ ]+]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[MCH0]] + ;CHECK: call <4 x float> @dx.op.matVecMul.v4f32.v4f32(i32 305, <4 x float> %{{[^ ]+}}, i1 false, i32 9, %dx.types.Handle %[[MAH0]], i32 0, i32 9, i32 4, i32 4, i32 0, i1 false, i32 64, i1 false) + call void @"dx.hl.op..void (i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32 390, <4 x float>* %output_vector, i1 false, <4 x float> %tmp4, i1 false, i32 9, %dx.types.Handle %tmp7, i32 0, i32 9, i32 4, i32 4, i32 0, i1 false, i32 64), !dbg !26 ; line:33 col:5 + + %tmp8 = load <4 x float>, <4 x float>* %output_vector, align 4, !dbg !27, !tbaa !28 ; line:37 col:35 + %tmp9 = load %struct.RWByteAddressBuffer, %struct.RWByteAddressBuffer* @"\01?output_vector_buffer@@3URWByteAddressBuffer@@A", !dbg !31 ; line:37 col:5 + %tmp10 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.RWByteAddressBuffer)"(i32 0, %struct.RWByteAddressBuffer %tmp9), !dbg !31 ; line:37 col:5 + %tmp11 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp10, %dx.types.ResourceProperties { i32 4107, i32 0 }, %struct.RWByteAddressBuffer zeroinitializer), !dbg !31 ; line:37 col:5 + call void @"dx.hl.op..void (i32, %dx.types.Handle, i32, <4 x float>)"(i32 277, %dx.types.Handle %tmp11, i32 0, <4 x float> %tmp8), !dbg !31 ; line:37 col:5 + %tmp12 = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?matrix_buffer@@3UByteAddressBuffer@@A", !dbg !32 ; line:49 col:5 + %tmp13 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %tmp12), !dbg !32 ; line:49 col:5 + %tmp14 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp13, %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer zeroinitializer), !dbg !32 ; line:49 col:5 + %tmp15 = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?bias_buffer@@3UByteAddressBuffer@@A", !dbg !32 ; line:49 col:5 + %tmp16 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %tmp15), !dbg !32 ; line:49 col:5 + %tmp17 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp16, %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer zeroinitializer), !dbg !32 ; line:49 col:5 + + ;CHECK: %[[MCH1:[^ ]+]] = call %dx.types.Handle @dx.op.createHandleForLib.struct.ByteAddressBuffer(i32 160, %struct.ByteAddressBuffer %[[MLD]] + ;CHECK: %[[MAH1:[^ ]+]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[MCH1]] + ;CHECK: %[[BCH1:[^ ]+]] = call %dx.types.Handle @dx.op.createHandleForLib.struct.ByteAddressBuffer(i32 160, %struct.ByteAddressBuffer %[[BLD]] + ;CHECK: %[[BAH1:[^ ]+]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[BCH1]] + ;CHECK: call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{[^ ]+}}, i1 false, i32 9, %dx.types.Handle %[[MAH1]], i32 0, i32 9, i32 4, i32 4, i32 0, i1 false, i32 64, %dx.types.Handle %[[BAH1]], i32 0, i32 9, i1 false) + call void @"dx.hl.op..void (i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32 391, <4 x float>* %output_vector, i1 false, <4 x float> %tmp4, i1 false, i32 9, %dx.types.Handle %tmp14, i32 0, i32 9, i32 4, i32 4, i32 0, i1 false, i32 64, %dx.types.Handle %tmp17, i32 0, i32 9), !dbg !32 ; line:49 col:5 + + %tmp18 = load <4 x float>, <4 x float>* %output_vector, align 4, !dbg !33, !tbaa !28 ; line:54 col:38 + %tmp19 = load %struct.RWByteAddressBuffer, %struct.RWByteAddressBuffer* @"\01?output_vector_buffer@@3URWByteAddressBuffer@@A", !dbg !34 ; line:54 col:5 + %tmp20 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.RWByteAddressBuffer)"(i32 0, %struct.RWByteAddressBuffer %tmp19), !dbg !34 ; line:54 col:5 + %tmp21 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp20, %dx.types.ResourceProperties { i32 4107, i32 0 }, %struct.RWByteAddressBuffer zeroinitializer), !dbg !34 ; line:54 col:5 + call void @"dx.hl.op..void (i32, %dx.types.Handle, i32, <4 x float>)"(i32 277, %dx.types.Handle %tmp21, i32 1024, <4 x float> %tmp18), !dbg !34 ; line:54 col:5 + %tmp22 = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?opa_input_buffer@@3UByteAddressBuffer@@A", !dbg !35 ; line:56 col:37 + %tmp23 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %tmp22), !dbg !35 ; line:56 col:37 + %tmp24 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp23, %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer zeroinitializer), !dbg !35 ; line:56 col:37 + %tmp25 = call <8 x i32> @"dx.hl.op.ro.<8 x i32> (i32, %dx.types.Handle, i32)"(i32 231, %dx.types.Handle %tmp24, i32 0), !dbg !35 ; line:56 col:37 + %tmp26 = load %struct.ByteAddressBuffer, %struct.ByteAddressBuffer* @"\01?opa_input_buffer@@3UByteAddressBuffer@@A", !dbg !36 ; line:57 col:37 + %tmp27 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32 0, %struct.ByteAddressBuffer %tmp26), !dbg !36 ; line:57 col:37 + %tmp28 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp27, %dx.types.ResourceProperties { i32 11, i32 0 }, %struct.ByteAddressBuffer zeroinitializer), !dbg !36 ; line:57 col:37 + %tmp29 = call <8 x i32> @"dx.hl.op.ro.<8 x i32> (i32, %dx.types.Handle, i32)"(i32 231, %dx.types.Handle %tmp28, i32 128), !dbg !36 ; line:57 col:37 + %tmp30 = load %struct.RWByteAddressBuffer, %struct.RWByteAddressBuffer* @"\01?rw_matrix_buffer@@3URWByteAddressBuffer@@A", !dbg !37 ; line:67 col:5 + %tmp31 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.RWByteAddressBuffer)"(i32 0, %struct.RWByteAddressBuffer %tmp30), !dbg !37 ; line:67 col:5 + %tmp32 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp31, %dx.types.ResourceProperties { i32 4107, i32 0 }, %struct.RWByteAddressBuffer zeroinitializer), !dbg !37 ; line:67 col:5 + + ;CHECK: %[[RWMCH0:[^ ]+]] = call %dx.types.Handle @dx.op.createHandleForLib.struct.RWByteAddressBuffer(i32 160, %struct.RWByteAddressBuffer %[[RWMLD0]] + ;CHECK: %[[RWMAH0:[^ ]+]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[RWMCH0]] + ;CHECK: call void @dx.op.outerProductAccumulate.v8i32.v8i32(i32 307, <8 x i32> %{{[^ ]+}}, <8 x i32> %{{[^ ]+}}, %dx.types.Handle %[[RWMAH0]], i32 0, i32 5, i32 3, i32 64) + call void @"dx.hl.op..void (i32, <8 x i32>, <8 x i32>, %dx.types.Handle, i32, i32, i32, i32)"(i32 392, <8 x i32> %tmp25, <8 x i32> %tmp29, %dx.types.Handle %tmp32, i32 0, i32 5, i32 3, i32 64), !dbg !37 ; line:67 col:5 + + + %tmp33 = load %struct.RWByteAddressBuffer, %struct.RWByteAddressBuffer* @"\01?rw_matrix_buffer@@3URWByteAddressBuffer@@A", !dbg !38 ; line:77 col:5 + %tmp34 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.RWByteAddressBuffer)"(i32 0, %struct.RWByteAddressBuffer %tmp33), !dbg !38 ; line:77 col:5 + %tmp35 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer)"(i32 14, %dx.types.Handle %tmp34, %dx.types.ResourceProperties { i32 4107, i32 0 }, %struct.RWByteAddressBuffer zeroinitializer), !dbg !38 ; line:77 col:5 + + ;CHECK: %[[RWMCH1:[^ ]+]] = call %dx.types.Handle @dx.op.createHandleForLib.struct.RWByteAddressBuffer(i32 160, %struct.RWByteAddressBuffer %[[RWMLD0]] + ;CHECK: %[[RWMAH1:[^ ]+]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[RWMCH1]] + ;CHECK: call void @dx.op.vectorAccumulate.v8i32(i32 308, <8 x i32> %{{[^ ]+}}, %dx.types.Handle %[[RWMAH1]], i32 0) + call void @"dx.hl.op..void (i32, <8 x i32>, %dx.types.Handle, i32)"(i32 393, <8 x i32> %tmp25, %dx.types.Handle %tmp35, i32 0), !dbg !38 ; line:77 col:5 + + %tmp36 = bitcast <4 x float>* %output_vector to i8*, !dbg !39 ; line:79 col:1 + call void @llvm.lifetime.end(i64 16, i8* %tmp36) #0, !dbg !39 ; line:79 col:1 + ret void, !dbg !39 ; line:79 col:1 +} + +; Function Attrs: nounwind +declare void @llvm.lifetime.start(i64, i8* nocapture) #0 + +; Function Attrs: nounwind +declare void @llvm.lifetime.end(i64, i8* nocapture) #0 + +; Function Attrs: nounwind readonly +declare <4 x float> @"dx.hl.op.ro.<4 x float> (i32, %dx.types.Handle, i32)"(i32, %dx.types.Handle, i32) #1 + +; Function Attrs: nounwind readnone +declare %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.ByteAddressBuffer)"(i32, %struct.ByteAddressBuffer) #2 + +; Function Attrs: nounwind readnone +declare %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer)"(i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.ByteAddressBuffer) #2 + +; Function Attrs: nounwind +declare void @"dx.hl.op..void (i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32)"(i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32) #0 + +; Function Attrs: nounwind +declare void @"dx.hl.op..void (i32, %dx.types.Handle, i32, <4 x float>)"(i32, %dx.types.Handle, i32, <4 x float>) #0 + +; Function Attrs: nounwind readnone +declare %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %struct.RWByteAddressBuffer)"(i32, %struct.RWByteAddressBuffer) #2 + +; Function Attrs: nounwind readnone +declare %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer)"(i32, %dx.types.Handle, %dx.types.ResourceProperties, %struct.RWByteAddressBuffer) #2 + +; Function Attrs: nounwind +declare void @"dx.hl.op..void (i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32)"(i32, <4 x float>*, i1, <4 x float>, i1, i32, %dx.types.Handle, i32, i32, i32, i32, i32, i1, i32, %dx.types.Handle, i32, i32) #0 + +; Function Attrs: nounwind readonly +declare <8 x i32> @"dx.hl.op.ro.<8 x i32> (i32, %dx.types.Handle, i32)"(i32, %dx.types.Handle, i32) #1 + +; Function Attrs: nounwind +declare void @"dx.hl.op..void (i32, <8 x i32>, <8 x i32>, %dx.types.Handle, i32, i32, i32, i32)"(i32, <8 x i32>, <8 x i32>, %dx.types.Handle, i32, i32, i32, i32) #0 + +; Function Attrs: nounwind +declare void @"dx.hl.op..void (i32, <8 x i32>, %dx.types.Handle, i32)"(i32, <8 x i32>, %dx.types.Handle, i32) #0 + +attributes #0 = { nounwind } +attributes #1 = { nounwind readonly } +attributes #2 = { nounwind readnone } + +!llvm.module.flags = !{!0} +!pauseresume = !{!1} +!dx.version = !{!2} +!dx.valver = !{!2} +!dx.shaderModel = !{!3} +!dx.typeAnnotations = !{!4} +!dx.entryPoints = !{!8} +!dx.fnprops = !{!18} +!dx.options = !{!19, !20} + +!0 = !{i32 2, !"Debug Info Version", i32 3} +!1 = !{!"hlsl-hlemit", !"hlsl-hlensure"} +!2 = !{i32 1, i32 9} +!3 = !{!"cs", i32 6, i32 9} +!4 = !{i32 1, void ()* @cs_main, !5} +!5 = !{!6} +!6 = !{i32 1, !7, !7} +!7 = !{} +!8 = !{void ()* @cs_main, !"cs_main", null, !9, null} +!9 = !{!10, !15, null, null} +!10 = !{!11, !12, !13, !14} +!11 = !{i32 0, %struct.ByteAddressBuffer* @"\01?input_vector_buffer@@3UByteAddressBuffer@@A", !"input_vector_buffer", i32 -1, i32 -1, i32 1, i32 11, i32 0, null} +!12 = !{i32 1, %struct.ByteAddressBuffer* @"\01?opa_input_buffer@@3UByteAddressBuffer@@A", !"opa_input_buffer", i32 -1, i32 -1, i32 1, i32 11, i32 0, null} +!13 = !{i32 2, %struct.ByteAddressBuffer* @"\01?matrix_buffer@@3UByteAddressBuffer@@A", !"matrix_buffer", i32 -1, i32 -1, i32 1, i32 11, i32 0, null} +!14 = !{i32 3, %struct.ByteAddressBuffer* @"\01?bias_buffer@@3UByteAddressBuffer@@A", !"bias_buffer", i32 -1, i32 -1, i32 1, i32 11, i32 0, null} +!15 = !{!16, !17} +!16 = !{i32 0, %struct.RWByteAddressBuffer* @"\01?rw_matrix_buffer@@3URWByteAddressBuffer@@A", !"rw_matrix_buffer", i32 -1, i32 -1, i32 1, i32 11, i1 false, i1 false, i1 false, null} +!17 = !{i32 1, %struct.RWByteAddressBuffer* @"\01?output_vector_buffer@@3URWByteAddressBuffer@@A", !"output_vector_buffer", i32 -1, i32 -1, i32 1, i32 11, i1 false, i1 false, i1 false, null} +!18 = !{void ()* @cs_main, i32 5, i32 1, i32 1, i32 1} +!19 = !{i32 -2147483584} +!20 = !{i32 -1} +!21 = !DILocation(line: 14, column: 5, scope: !22) +!22 = !DISubprogram(name: "cs_main", scope: !23, file: !23, line: 12, type: !24, isLocal: false, isDefinition: true, scopeLine: 13, flags: DIFlagPrototyped, isOptimized: false, function: void ()* @cs_main) +!23 = !DIFile(filename: "DirectXShaderCompiler\5Ctools\5Cclang\5Ctest\5CCodeGenDXIL\5Chlsl\5Cintrinsics\5Clinalg_builtins\5Clinalg-builtins.hlsl", directory: "") +!24 = !DISubroutineType(types: !7) +!25 = !DILocation(line: 17, column: 37, scope: !22) +!26 = !DILocation(line: 33, column: 5, scope: !22) +!27 = !DILocation(line: 37, column: 35, scope: !22) +!28 = !{!29, !29, i64 0} +!29 = !{!"omnipotent char", !30, i64 0} +!30 = !{!"Simple C/C++ TBAA"} +!31 = !DILocation(line: 37, column: 5, scope: !22) +!32 = !DILocation(line: 49, column: 5, scope: !22) +!33 = !DILocation(line: 54, column: 38, scope: !22) +!34 = !DILocation(line: 54, column: 5, scope: !22) +!35 = !DILocation(line: 56, column: 37, scope: !22) +!36 = !DILocation(line: 57, column: 37, scope: !22) +!37 = !DILocation(line: 67, column: 5, scope: !22) +!38 = !DILocation(line: 77, column: 5, scope: !22) +!39 = !DILocation(line: 79, column: 1, scope: !22) diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/unavailable-pre-sm69.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/unavailable-pre-sm69.hlsl new file mode 100644 index 0000000000..d5e251ae8b --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/unavailable-pre-sm69.hlsl @@ -0,0 +1,59 @@ +// RUN: %dxc -T lib_6_8 %s -verify + +ByteAddressBuffer matrix_buffer; +ByteAddressBuffer bias_buffer; +RWByteAddressBuffer rw_matrix_buffer; + +[Shader("compute")] +[Numthreads(1,1,1)] +void cs_main() +{ + vector output_vector; + static const uint is_output_unsigned = 0; + + vector input_vector; + const uint is_input_unsigned = 0; + const uint input_interpretation = 9; /*F32*/ + + const uint matrix_offset = 0; + const uint matrix_interpretation = 9; /*F32*/ + const uint matrix_dimM = 4; + const uint matrix_dimK = 4; + const uint matrix_layout = 0; /*RowMajor*/ + const bool matrix_is_transposed = false; + const uint matrix_stride = 64; + + //expected-error@+1{{intrinsic __builtin_MatVecMul potentially used by 'cs_main' requires shader model 6.9 or greater}} + __builtin_MatVecMul(output_vector, is_output_unsigned, input_vector, + is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, + matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout, + matrix_is_transposed, matrix_stride); + + const uint bias_offset = 0; + const uint bias_interpretation = 9; /*F32*/ + + //expected-error@+1{{intrinsic __builtin_MatVecMulAdd potentially used by 'cs_main' requires shader model 6.9 or greater}} + __builtin_MatVecMulAdd(output_vector, is_output_unsigned, input_vector, + is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, + matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout, + matrix_is_transposed, matrix_stride, bias_buffer, bias_offset, + bias_interpretation); + + vector input_vector1; + vector input_vector2; + const uint opa_matrix_offset = 0; + const uint opa_matrix_interpretation = 5; /*U32*/ + const uint opa_matrix_layout = 3; /*OuterProductOptimal*/ + const uint opa_matrix_stride = 64; + + //expected-error@+1{{intrinsic __builtin_OuterProductAccumulate potentially used by 'cs_main' requires shader model 6.9 or greater}} + __builtin_OuterProductAccumulate(input_vector1, input_vector2, + rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation, + opa_matrix_layout, opa_matrix_stride); + + const uint va_matrix_offset = 0; + + //expected-error@+1{{intrinsic __builtin_VectorAccumulate potentially used by 'cs_main' requires shader model 6.9 or greater}} + __builtin_VectorAccumulate(input_vector1, rw_matrix_buffer, + va_matrix_offset); +} \ No newline at end of file diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index f1274fd308..c394611302 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -383,6 +383,14 @@ void [[]] Barrier(in NodeRecordOrUAV o, in uint SemanticFlags); uint [[]] GetRemainingRecursionLevels(); +void [[]] __builtin_MatVecMul(out numeric OutputVector, in bool OutputIsUnsigned, in numeric InputVector, in bool InputIsUnsigned, in uint InputInterpretation, in ByteAddressBuffer MatrixBuffer, in uint MatrixOffset, in uint MatrixInterpretation, in uint M, in uint K, in uint MatrixLayout, in bool MatrixIsTransposed, in uint MatrixStride); + +void [[]] __builtin_MatVecMulAdd(out numeric OutputVector, in bool OutputIsUnsigned, in numeric InputVector, in bool InputIsUnsigned, in uint InputInterpretation, in ByteAddressBuffer MatrixBuffer, in uint MatrixOffset, in uint MatrixInterpretation, in uint M, in uint K, in uint MatrixLayout, in bool MatrixIsTransposed, in uint MatrixStride, in ByteAddressBuffer BiasVector, in uint BiasOffset, in uint BiasInterpretation); + +void [[]] __builtin_OuterProductAccumulate(in numeric InputVector1, in numeric InputVector2, in RWByteAddressBuffer MatrixBuffer, in uint MatrixOffset, in uint MatrixInterpretation, in uint MatrixLayout, in uint MatrixStride); + +void [[]] __builtin_VectorAccumulate(in numeric InputVector, in RWByteAddressBuffer MatrixBuffer, in uint MatrixOffset); + } namespace diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 6344fb5849..63af8c0b38 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -873,6 +873,11 @@ def populate_categories_and_models(self): "library", "raygeneration", ) + for i in ( + "MatVecMul,MatVecMulAdd,OuterProductAccumulate,VectorAccumulate" + ).split(","): + self.name_idx[i].category = "Linear Algebra Operations" + self.name_idx[i].shader_model = 6, 9 def populate_llvm_instructions(self): # Add instructions that map to LLVM instructions. @@ -6340,6 +6345,103 @@ def UFI(name, **mappings): ) next_op_idx += 1 + self.add_dxil_op( + "MatVecMul", + next_op_idx, + "MatVecMul", + "Multiplies a MxK dimension matrix and a K sized input vector", + " Date: Mon, 21 Apr 2025 17:15:09 -0700 Subject: [PATCH 09/31] Revert ADO pipelines to Ubuntu 22.04 temporarily (#7365) (#7366) DXC seems to be building inocrrectly with GCC-13 and later, which is causing our pre-merge testing on 24.04 to fail. This will take some time to sort out, so in the meantime I'm reverting to 22.04 on our pipelines. (cherry picked from commit b4a3076caa92c4e9ed05761cbcd2141591fb3f89) Co-authored-by: Chris B --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 285fc4028a..8f07f59077 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -54,7 +54,7 @@ stages: variables: macOS: macOS-latest - linux: Ubuntu-latest + linux: Ubuntu-22.04 # FIXME: #7364, DXC does not build correctly with GCC 13+ strategy: matrix: From 86dd84d6a7d21f9a4b9d3c5653a040041c8b840f Mon Sep 17 00:00:00 2001 From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com> Date: Tue, 22 Apr 2025 12:18:34 -0700 Subject: [PATCH 10/31] Basic implementation of priority long vector exec tests. (#7320) This PR is a basic implementation of the priority long vector execution tests #7260. --- include/dxc/Test/HlslTestUtils.h | 50 +- tools/clang/unittests/HLSLExec/CMakeLists.txt | 7 + .../unittests/HLSLExec/ExecutionTest.cpp | 1216 +++++++++++++++++ .../unittests/HLSLExec/ShaderOpArith.xml | 76 ++ .../clang/unittests/HLSLExec/ShaderOpTest.cpp | 8 + 5 files changed, 1351 insertions(+), 6 deletions(-) diff --git a/include/dxc/Test/HlslTestUtils.h b/include/dxc/Test/HlslTestUtils.h index 0e37ccdcff..3b6f9d4ec4 100644 --- a/include/dxc/Test/HlslTestUtils.h +++ b/include/dxc/Test/HlslTestUtils.h @@ -258,6 +258,17 @@ inline void LogErrorFmt(const wchar_t *fmt, ...) { WEX::Logging::Log::Error(buf.data()); } +inline void LogErrorFmtThrow(const wchar_t *fmt, ...) { + va_list args; + va_start(args, fmt); + std::wstring buf(vFormatToWString(fmt, args)); + va_end(args); + WEX::Logging::Log::Error(buf.data()); + + // Throws an exception to abort the test. + VERIFY_FAIL(L"Test error"); +} + inline std::wstring GetPathToHlslDataFile(const wchar_t *relative, LPCWSTR paramName = HLSLDATAFILEPARAM, @@ -459,15 +470,17 @@ inline bool GetTestParamUseWARP(bool defaultVal) { #ifdef FP_SUBNORMAL -inline bool isdenorm(float f) { return FP_SUBNORMAL == std::fpclassify(f); } +template inline bool isdenorm(T f) { + return FP_SUBNORMAL == std::fpclassify(f); +} #else -inline bool isdenorm(float f) { - return (std::numeric_limits::denorm_min() <= f && - f < std::numeric_limits::min()) || - (-std::numeric_limits::min() < f && - f <= -std::numeric_limits::denorm_min()); +template inline bool isdenorm(T f) { + return (std::numeric_limits::denorm_min() <= f && + f < std::numeric_limits::min()) || + (-std::numeric_limits::min() < f && + f <= -std::numeric_limits::denorm_min()); } #endif // FP_SUBNORMAL @@ -515,6 +528,31 @@ inline bool isnanFloat16(uint16_t val) { uint16_t ConvertFloat32ToFloat16(float val) throw(); float ConvertFloat16ToFloat32(uint16_t val) throw(); +inline bool CompareDoubleULP( + const double &Src, const double &Ref, int64_t ULPTolerance, + hlsl::DXIL::Float32DenormMode Mode = hlsl::DXIL::Float32DenormMode::Any) { + if (Src == Ref) { + return true; + } + if (std::isnan(Src)) { + return std::isnan(Ref); + } + + if (Mode == hlsl::DXIL::Float32DenormMode::Any) { + // If denorm expected, output can be sign preserved zero. Otherwise output + // should pass the regular ulp testing. + if (isdenorm(Ref) && Src == 0 && std::signbit(Src) == std::signbit(Ref)) + return true; + } + + // For FTZ or Preserve mode, we should get the expected number within + // ULPTolerance for any operations. + int64_t Diff = *((const uint64_t *)&Src) - *((const uint64_t *)&Ref); + + uint64_t AbsoluteDiff = Diff < 0 ? -Diff : Diff; + return AbsoluteDiff <= (uint64_t)ULPTolerance; +} + inline bool CompareFloatULP( const float &fsrc, const float &fref, int ULPTolerance, hlsl::DXIL::Float32DenormMode mode = hlsl::DXIL::Float32DenormMode::Any) { diff --git a/tools/clang/unittests/HLSLExec/CMakeLists.txt b/tools/clang/unittests/HLSLExec/CMakeLists.txt index 3878fa3f34..c047a9be00 100644 --- a/tools/clang/unittests/HLSLExec/CMakeLists.txt +++ b/tools/clang/unittests/HLSLExec/CMakeLists.txt @@ -39,3 +39,10 @@ endif() file(TO_NATIVE_PATH "${CMAKE_CURRENT_SOURCE_DIR}" DOS_STYLE_SOURCE_DIR) file(TO_NATIVE_PATH "${TAEF_BIN_DIR}" DOS_TAEF_BIN_DIR) configure_file(ExecHLSLTests.vcxproj.user.txt ExecHLSLTests.vcxproj.user) + +# Copy the ShaderOpArith.xml file to the output directory. It's used by the exec +# tests and it's convenient to have it copied here if you want to easily copy +# the tests to another machine after building. +set(XML_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/ShaderOpArith.xml) +set(XML_DESTINATION ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}/bin) +file(COPY ${XML_SOURCE} DESTINATION ${XML_DESTINATION}) \ No newline at end of file diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 3aff8bcda8..ee67944950 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -35,6 +35,8 @@ #include #include #include +#include +#include #undef _read #include "dxc/Test/DxcTestUtils.h" @@ -67,6 +69,9 @@ #pragma comment(lib, "dxguid.lib") #pragma comment(lib, "version.lib") +// Float values for this were taken from Microsoft online documentation for the +// DirectX HALF data type. HALF is equivalent to IEEE 754 binary 16 format. + // A more recent Windows SDK than currently required is needed for these. typedef HRESULT(WINAPI *D3D12EnableExperimentalFeaturesFn)( UINT NumFeatures, __in_ecount(NumFeatures) const IID *pIIDs, @@ -274,6 +279,9 @@ typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS4 { // Virtual class to compute the expected result given a set of inputs struct TableParameter; +template struct LongVectorOpTestConfig; // Forward declaration +enum LongVectorOpType; // Forward declaration + class ExecutionTest { public: BEGIN_TEST_CLASS(ExecutionTest) @@ -501,6 +509,116 @@ class ExecutionTest { L"Table:ShaderOpArithTable.xml#PackUnpackOpTable") END_TEST_METHOD() + // bool binary ops + TEST_METHOD(LongVector_ScalarAdd_bool) + TEST_METHOD(LongVector_ScalarMultiply_bool) + TEST_METHOD(LongVector_Multiply_bool) + TEST_METHOD(LongVector_Add_bool) + TEST_METHOD(LongVector_Min_bool) + TEST_METHOD(LongVector_Max_bool) + // bool unary ops + // Note that clamp doesn't make sense for bools. + TEST_METHOD(LongVector_Initialize_bool); + + // float16 (half) binary ops + TEST_METHOD(LongVector_ScalarAdd_float16) + TEST_METHOD(LongVector_ScalarMultiply_float16) + TEST_METHOD(LongVector_Multiply_float16) + TEST_METHOD(LongVector_Add_float16) + TEST_METHOD(LongVector_Min_float16) + TEST_METHOD(LongVector_Max_float16) + // float16 (half) unary ops + TEST_METHOD(LongVector_Clamp_float16); + TEST_METHOD(LongVector_Initialize_float16); + + // float32 binary ops + TEST_METHOD(LongVector_ScalarAdd_float32) + TEST_METHOD(LongVector_ScalarMultiply_float32) + TEST_METHOD(LongVector_Multiply_float32) + TEST_METHOD(LongVector_Add_float32) + TEST_METHOD(LongVector_Min_float32) + TEST_METHOD(LongVector_Max_float32) + // float32 unary ops + TEST_METHOD(LongVector_Clamp_float32); + TEST_METHOD(LongVector_Initialize_float32); + + // float64 binary ops + TEST_METHOD(LongVector_ScalarAdd_float64) + TEST_METHOD(LongVector_ScalarMultiply_float64) + TEST_METHOD(LongVector_Multiply_float64) + TEST_METHOD(LongVector_Add_float64) + TEST_METHOD(LongVector_Min_float64) + TEST_METHOD(LongVector_Max_float64) + // float64 unary ops + TEST_METHOD(LongVector_Clamp_float64); + TEST_METHOD(LongVector_Initialize_float64); + + // int16 binary ops + TEST_METHOD(LongVector_ScalarAdd_int16) + TEST_METHOD(LongVector_ScalarMultiply_int16) + TEST_METHOD(LongVector_Multiply_int16) + TEST_METHOD(LongVector_Add_int16) + TEST_METHOD(LongVector_Min_int16) + TEST_METHOD(LongVector_Max_int16) + // int16 unary ops + TEST_METHOD(LongVector_Clamp_int16); + TEST_METHOD(LongVector_Initialize_int16); + + // int32 binary ops + TEST_METHOD(LongVector_ScalarAdd_int32) + TEST_METHOD(LongVector_ScalarMultiply_int32) + TEST_METHOD(LongVector_Multiply_int32) + TEST_METHOD(LongVector_Add_int32) + TEST_METHOD(LongVector_Min_int32) + TEST_METHOD(LongVector_Max_int32) + // int32 unary ops + TEST_METHOD(LongVector_Clamp_int32); + TEST_METHOD(LongVector_Initialize_int32); + + // int64 binary ops + TEST_METHOD(LongVector_ScalarAdd_int64) + TEST_METHOD(LongVector_ScalarMultiply_int64) + TEST_METHOD(LongVector_Multiply_int64) + TEST_METHOD(LongVector_Add_int64) + TEST_METHOD(LongVector_Min_int64) + TEST_METHOD(LongVector_Max_int64) + // int64 unary ops + TEST_METHOD(LongVector_Clamp_int64); + TEST_METHOD(LongVector_Initialize_int64); + + // uint16 binary ops + TEST_METHOD(LongVector_ScalarAdd_uint16) + TEST_METHOD(LongVector_ScalarMultiply_uint16) + TEST_METHOD(LongVector_Multiply_uint16) + TEST_METHOD(LongVector_Add_uint16) + TEST_METHOD(LongVector_Min_uint16) + TEST_METHOD(LongVector_Max_uint16) + // uint16 unary ops + TEST_METHOD(LongVector_Clamp_uint16); + TEST_METHOD(LongVector_Initialize_uint16); + + // uint32 binary ops + TEST_METHOD(LongVector_ScalarAdd_uint32) + TEST_METHOD(LongVector_ScalarMultiply_uint32) + TEST_METHOD(LongVector_Multiply_uint32) + TEST_METHOD(LongVector_Add_uint32) + TEST_METHOD(LongVector_Min_uint32) + TEST_METHOD(LongVector_Max_uint32) + // uint32 unary ops + TEST_METHOD(LongVector_Clamp_uint32); + TEST_METHOD(LongVector_Initialize_uint32); + + // uint64 binary ops + TEST_METHOD(LongVector_ScalarAdd_uint64) + TEST_METHOD(LongVector_ScalarMultiply_uint64) + TEST_METHOD(LongVector_Multiply_uint64) + TEST_METHOD(LongVector_Add_uint64) + TEST_METHOD(LongVector_Min_uint64) + TEST_METHOD(LongVector_Max_uint64) + // uint64 unary ops + TEST_METHOD(LongVector_Clamp_uint64); + TEST_METHOD(LongVector_Initialize_uint64); + dxc::DxcDllSupport m_support; bool m_D3DInitCompleted = false; @@ -710,6 +828,10 @@ class ExecutionTest { const char *pShaderModelStr, const char *pShader, Ty *pInputDataPairs, unsigned inputDataCount); + template + void LongVectorOpTestBase(LongVectorOpTestConfig &TestConfig); + template void LongVectorOpTestBase(LongVectorOpType OpType); + template const wchar_t *BasicShaderModelTest_GetFormatString(); void CompileFromText(LPCSTR pText, LPCWSTR pEntryPoint, @@ -11096,6 +11218,1100 @@ TEST_F(ExecutionTest, PackUnpackTest) { } } +// A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes. +// Take int32_t as a constuctor argument and convert it to bool when needed. +// Comparisons cast to a bool because we only care if the bool representation is +// true or false. +struct HLSLBool_t { + HLSLBool_t() : val(0) {} + HLSLBool_t(int32_t val) : val(val) {} + HLSLBool_t(bool val) : val(val) {} + HLSLBool_t(const HLSLBool_t &other) : val(other.val) {} + + bool operator==(const HLSLBool_t &other) const { + return static_cast(val) == static_cast(other.val); + } + + bool operator!=(const HLSLBool_t &other) const { + return static_cast(val) != static_cast(other.val); + } + + bool operator<(const HLSLBool_t &other) const { return val < other.val; } + + bool operator>(const HLSLBool_t &other) const { return val > other.val; } + + bool operator<=(const HLSLBool_t &other) const { return val <= other.val; } + + bool operator>=(const HLSLBool_t &other) const { return val >= other.val; } + + HLSLBool_t operator*(const HLSLBool_t &other) const { + return HLSLBool_t(val * other.val); + } + + HLSLBool_t operator+(const HLSLBool_t &other) const { + return HLSLBool_t(val + other.val); + } + + // So we can construct std::wstrings using std::wostream + friend std::wostream &operator<<(std::wostream &os, const HLSLBool_t &obj) { + os << static_cast(obj.val); + return os; + } + + // So we can construct std::strings using std::ostream + friend std::ostream &operator<<(std::ostream &os, const HLSLBool_t &obj) { + os << static_cast(obj.val); + return os; + } + + int32_t val = 0; +}; + +// No native float16 type in C++ until C++23 . So we use uint16_t to represent +// it. Simple little wrapping struct to help handle the right behavior. +struct HLSLHalf_t { + HLSLHalf_t() : val(0) {} + HLSLHalf_t(DirectX::PackedVector::HALF val) : val(val) {} + HLSLHalf_t(const HLSLHalf_t &other) : val(other.val) {} + + bool operator==(const HLSLHalf_t &other) const { return val == other.val; } + + bool operator<(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) < + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + bool operator>(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) > + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + // Used by tolerance checks in the tests. + bool operator>(float d) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + return a > d; + } + + bool operator<(float d) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + return a < d; + } + + bool operator<=(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) <= + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + bool operator>=(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) >= + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + bool operator!=(const HLSLHalf_t &other) const { return val != other.val; } + + HLSLHalf_t operator*(const HLSLHalf_t &other) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); + return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a * b)); + } + + HLSLHalf_t operator+(const HLSLHalf_t &other) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); + return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a + b)); + } + + HLSLHalf_t operator-(const HLSLHalf_t &other) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); + return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a - b)); + } + + // So we can construct std::wstrings using std::wostream + friend std::wostream &operator<<(std::wostream &os, const HLSLHalf_t &obj) { + os << DirectX::PackedVector::XMConvertHalfToFloat(obj.val); + return os; + } + + // So we can construct std::wstrings using std::wostream + friend std::ostream &operator<<(std::ostream &os, const HLSLHalf_t &obj) { + os << DirectX::PackedVector::XMConvertHalfToFloat(obj.val); + return os; + } + + // HALF is an alias to uint16_t + DirectX::PackedVector::HALF val = 0; +}; + +// Helper to fill the shader buffer based on type. Convenient to be used when +// copying HLSL*_t types so we can copy the underlying type directly instead of +// the struct. +template +void FillShaderBufferFromLongVectorData(std::vector &ShaderBuffer, + std::array &TestData) { + + // Note: DataSize for HLSLHalf_t and HLSLBool_t may be larger than the + // underlying type in some cases. Thats fine. Resize just makes sure we have + // enough space. + const size_t DataSize = sizeof(T) * N; + ShaderBuffer.resize(DataSize); + + if constexpr (std::is_same_v) { + DirectX::PackedVector::HALF *ShaderBufferPtr = + reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + ShaderBufferPtr[i] = TestData[i].val; + } + } else if constexpr (std::is_same_v) { + int32_t *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + ShaderBufferPtr[i] = TestData[i].val; + } + } else { + T *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + ShaderBufferPtr[i] = TestData[i]; + } + } +} + +// Helper to fill the test data from the shader buffer based on type. Convenient +// to be used when copying HLSL*_t types so we can use the underlying type. +template +void FillLongVectorDataFromShaderBuffer(MappedData &ShaderBuffer, + std::array &TestData) { + + if constexpr (std::is_same_v) { + DirectX::PackedVector::HALF *ShaderBufferPtr = + reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + // HLSLHalf_t has a DirectX::PackedVector::HALF based constructor. + TestData[i] = ShaderBufferPtr[i]; + } + } else if constexpr (std::is_same_v) { + int32_t *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + // HLSLBool_t has a int32_t based constructor. + TestData[i] = ShaderBufferPtr[i]; + } + } else { + T *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + TestData[i] = ShaderBufferPtr[i]; + } + } +} + +enum LongVectorOpType { + LongVectorOpType_ScalarAdd, + LongVectorOpType_ScalarMultiply, + LongVectorOpType_Multiply, + LongVectorOpType_Add, + LongVectorOpType_Min, + LongVectorOpType_Max, + LongVectorOpType_Clamp, + LongVectorOpType_Initialize, + LongVectorOpType_UnInitialized +}; + +// Used to pass into LongVectorOpTestBase +template struct LongVectorOpTestConfig { + LongVectorOpTestConfig() = default; + + LongVectorOpTestConfig(LongVectorOpType OpType) : OpType(OpType) { + IntrinsicString = ""; + + if (IsFloatingPointType()) + Tolerance = 1; + + switch (OpType) { + case LongVectorOpType_ScalarAdd: + OperatorString = "+"; + IsScalarOp = true; + break; + case LongVectorOpType_ScalarMultiply: + OperatorString = "*"; + IsScalarOp = true; + break; + case LongVectorOpType_Multiply: + OperatorString = "*"; + break; + case LongVectorOpType_Add: + OperatorString = "+"; + break; + case LongVectorOpType_Min: + OperatorString = ","; + IntrinsicString = "min"; + break; + case LongVectorOpType_Max: + OperatorString = ","; + IntrinsicString = "max"; + break; + case LongVectorOpType_Clamp: + OperatorString = ","; + IntrinsicString = "TestClamp"; + IsBinaryOp = false; + break; + case LongVectorOpType_Initialize: + IntrinsicString = "TestInitialize"; + IsBinaryOp = false; + break; + default: + VERIFY_FAIL("Invalid LongVectorOpType"); + } + } + + bool IsFloatingPointType() const { + return std::is_same_v || std::is_same_v || + std::is_same_v; + } + + // A helper to get the hlsl type as a string for a given C++ type. + // Used in the long vector tests. + std::string GetHLSLTypeString() { + if (std::is_same_v) + return "bool"; + if (std::is_same_v) + return "half"; + if (std::is_same_v) + return "float"; + if (std::is_same_v) + return "double"; + if (std::is_same_v) + return "int16_t"; + if (std::is_same_v) + return "int"; + if (std::is_same_v) + return "int64_t"; + if (std::is_same_v) + return "uint16_t"; + if (std::is_same_v) + return "uint32_t"; + if (std::is_same_v) + return "uint64_t"; + + std::string ErrStr("GetHLSLTypeString() Unsupported type: "); + ErrStr.append(typeid(T).name()); + VERIFY_IS_TRUE(false, ErrStr.c_str()); + return "UnknownType"; + } + + // To be used for the value of -DOPERATOR + std::string OperatorString; + // To be used for the value of -DFUNC + std::string IntrinsicString; + // Optional, can be used to override shader code. + bool IsScalarOp = false; + bool IsBinaryOp = true; + float Tolerance = 0.0; + LongVectorOpType OpType = LongVectorOpType_UnInitialized; +}; + +template struct LongVectorTestTraits { + std::uniform_int_distribution UD = std::uniform_int_distribution( + std::numeric_limits::min(), std::numeric_limits::max()); +}; + +template <> struct LongVectorTestTraits { + // Float values for this were taken from Microsoft online documentation for + // the DirectX HALF data type. HALF is equivalent to IEEE 754 binary 16 + // format. + std::uniform_int_distribution UD = + std::uniform_int_distribution( + DirectX::PackedVector::XMConvertFloatToHalf(float(6.10e-5f)), + DirectX::PackedVector::XMConvertFloatToHalf(float(65504.0f))); +}; + +template <> struct LongVectorTestTraits { + std::uniform_int_distribution UD = + std::uniform_int_distribution(0u, 1u); +}; + +template <> struct LongVectorTestTraits { + // The ranges for generation. A std::uniform_real_distribution can only + // have a range that is equal to the types largest value. This is due to + // precision issues. So instead we define some large values. + std::uniform_real_distribution UD = + std::uniform_real_distribution(-1e20f, 1e20f); +}; + +template <> struct LongVectorTestTraits { + // The ranges for generation. A std::uniform_real_distribution can only + // have a range that is equal to the types largest value. This is due to + // precision issues. So instead we define some large values. + std::uniform_real_distribution UD = + std::uniform_real_distribution(-1e100, 1e100); +}; + +template class DeterministicNumberGenerator { + // Mersenne Twister 'random' number generator. Generated numbers are based + // on the seed value and are deterministic for any given seed. + std::mt19937 Generator; + + LongVectorTestTraits UD; + +public: + DeterministicNumberGenerator(unsigned SeedValue) : Generator(SeedValue) {} + + T generate() { return UD.UD(Generator); } +}; + +template +bool DoArraysMatch(const std::array &ActualValues, + const std::array &ExpectedValues, float Tolerance) { + // Stash mismatched indexes for easy failure logging later + std::vector MismatchedIndexes; + for (size_t Index = 0; Index < N; ++Index) { + if constexpr (std::is_same_v) { + // Compiler was very picky and wanted an explicit case for any T that + // doesn't implement the operators in the below else. ( > and -). It + // wouldn't accept putting this constexpr as an or case with other + // statements. + if (ActualValues[Index] != ExpectedValues[Index]) { + MismatchedIndexes.push_back(Index); + } + } else if constexpr (std::is_same_v) { + const DirectX::PackedVector::HALF a = ActualValues[Index].val; + const DirectX::PackedVector::HALF b = ExpectedValues[Index].val; + if (!CompareHalfULP(a, b, Tolerance)) { + MismatchedIndexes.push_back(Index); + } + } else if constexpr (std::is_same_v) { + const int IntTolerance = static_cast(Tolerance); + if (!CompareFloatULP(ActualValues[Index], ExpectedValues[Index], + IntTolerance)) { + MismatchedIndexes.push_back(Index); + } + } else if constexpr (std::is_same_v) { + const int64_t IntTolerance = static_cast(Tolerance); + if (!CompareDoubleULP(ActualValues[Index], ExpectedValues[Index], + IntTolerance)) { + MismatchedIndexes.push_back(Index); + } + } else if (Tolerance == 0 && ActualValues[Index] != ExpectedValues[Index]) { + MismatchedIndexes.push_back(Index); + } else { + T Diff = ActualValues[Index] > ExpectedValues[Index] + ? ActualValues[Index] - ExpectedValues[Index] + : ExpectedValues[Index] - ActualValues[Index]; + if (Diff > Tolerance) { + MismatchedIndexes.push_back(Index); + } + } + } + + if (MismatchedIndexes.empty()) + return true; + + if (!MismatchedIndexes.empty()) { + for (size_t Index : MismatchedIndexes) { + std::wstringstream Wss(L""); + Wss << L"Mismatch at Index: " << Index; + Wss << L" Actual Value:" << ActualValues[Index] << ","; + Wss << L" Expected Value:" << ExpectedValues[Index]; + WEX::Logging::Log::Error(Wss.str().c_str()); + } + } + + return false; +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_bool) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_bool) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_bool) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_bool) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_bool) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_bool) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_ScalarAdd_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarAdd); +} + +TEST_F(ExecutionTest, LongVector_ScalarMultiply_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_ScalarMultiply); +} + +TEST_F(ExecutionTest, LongVector_Multiply_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Multiply); +} + +TEST_F(ExecutionTest, LongVector_Add_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Add); +} + +TEST_F(ExecutionTest, LongVector_Min_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Min); +} + +TEST_F(ExecutionTest, LongVector_Max_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Max); +} + +TEST_F(ExecutionTest, LongVector_Initialize_bool) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_float16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_float32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_float64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_int16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_int32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_int64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_uint16) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_uint32) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +TEST_F(ExecutionTest, LongVector_Clamp_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Clamp); +} + +TEST_F(ExecutionTest, LongVector_Initialize_uint64) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + LongVectorOpTestBase(LongVectorOpType_Initialize); +} + +template +void ExecutionTest::LongVectorOpTestBase(LongVectorOpType opType) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + LongVectorOpTestConfig TestConfig(opType); + + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); + LongVectorOpTestBase(TestConfig); +} + +template +void ExecutionTest::LongVectorOpTestBase( + LongVectorOpTestConfig &TestConfig) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + LogCommentFmt(L"Running LongVectorOpTestBase<%S, %zu>", typeid(T).name(), N); + + CComPtr D3DDevice; + if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { +#ifdef _HLK_CONF + LogErrorFmtThrow(L"Device does not support SM 6.9. Can't run these tests."); + } +#else + WEX::Logging::Log::Comment( + "Device does not support SM 6.9. Can't run these tests."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; +#endif +} + +DeterministicNumberGenerator NumberGenerator(1337); +std::array InputVector1; +std::array InputVector2; +std::array ScalarInput; +ScalarInput[0] = NumberGenerator.generate(); +const bool IsVectorBinaryOp = TestConfig.IsBinaryOp && !TestConfig.IsScalarOp; + +// Fill the vector inputs with values. +for (size_t Index = 0; Index < N; Index++) { + // Always generate input. + InputVector1[Index] = NumberGenerator.generate(); + + if (IsVectorBinaryOp) + InputVector2[Index] = NumberGenerator.generate(); +} + +// We pass these values into the shader and they're requried to compile. So +// they need to set to something. +T ClampArgMin = 0; +T ClampArgMax = 0; +if (TestConfig.OpType == LongVectorOpType_Clamp) { + if constexpr (std::is_same_v) { + // Attempting to generate a clamp value for HLSLBool_t will result in an + // infinite loop in the below while. We don't have a test case for clamp + // with bools anyways. But adding this check to prevent the mistake. + LogErrorFmtThrow(L"Clamp is not supported for HLSLBool_t."); + } + + ClampArgMin = NumberGenerator.generate(); + ClampArgMax = NumberGenerator.generate(); + while (ClampArgMin >= ClampArgMax) { + // Generate a new value for ClampArgMin. It needs to be smaller than + // or equal to ClampArgMax. + ClampArgMax = NumberGenerator.generate(); + } +} + +std::array ExpectedVector; +for (size_t Index = 0; Index < N; Index++) { + if (TestConfig.IsBinaryOp) { + T Input1 = InputVector1[Index]; + T Input2 = TestConfig.IsScalarOp ? ScalarInput[0] : InputVector2[Index]; + if (TestConfig.OperatorString == "*") { + ExpectedVector[Index] = Input1 * Input2; + } else if (TestConfig.OperatorString == "+") { + ExpectedVector[Index] = Input1 + Input2; + } else if (TestConfig.OperatorString == ",") { + if (TestConfig.OpType == LongVectorOpType_Min) + ExpectedVector[Index] = std::min(Input1, Input2); + else if (TestConfig.OpType == LongVectorOpType_Max) + ExpectedVector[Index] = std::max(Input1, Input2); + else + LogErrorFmtThrow(L"Unrecognized Binary LongVectorOpType: %d", + TestConfig.OpType); + } else { + LogErrorFmtThrow( + L"Don't know how to compute expected value for operatorString: %s", + TestConfig.OperatorString.c_str()); + } + } else // Unary op logic + { + if (TestConfig.OpType == LongVectorOpType_Clamp) { + ExpectedVector[Index] = + std::clamp(InputVector1[Index], ClampArgMin, ClampArgMax); + } else if (TestConfig.OpType = LongVectorOpType_Initialize) { + ExpectedVector[Index] = InputVector1[Index]; + } else { + LogErrorFmtThrow(L"Unrecognized Unary LongVectorOpType: %d", + TestConfig.OpType); + } + } +} + +// Set up the compiler options string. +std::stringstream CompilerOptions(""); +std::string HLSLType = TestConfig.GetHLSLTypeString(); +CompilerOptions << "-DTYPE="; +CompilerOptions << HLSLType; +CompilerOptions << " -DNUM="; +CompilerOptions << N; +const bool Is16BitType = + (HLSLType == "int16_t" || HLSLType == "uint16_t" || HLSLType == "half"); +CompilerOptions << (Is16BitType ? " -enable-16bit-types" : ""); +CompilerOptions << " -DOPERATOR="; +CompilerOptions << TestConfig.OperatorString; +if (TestConfig.IsBinaryOp) { + CompilerOptions << " -DOPERAND2="; + CompilerOptions << (TestConfig.IsScalarOp ? "InputScalar" : "InputVector2"); + + if (TestConfig.IsScalarOp) { + CompilerOptions << " -DIS_SCALAR_OP=1"; + } else { + CompilerOptions << " -DIS_BINARY_VECTOR_OP=1"; + } + CompilerOptions << " -DFUNC="; + CompilerOptions << TestConfig.IntrinsicString; +} else { + CompilerOptions << " -DFUNC="; + CompilerOptions << TestConfig.IntrinsicString; + CompilerOptions << " -DOPERAND2="; + switch (TestConfig.OpType) { + case LongVectorOpType_Clamp: + CompilerOptions << "ClampArgMinMax"; + CompilerOptions << " -DFUNC_CLAMP=1"; + break; + case LongVectorOpType_Initialize: + CompilerOptions << " -DFUNC_INITIALIZE=1"; + break; + } +} + +// We have to construct the string outside of the lambda. Otherwise it's +// cleaned up when the lambda finishes executing but before the shader runs. +std::string CompilerOptionsString = CompilerOptions.str(); + +// ShaderOpArith.xml defines the input/output resources and the shader source. +CComPtr TestXML; +ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &TestXML); + +// RunShaderOpTest is a helper function that handles resource creation +// and setup. It also handles the shader compilation and execution. It takes a +// callback that is called when the shader is compiled, but before it is +// executed. +std::shared_ptr TestResult = RunShaderOpTest( + D3DDevice, m_support, TestXML, "LongVectorOp", + [&](LPCSTR Name, std::vector &ShaderData, st::ShaderOp *ShaderOp) { + LogCommentFmt(L"RunShaderOpTest CallBack. Resource Name: %S", Name); + + // This callback is called once for each resource defined for + // "LongVectorOp" in ShaderOpArith.xml. All callbacks are fired for each + // resource. We determine whether they are applicable to the test case + // when they run. + + // Process the callback for the OutputVector resource. + if (0 == _stricmp(Name, "OutputVector")) { + // We only need to set the compiler options string once. So this is a + // convenient place to do it. + ShaderOp->Shaders.at(0).Arguments = CompilerOptionsString.c_str(); + + return; + } + + // Process the callback for the InputFuncArgs resource. + if (0 == _stricmp(Name, "InputFuncArgs")) { + if (TestConfig.IsScalarOp) { + FillShaderBufferFromLongVectorData(ShaderData, ScalarInput); + } else if (TestConfig.OpType == LongVectorOpType_Clamp) { + std::array ClampArgs = {ClampArgMin, ClampArgMax}; + FillShaderBufferFromLongVectorData(ShaderData, ClampArgs); + } + + return; + } + + // Process the callback for the InputVector1 resource. + if (0 == _stricmp(Name, "InputVector1")) { + FillShaderBufferFromLongVectorData(ShaderData, InputVector1); + return; + } + + // Process the callback for the InputVector2 resource. + if (0 == _stricmp(Name, "InputVector2")) { + if (IsVectorBinaryOp) { + FillShaderBufferFromLongVectorData(ShaderData, InputVector2); + } + return; + } + + LogErrorFmtThrow( + L"RunShaderOpTest CallBack. Unexpected Resource Name: %S", Name); + }); + +// Map the data from GPU to CPU memory so we can verify our expectations. +MappedData ShaderOutData; +TestResult->Test->GetReadBackData("OutputVector", &ShaderOutData); + +std::array OutputVector; +FillLongVectorDataFromShaderBuffer(ShaderOutData, OutputVector); + +VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, + TestConfig.Tolerance)); +} + // This test expects a that retrieves a signal value from each of a // few resources that are initialized here. determines if it uses // the 6.6 Dynamic Resources feature. Values are read back from the result UAV diff --git a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml index e768f205f1..5e95ad2502 100644 --- a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml +++ b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml @@ -3750,4 +3750,80 @@ void MSMain(uint GID : SV_GroupIndex, + + RootFlags(0), UAV(u0), UAV(u1), UAV(u2), + UAV(u3) + + + + + + + + + + + + + + + TestInitialize(vector Vector) + { + vector VectorCopy = Vector; + return VectorCopy; + } + #endif + + #ifdef FUNC_CLAMP + vector TestClamp(vector Vector, vector ClampArgMinMax) + { + TYPE ClampArgMin = ClampArgMinMax[0]; + TYPE ClampArgMax = ClampArgMinMax[1]; + return clamp(Vector, ClampArgMin, ClampArgMax); + } + #endif + + RWByteAddressBuffer g_InputFuncArgs : register(u0); + RWByteAddressBuffer g_InputVector1 : register(u1); + RWByteAddressBuffer g_InputVector2 : register(u2); + RWByteAddressBuffer g_OutputVector : register(u3); + [numthreads(1,1,1)] + void main(uint GI : SV_GroupIndex) { + + vector InputVector1 = g_InputVector1.Load< vector >(0); + + #ifdef IS_BINARY_VECTOR_OP + vector InputVector2 = g_InputVector2.Load< vector >(0); + #endif + + #ifdef IS_SCALAR_OP + TYPE InputScalar = g_InputFuncArgs.Load(0); + #endif + + #ifdef FUNC_CLAMP + TYPE Clamp_ArgMin = g_InputFuncArgs.Load(0); + TYPE Clamp_ArgMax = g_InputFuncArgs.Load(sizeof(TYPE)); + vector ClampArgMinMax = {Clamp_ArgMin, Clamp_ArgMax}; + #endif + + vector OutputVector = FUNC(InputVector1 OPERATOR OPERAND2); + + g_OutputVector.Store< vector >(0, OutputVector); + }; + ]]> + + diff --git a/tools/clang/unittests/HLSLExec/ShaderOpTest.cpp b/tools/clang/unittests/HLSLExec/ShaderOpTest.cpp index e6c9b10f6c..33858c508f 100644 --- a/tools/clang/unittests/HLSLExec/ShaderOpTest.cpp +++ b/tools/clang/unittests/HLSLExec/ShaderOpTest.cpp @@ -86,6 +86,9 @@ static void ShaderOpLogFmt(const wchar_t *fmt, ...) { // Check the specified HRESULT and return the success value. static HRESULT CHECK_HR_RET(HRESULT hr) { + if (FAILED(hr)) { + DebugBreak(); + } CHECK_HR(hr); return hr; } @@ -861,6 +864,11 @@ void ShaderOpTest::CreateShaders() { CHECK_HR(pLibrary->CreateBlobWithEncodingFromPinned( pText, (UINT32)strlen(pText), CP_UTF8, &pTextBlob)); CHECK_HR(m_pDxcSupport->CreateInstance(CLSID_DxcCompiler, &pCompiler)); + WEX::Logging::Log::Comment(L"Compiling shader:"); + ShaderOpLogFmt(L"\tTarget profile: %S", S.Target); + if (argumentsWList.size() > 0) { + ShaderOpLogFmt(L"\tArguments: %S", pArguments); + } CHECK_HR(pCompiler->Compile(pTextBlob, nameW, entryPointW, targetW, (LPCWSTR *)argumentsWList.data(), (UINT32)argumentsWList.size(), nullptr, 0, From 26ca0d513a12426ec3f7cf39a9f81f06755fac6f Mon Sep 17 00:00:00 2001 From: Alex Sepkowski <5620315+alsepkow@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:26:00 -0700 Subject: [PATCH 11/31] Move most long vector preview test utility logic to its own file (#7375) Moving the long vector test utility functionality to its own file to help make subsequent reviews easier. No new code or logic updates. --- .../unittests/HLSLExec/ExecutionTest.cpp | 402 +---------------- tools/clang/unittests/HLSLExec/LongVectors.h | 414 ++++++++++++++++++ 2 files changed, 415 insertions(+), 401 deletions(-) create mode 100644 tools/clang/unittests/HLSLExec/LongVectors.h diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index ee67944950..1bef0b4f8d 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -62,6 +62,7 @@ #include "ShaderOpTest.h" #include #include +#include "LongVectors.h" // clang-format on #pragma comment(lib, "d3dcompiler.lib") @@ -279,9 +280,6 @@ typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS4 { // Virtual class to compute the expected result given a set of inputs struct TableParameter; -template struct LongVectorOpTestConfig; // Forward declaration -enum LongVectorOpType; // Forward declaration - class ExecutionTest { public: BEGIN_TEST_CLASS(ExecutionTest) @@ -11218,404 +11216,6 @@ TEST_F(ExecutionTest, PackUnpackTest) { } } -// A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes. -// Take int32_t as a constuctor argument and convert it to bool when needed. -// Comparisons cast to a bool because we only care if the bool representation is -// true or false. -struct HLSLBool_t { - HLSLBool_t() : val(0) {} - HLSLBool_t(int32_t val) : val(val) {} - HLSLBool_t(bool val) : val(val) {} - HLSLBool_t(const HLSLBool_t &other) : val(other.val) {} - - bool operator==(const HLSLBool_t &other) const { - return static_cast(val) == static_cast(other.val); - } - - bool operator!=(const HLSLBool_t &other) const { - return static_cast(val) != static_cast(other.val); - } - - bool operator<(const HLSLBool_t &other) const { return val < other.val; } - - bool operator>(const HLSLBool_t &other) const { return val > other.val; } - - bool operator<=(const HLSLBool_t &other) const { return val <= other.val; } - - bool operator>=(const HLSLBool_t &other) const { return val >= other.val; } - - HLSLBool_t operator*(const HLSLBool_t &other) const { - return HLSLBool_t(val * other.val); - } - - HLSLBool_t operator+(const HLSLBool_t &other) const { - return HLSLBool_t(val + other.val); - } - - // So we can construct std::wstrings using std::wostream - friend std::wostream &operator<<(std::wostream &os, const HLSLBool_t &obj) { - os << static_cast(obj.val); - return os; - } - - // So we can construct std::strings using std::ostream - friend std::ostream &operator<<(std::ostream &os, const HLSLBool_t &obj) { - os << static_cast(obj.val); - return os; - } - - int32_t val = 0; -}; - -// No native float16 type in C++ until C++23 . So we use uint16_t to represent -// it. Simple little wrapping struct to help handle the right behavior. -struct HLSLHalf_t { - HLSLHalf_t() : val(0) {} - HLSLHalf_t(DirectX::PackedVector::HALF val) : val(val) {} - HLSLHalf_t(const HLSLHalf_t &other) : val(other.val) {} - - bool operator==(const HLSLHalf_t &other) const { return val == other.val; } - - bool operator<(const HLSLHalf_t &other) const { - return DirectX::PackedVector::XMConvertHalfToFloat(val) < - DirectX::PackedVector::XMConvertHalfToFloat(other.val); - } - - bool operator>(const HLSLHalf_t &other) const { - return DirectX::PackedVector::XMConvertHalfToFloat(val) > - DirectX::PackedVector::XMConvertHalfToFloat(other.val); - } - - // Used by tolerance checks in the tests. - bool operator>(float d) const { - float a = DirectX::PackedVector::XMConvertHalfToFloat(val); - return a > d; - } - - bool operator<(float d) const { - float a = DirectX::PackedVector::XMConvertHalfToFloat(val); - return a < d; - } - - bool operator<=(const HLSLHalf_t &other) const { - return DirectX::PackedVector::XMConvertHalfToFloat(val) <= - DirectX::PackedVector::XMConvertHalfToFloat(other.val); - } - - bool operator>=(const HLSLHalf_t &other) const { - return DirectX::PackedVector::XMConvertHalfToFloat(val) >= - DirectX::PackedVector::XMConvertHalfToFloat(other.val); - } - - bool operator!=(const HLSLHalf_t &other) const { return val != other.val; } - - HLSLHalf_t operator*(const HLSLHalf_t &other) const { - float a = DirectX::PackedVector::XMConvertHalfToFloat(val); - float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); - return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a * b)); - } - - HLSLHalf_t operator+(const HLSLHalf_t &other) const { - float a = DirectX::PackedVector::XMConvertHalfToFloat(val); - float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); - return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a + b)); - } - - HLSLHalf_t operator-(const HLSLHalf_t &other) const { - float a = DirectX::PackedVector::XMConvertHalfToFloat(val); - float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); - return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a - b)); - } - - // So we can construct std::wstrings using std::wostream - friend std::wostream &operator<<(std::wostream &os, const HLSLHalf_t &obj) { - os << DirectX::PackedVector::XMConvertHalfToFloat(obj.val); - return os; - } - - // So we can construct std::wstrings using std::wostream - friend std::ostream &operator<<(std::ostream &os, const HLSLHalf_t &obj) { - os << DirectX::PackedVector::XMConvertHalfToFloat(obj.val); - return os; - } - - // HALF is an alias to uint16_t - DirectX::PackedVector::HALF val = 0; -}; - -// Helper to fill the shader buffer based on type. Convenient to be used when -// copying HLSL*_t types so we can copy the underlying type directly instead of -// the struct. -template -void FillShaderBufferFromLongVectorData(std::vector &ShaderBuffer, - std::array &TestData) { - - // Note: DataSize for HLSLHalf_t and HLSLBool_t may be larger than the - // underlying type in some cases. Thats fine. Resize just makes sure we have - // enough space. - const size_t DataSize = sizeof(T) * N; - ShaderBuffer.resize(DataSize); - - if constexpr (std::is_same_v) { - DirectX::PackedVector::HALF *ShaderBufferPtr = - reinterpret_cast(ShaderBuffer.data()); - for (size_t i = 0; i < N; ++i) { - ShaderBufferPtr[i] = TestData[i].val; - } - } else if constexpr (std::is_same_v) { - int32_t *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); - for (size_t i = 0; i < N; ++i) { - ShaderBufferPtr[i] = TestData[i].val; - } - } else { - T *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); - for (size_t i = 0; i < N; ++i) { - ShaderBufferPtr[i] = TestData[i]; - } - } -} - -// Helper to fill the test data from the shader buffer based on type. Convenient -// to be used when copying HLSL*_t types so we can use the underlying type. -template -void FillLongVectorDataFromShaderBuffer(MappedData &ShaderBuffer, - std::array &TestData) { - - if constexpr (std::is_same_v) { - DirectX::PackedVector::HALF *ShaderBufferPtr = - reinterpret_cast(ShaderBuffer.data()); - for (size_t i = 0; i < N; ++i) { - // HLSLHalf_t has a DirectX::PackedVector::HALF based constructor. - TestData[i] = ShaderBufferPtr[i]; - } - } else if constexpr (std::is_same_v) { - int32_t *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); - for (size_t i = 0; i < N; ++i) { - // HLSLBool_t has a int32_t based constructor. - TestData[i] = ShaderBufferPtr[i]; - } - } else { - T *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); - for (size_t i = 0; i < N; ++i) { - TestData[i] = ShaderBufferPtr[i]; - } - } -} - -enum LongVectorOpType { - LongVectorOpType_ScalarAdd, - LongVectorOpType_ScalarMultiply, - LongVectorOpType_Multiply, - LongVectorOpType_Add, - LongVectorOpType_Min, - LongVectorOpType_Max, - LongVectorOpType_Clamp, - LongVectorOpType_Initialize, - LongVectorOpType_UnInitialized -}; - -// Used to pass into LongVectorOpTestBase -template struct LongVectorOpTestConfig { - LongVectorOpTestConfig() = default; - - LongVectorOpTestConfig(LongVectorOpType OpType) : OpType(OpType) { - IntrinsicString = ""; - - if (IsFloatingPointType()) - Tolerance = 1; - - switch (OpType) { - case LongVectorOpType_ScalarAdd: - OperatorString = "+"; - IsScalarOp = true; - break; - case LongVectorOpType_ScalarMultiply: - OperatorString = "*"; - IsScalarOp = true; - break; - case LongVectorOpType_Multiply: - OperatorString = "*"; - break; - case LongVectorOpType_Add: - OperatorString = "+"; - break; - case LongVectorOpType_Min: - OperatorString = ","; - IntrinsicString = "min"; - break; - case LongVectorOpType_Max: - OperatorString = ","; - IntrinsicString = "max"; - break; - case LongVectorOpType_Clamp: - OperatorString = ","; - IntrinsicString = "TestClamp"; - IsBinaryOp = false; - break; - case LongVectorOpType_Initialize: - IntrinsicString = "TestInitialize"; - IsBinaryOp = false; - break; - default: - VERIFY_FAIL("Invalid LongVectorOpType"); - } - } - - bool IsFloatingPointType() const { - return std::is_same_v || std::is_same_v || - std::is_same_v; - } - - // A helper to get the hlsl type as a string for a given C++ type. - // Used in the long vector tests. - std::string GetHLSLTypeString() { - if (std::is_same_v) - return "bool"; - if (std::is_same_v) - return "half"; - if (std::is_same_v) - return "float"; - if (std::is_same_v) - return "double"; - if (std::is_same_v) - return "int16_t"; - if (std::is_same_v) - return "int"; - if (std::is_same_v) - return "int64_t"; - if (std::is_same_v) - return "uint16_t"; - if (std::is_same_v) - return "uint32_t"; - if (std::is_same_v) - return "uint64_t"; - - std::string ErrStr("GetHLSLTypeString() Unsupported type: "); - ErrStr.append(typeid(T).name()); - VERIFY_IS_TRUE(false, ErrStr.c_str()); - return "UnknownType"; - } - - // To be used for the value of -DOPERATOR - std::string OperatorString; - // To be used for the value of -DFUNC - std::string IntrinsicString; - // Optional, can be used to override shader code. - bool IsScalarOp = false; - bool IsBinaryOp = true; - float Tolerance = 0.0; - LongVectorOpType OpType = LongVectorOpType_UnInitialized; -}; - -template struct LongVectorTestTraits { - std::uniform_int_distribution UD = std::uniform_int_distribution( - std::numeric_limits::min(), std::numeric_limits::max()); -}; - -template <> struct LongVectorTestTraits { - // Float values for this were taken from Microsoft online documentation for - // the DirectX HALF data type. HALF is equivalent to IEEE 754 binary 16 - // format. - std::uniform_int_distribution UD = - std::uniform_int_distribution( - DirectX::PackedVector::XMConvertFloatToHalf(float(6.10e-5f)), - DirectX::PackedVector::XMConvertFloatToHalf(float(65504.0f))); -}; - -template <> struct LongVectorTestTraits { - std::uniform_int_distribution UD = - std::uniform_int_distribution(0u, 1u); -}; - -template <> struct LongVectorTestTraits { - // The ranges for generation. A std::uniform_real_distribution can only - // have a range that is equal to the types largest value. This is due to - // precision issues. So instead we define some large values. - std::uniform_real_distribution UD = - std::uniform_real_distribution(-1e20f, 1e20f); -}; - -template <> struct LongVectorTestTraits { - // The ranges for generation. A std::uniform_real_distribution can only - // have a range that is equal to the types largest value. This is due to - // precision issues. So instead we define some large values. - std::uniform_real_distribution UD = - std::uniform_real_distribution(-1e100, 1e100); -}; - -template class DeterministicNumberGenerator { - // Mersenne Twister 'random' number generator. Generated numbers are based - // on the seed value and are deterministic for any given seed. - std::mt19937 Generator; - - LongVectorTestTraits UD; - -public: - DeterministicNumberGenerator(unsigned SeedValue) : Generator(SeedValue) {} - - T generate() { return UD.UD(Generator); } -}; - -template -bool DoArraysMatch(const std::array &ActualValues, - const std::array &ExpectedValues, float Tolerance) { - // Stash mismatched indexes for easy failure logging later - std::vector MismatchedIndexes; - for (size_t Index = 0; Index < N; ++Index) { - if constexpr (std::is_same_v) { - // Compiler was very picky and wanted an explicit case for any T that - // doesn't implement the operators in the below else. ( > and -). It - // wouldn't accept putting this constexpr as an or case with other - // statements. - if (ActualValues[Index] != ExpectedValues[Index]) { - MismatchedIndexes.push_back(Index); - } - } else if constexpr (std::is_same_v) { - const DirectX::PackedVector::HALF a = ActualValues[Index].val; - const DirectX::PackedVector::HALF b = ExpectedValues[Index].val; - if (!CompareHalfULP(a, b, Tolerance)) { - MismatchedIndexes.push_back(Index); - } - } else if constexpr (std::is_same_v) { - const int IntTolerance = static_cast(Tolerance); - if (!CompareFloatULP(ActualValues[Index], ExpectedValues[Index], - IntTolerance)) { - MismatchedIndexes.push_back(Index); - } - } else if constexpr (std::is_same_v) { - const int64_t IntTolerance = static_cast(Tolerance); - if (!CompareDoubleULP(ActualValues[Index], ExpectedValues[Index], - IntTolerance)) { - MismatchedIndexes.push_back(Index); - } - } else if (Tolerance == 0 && ActualValues[Index] != ExpectedValues[Index]) { - MismatchedIndexes.push_back(Index); - } else { - T Diff = ActualValues[Index] > ExpectedValues[Index] - ? ActualValues[Index] - ExpectedValues[Index] - : ExpectedValues[Index] - ActualValues[Index]; - if (Diff > Tolerance) { - MismatchedIndexes.push_back(Index); - } - } - } - - if (MismatchedIndexes.empty()) - return true; - - if (!MismatchedIndexes.empty()) { - for (size_t Index : MismatchedIndexes) { - std::wstringstream Wss(L""); - Wss << L"Mismatch at Index: " << Index; - Wss << L" Actual Value:" << ActualValues[Index] << ","; - Wss << L" Expected Value:" << ExpectedValues[Index]; - WEX::Logging::Log::Error(Wss.str().c_str()); - } - } - - return false; -} - TEST_F(ExecutionTest, LongVector_ScalarAdd_bool) { WEX::TestExecution::SetVerifyOutput verifySettings( WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); diff --git a/tools/clang/unittests/HLSLExec/LongVectors.h b/tools/clang/unittests/HLSLExec/LongVectors.h new file mode 100644 index 0000000000..eb3b37a570 --- /dev/null +++ b/tools/clang/unittests/HLSLExec/LongVectors.h @@ -0,0 +1,414 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +template struct LongVectorOpTestConfig; // Forward declaration +enum LongVectorOpType; // Forward declaration + +// A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes. +// Take int32_t as a constuctor argument and convert it to bool when needed. +// Comparisons cast to a bool because we only care if the bool representation is +// true or false. +struct HLSLBool_t { + HLSLBool_t() : val(0) {} + HLSLBool_t(int32_t val) : val(val) {} + HLSLBool_t(bool val) : val(val) {} + HLSLBool_t(const HLSLBool_t &other) : val(other.val) {} + + bool operator==(const HLSLBool_t &other) const { + return static_cast(val) == static_cast(other.val); + } + + bool operator!=(const HLSLBool_t &other) const { + return static_cast(val) != static_cast(other.val); + } + + bool operator<(const HLSLBool_t &other) const { return val < other.val; } + + bool operator>(const HLSLBool_t &other) const { return val > other.val; } + + bool operator<=(const HLSLBool_t &other) const { return val <= other.val; } + + bool operator>=(const HLSLBool_t &other) const { return val >= other.val; } + + HLSLBool_t operator*(const HLSLBool_t &other) const { + return HLSLBool_t(val * other.val); + } + + HLSLBool_t operator+(const HLSLBool_t &other) const { + return HLSLBool_t(val + other.val); + } + + // So we can construct std::wstrings using std::wostream + friend std::wostream &operator<<(std::wostream &os, const HLSLBool_t &obj) { + os << static_cast(obj.val); + return os; + } + + // So we can construct std::strings using std::ostream + friend std::ostream &operator<<(std::ostream &os, const HLSLBool_t &obj) { + os << static_cast(obj.val); + return os; + } + + int32_t val = 0; +}; + +// No native float16 type in C++ until C++23 . So we use uint16_t to represent +// it. Simple little wrapping struct to help handle the right behavior. +struct HLSLHalf_t { + HLSLHalf_t() : val(0) {} + HLSLHalf_t(DirectX::PackedVector::HALF val) : val(val) {} + HLSLHalf_t(const HLSLHalf_t &other) : val(other.val) {} + + bool operator==(const HLSLHalf_t &other) const { return val == other.val; } + + bool operator<(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) < + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + bool operator>(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) > + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + // Used by tolerance checks in the tests. + bool operator>(float d) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + return a > d; + } + + bool operator<(float d) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + return a < d; + } + + bool operator<=(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) <= + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + bool operator>=(const HLSLHalf_t &other) const { + return DirectX::PackedVector::XMConvertHalfToFloat(val) >= + DirectX::PackedVector::XMConvertHalfToFloat(other.val); + } + + bool operator!=(const HLSLHalf_t &other) const { return val != other.val; } + + HLSLHalf_t operator*(const HLSLHalf_t &other) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); + return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a * b)); + } + + HLSLHalf_t operator+(const HLSLHalf_t &other) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); + return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a + b)); + } + + HLSLHalf_t operator-(const HLSLHalf_t &other) const { + float a = DirectX::PackedVector::XMConvertHalfToFloat(val); + float b = DirectX::PackedVector::XMConvertHalfToFloat(other.val); + return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(a - b)); + } + + // So we can construct std::wstrings using std::wostream + friend std::wostream &operator<<(std::wostream &os, const HLSLHalf_t &obj) { + os << DirectX::PackedVector::XMConvertHalfToFloat(obj.val); + return os; + } + + // So we can construct std::wstrings using std::wostream + friend std::ostream &operator<<(std::ostream &os, const HLSLHalf_t &obj) { + os << DirectX::PackedVector::XMConvertHalfToFloat(obj.val); + return os; + } + + // HALF is an alias to uint16_t + DirectX::PackedVector::HALF val = 0; +}; + +// Helper to fill the shader buffer based on type. Convenient to be used when +// copying HLSL*_t types so we can copy the underlying type directly instead of +// the struct. +template +void FillShaderBufferFromLongVectorData(std::vector &ShaderBuffer, + std::array &TestData) { + + // Note: DataSize for HLSLHalf_t and HLSLBool_t may be larger than the + // underlying type in some cases. Thats fine. Resize just makes sure we have + // enough space. + const size_t DataSize = sizeof(T) * N; + ShaderBuffer.resize(DataSize); + + if constexpr (std::is_same_v) { + DirectX::PackedVector::HALF *ShaderBufferPtr = + reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + ShaderBufferPtr[i] = TestData[i].val; + } + } else if constexpr (std::is_same_v) { + int32_t *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + ShaderBufferPtr[i] = TestData[i].val; + } + } else { + T *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + ShaderBufferPtr[i] = TestData[i]; + } + } +} + +// Helper to fill the test data from the shader buffer based on type. Convenient +// to be used when copying HLSL*_t types so we can use the underlying type. +template +void FillLongVectorDataFromShaderBuffer(MappedData &ShaderBuffer, + std::array &TestData) { + + if constexpr (std::is_same_v) { + DirectX::PackedVector::HALF *ShaderBufferPtr = + reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + // HLSLHalf_t has a DirectX::PackedVector::HALF based constructor. + TestData[i] = ShaderBufferPtr[i]; + } + } else if constexpr (std::is_same_v) { + int32_t *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + // HLSLBool_t has a int32_t based constructor. + TestData[i] = ShaderBufferPtr[i]; + } + } else { + T *ShaderBufferPtr = reinterpret_cast(ShaderBuffer.data()); + for (size_t i = 0; i < N; ++i) { + TestData[i] = ShaderBufferPtr[i]; + } + } +} + +enum LongVectorOpType { + LongVectorOpType_ScalarAdd, + LongVectorOpType_ScalarMultiply, + LongVectorOpType_Multiply, + LongVectorOpType_Add, + LongVectorOpType_Min, + LongVectorOpType_Max, + LongVectorOpType_Clamp, + LongVectorOpType_Initialize, + LongVectorOpType_UnInitialized +}; + +// Used to pass into LongVectorOpTestBase +template struct LongVectorOpTestConfig { + LongVectorOpTestConfig() = default; + + LongVectorOpTestConfig(LongVectorOpType OpType) : OpType(OpType) { + IntrinsicString = ""; + + if (IsFloatingPointType()) + Tolerance = 1; + + switch (OpType) { + case LongVectorOpType_ScalarAdd: + OperatorString = "+"; + IsScalarOp = true; + break; + case LongVectorOpType_ScalarMultiply: + OperatorString = "*"; + IsScalarOp = true; + break; + case LongVectorOpType_Multiply: + OperatorString = "*"; + break; + case LongVectorOpType_Add: + OperatorString = "+"; + break; + case LongVectorOpType_Min: + OperatorString = ","; + IntrinsicString = "min"; + break; + case LongVectorOpType_Max: + OperatorString = ","; + IntrinsicString = "max"; + break; + case LongVectorOpType_Clamp: + OperatorString = ","; + IntrinsicString = "TestClamp"; + IsBinaryOp = false; + break; + case LongVectorOpType_Initialize: + IntrinsicString = "TestInitialize"; + IsBinaryOp = false; + break; + default: + VERIFY_FAIL("Invalid LongVectorOpType"); + } + } + + bool IsFloatingPointType() const { + return std::is_same_v || std::is_same_v || + std::is_same_v; + } + + // A helper to get the hlsl type as a string for a given C++ type. + // Used in the long vector tests. + std::string GetHLSLTypeString() { + if (std::is_same_v) + return "bool"; + if (std::is_same_v) + return "half"; + if (std::is_same_v) + return "float"; + if (std::is_same_v) + return "double"; + if (std::is_same_v) + return "int16_t"; + if (std::is_same_v) + return "int"; + if (std::is_same_v) + return "int64_t"; + if (std::is_same_v) + return "uint16_t"; + if (std::is_same_v) + return "uint32_t"; + if (std::is_same_v) + return "uint64_t"; + + std::string ErrStr("GetHLSLTypeString() Unsupported type: "); + ErrStr.append(typeid(T).name()); + VERIFY_IS_TRUE(false, ErrStr.c_str()); + return "UnknownType"; + } + + // To be used for the value of -DOPERATOR + std::string OperatorString; + // To be used for the value of -DFUNC + std::string IntrinsicString; + // Optional, can be used to override shader code. + bool IsScalarOp = false; + bool IsBinaryOp = true; + float Tolerance = 0.0; + LongVectorOpType OpType = LongVectorOpType_UnInitialized; +}; + +template struct LongVectorTestTraits { + std::uniform_int_distribution UD = std::uniform_int_distribution( + std::numeric_limits::min(), std::numeric_limits::max()); +}; + +template <> struct LongVectorTestTraits { + // Float values for this were taken from Microsoft online documentation for + // the DirectX HALF data type. HALF is equivalent to IEEE 754 binary 16 + // format. + std::uniform_int_distribution UD = + std::uniform_int_distribution( + DirectX::PackedVector::XMConvertFloatToHalf(float(6.10e-5f)), + DirectX::PackedVector::XMConvertFloatToHalf(float(65504.0f))); +}; + +template <> struct LongVectorTestTraits { + std::uniform_int_distribution UD = + std::uniform_int_distribution(0u, 1u); +}; + +template <> struct LongVectorTestTraits { + // The ranges for generation. A std::uniform_real_distribution can only + // have a range that is equal to the types largest value. This is due to + // precision issues. So instead we define some large values. + std::uniform_real_distribution UD = + std::uniform_real_distribution(-1e20f, 1e20f); +}; + +template <> struct LongVectorTestTraits { + // The ranges for generation. A std::uniform_real_distribution can only + // have a range that is equal to the types largest value. This is due to + // precision issues. So instead we define some large values. + std::uniform_real_distribution UD = + std::uniform_real_distribution(-1e100, 1e100); +}; + +template class DeterministicNumberGenerator { + // Mersenne Twister 'random' number generator. Generated numbers are based + // on the seed value and are deterministic for any given seed. + std::mt19937 Generator; + + LongVectorTestTraits UD; + +public: + DeterministicNumberGenerator(unsigned SeedValue) : Generator(SeedValue) {} + + T generate() { return UD.UD(Generator); } +}; + +template +bool DoArraysMatch(const std::array &ActualValues, + const std::array &ExpectedValues, float Tolerance) { + // Stash mismatched indexes for easy failure logging later + std::vector MismatchedIndexes; + for (size_t Index = 0; Index < N; ++Index) { + if constexpr (std::is_same_v) { + // Compiler was very picky and wanted an explicit case for any T that + // doesn't implement the operators in the below else. ( > and -). It + // wouldn't accept putting this constexpr as an or case with other + // statements. + if (ActualValues[Index] != ExpectedValues[Index]) { + MismatchedIndexes.push_back(Index); + } + } else if constexpr (std::is_same_v) { + const DirectX::PackedVector::HALF a = ActualValues[Index].val; + const DirectX::PackedVector::HALF b = ExpectedValues[Index].val; + if (!CompareHalfULP(a, b, Tolerance)) { + MismatchedIndexes.push_back(Index); + } + } else if constexpr (std::is_same_v) { + const int IntTolerance = static_cast(Tolerance); + if (!CompareFloatULP(ActualValues[Index], ExpectedValues[Index], + IntTolerance)) { + MismatchedIndexes.push_back(Index); + } + } else if constexpr (std::is_same_v) { + const int64_t IntTolerance = static_cast(Tolerance); + if (!CompareDoubleULP(ActualValues[Index], ExpectedValues[Index], + IntTolerance)) { + MismatchedIndexes.push_back(Index); + } + } else if (Tolerance == 0 && ActualValues[Index] != ExpectedValues[Index]) { + MismatchedIndexes.push_back(Index); + } else { + T Diff = ActualValues[Index] > ExpectedValues[Index] + ? ActualValues[Index] - ExpectedValues[Index] + : ExpectedValues[Index] - ActualValues[Index]; + if (Diff > Tolerance) { + MismatchedIndexes.push_back(Index); + } + } + } + + if (MismatchedIndexes.empty()) + return true; + + if (!MismatchedIndexes.empty()) { + for (size_t Index : MismatchedIndexes) { + std::wstringstream Wss(L""); + Wss << L"Mismatch at Index: " << Index; + Wss << L" Actual Value:" << ActualValues[Index] << ","; + Wss << L" Expected Value:" << ExpectedValues[Index]; + WEX::Logging::Log::Error(Wss.str().c_str()); + } + } + + return false; +} From 6cb6843870cd5dc071c618f1e203361cb5d186ae Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Fri, 25 Apr 2025 06:43:32 +0200 Subject: [PATCH 12/31] [SER] Basic execution tests - All trivial scalar/vector/matrix getters - HitObject::FromRayQuery with procedural hit - HitObject::GetAttributes with custom attributes and procedural hit SER implementation tracker: #7214 --- tools/clang/unittests/HLSLExec/DXRUtil.h | 226 +++ .../unittests/HLSLExec/ExecutionTest.cpp | 718 +++++++ .../unittests/HLSLExec/ExecutionTest_SER.h | 1703 +++++++++++++++++ 3 files changed, 2647 insertions(+) create mode 100644 tools/clang/unittests/HLSLExec/DXRUtil.h create mode 100644 tools/clang/unittests/HLSLExec/ExecutionTest_SER.h diff --git a/tools/clang/unittests/HLSLExec/DXRUtil.h b/tools/clang/unittests/HLSLExec/DXRUtil.h new file mode 100644 index 0000000000..1f008885cf --- /dev/null +++ b/tools/clang/unittests/HLSLExec/DXRUtil.h @@ -0,0 +1,226 @@ +//===------------ DXRUtil.h - DXR Utility Functions ------------*- C++ -*-===// +/////////////////////////////////////////////////////////////////////////////// +// // +// DXRUtil.h // +// Copyright (C) Nvidia Corporation. All rights reserved. // +// This file is distributed under the University of Illinois Open Source // +// License. See LICENSE.TXT for details. // +// // +// This file contains the utility functions for DXR execution tests. // +// // +/////////////////////////////////////////////////////////////////////////////// + +//= DXR Utility +//============================================================================ +#define SHADER_ID_SIZE_IN_BYTES 32 + +#ifndef ROUND_UP +#define ROUND_UP(v, powerOf2Alignment) \ + (((v) + (powerOf2Alignment)-1) & ~((powerOf2Alignment)-1)) +#endif +struct SceneConsts { + DirectX::XMFLOAT4 eye; + DirectX::XMFLOAT4 U; + DirectX::XMFLOAT4 V; + DirectX::XMFLOAT4 W; + float sceneScale; + unsigned windowSize[2]; + int rayFlags; +}; + +struct Instance { + D3D12_RAYTRACING_GEOMETRY_TYPE type; + DirectX::XMFLOAT4X4 matrix; + UINT geometryCount; + UINT bottomASIdx; + UINT instanceID; + UINT mask; + UINT flags; +}; + +class ShaderTable { +public: + void Init(ID3D12Device *device, int raygenCount, int missCount, + int hitGroupCount, int rayTypeCount, int rootTableDwords) { + m_rayTypeCount = rayTypeCount; + m_raygenCount = raygenCount; + m_missCount = missCount * rayTypeCount; + m_hitGroupCount = hitGroupCount * rayTypeCount; + m_rootTableSizeInBytes = rootTableDwords * 4; + m_shaderRecordSizeInBytes = + ROUND_UP(m_rootTableSizeInBytes + SHADER_ID_SIZE_IN_BYTES, + D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT); + m_missStartIdx = m_raygenCount; + m_hitGroupStartIdx = m_missStartIdx + m_missCount; + + const int m_totalSizeInBytes = + (m_raygenCount + m_missCount + m_hitGroupCount) * + m_shaderRecordSizeInBytes; + + D3D12_RESOURCE_DESC desc = CD3DX12_RESOURCE_DESC::Buffer( + m_totalSizeInBytes, D3D12_RESOURCE_FLAG_NONE, + std::max(D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT, + D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT)); + CD3DX12_HEAP_PROPERTIES heap = + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); + VERIFY_SUCCEEDED(device->CreateCommittedResource( + &heap, D3D12_HEAP_FLAG_NONE, &desc, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, nullptr, + IID_PPV_ARGS(&m_sbtResource))); + m_sbtResource->SetName(L"SBT Resource Heap"); + CD3DX12_HEAP_PROPERTIES upload = + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); + VERIFY_SUCCEEDED(device->CreateCommittedResource( + &upload, D3D12_HEAP_FLAG_NONE, &desc, D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, IID_PPV_ARGS(&m_sbtUploadResource))); + m_sbtUploadResource->SetName(L"SBT Upload Heap"); + + VERIFY_SUCCEEDED(m_sbtUploadResource->Map(0, nullptr, (void **)&m_hostPtr)); + } + + void Upload(ID3D12GraphicsCommandList *cmdlist) { + CD3DX12_RESOURCE_BARRIER barrier = CD3DX12_RESOURCE_BARRIER::Transition( + m_sbtResource, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, + D3D12_RESOURCE_STATE_COPY_DEST); + cmdlist->ResourceBarrier(1, &barrier); + cmdlist->CopyResource(m_sbtResource, m_sbtUploadResource); + CD3DX12_RESOURCE_BARRIER barrier2 = CD3DX12_RESOURCE_BARRIER::Transition( + m_sbtResource, D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE); + cmdlist->ResourceBarrier(1, &barrier2); + } + + int GetShaderRecordSizeInBytes() { return m_shaderRecordSizeInBytes; } + + int GetRaygenShaderRecordIdx(int idx) { return idx; } + int GetMissShaderRecordIdx(int idx, int rayType) { + return m_missStartIdx + idx * m_rayTypeCount + rayType; + } + int GetHitGroupShaderRecordIdx(int idx, int rayType) { + return m_hitGroupStartIdx + idx * m_rayTypeCount + rayType; + } + + void *GetRaygenShaderIdPtr(int idx) { + return m_hostPtr + + GetRaygenShaderRecordIdx(idx) * m_shaderRecordSizeInBytes; + } + void *GetMissShaderIdPtr(int idx, int rayType) { + return m_hostPtr + + GetMissShaderRecordIdx(idx, rayType) * m_shaderRecordSizeInBytes; + } + void *GetHitGroupShaderIdPtr(int idx, int rayType) { + return m_hostPtr + + GetHitGroupShaderRecordIdx(idx, rayType) * m_shaderRecordSizeInBytes; + } + + void *GetRaygenRootTablePtr(int idx) { + return (char *)GetRaygenShaderIdPtr(idx) + SHADER_ID_SIZE_IN_BYTES; + } + void *GetMissRootTablePtr(int idx, int rayType) { + return (char *)GetMissShaderIdPtr(idx, rayType) + SHADER_ID_SIZE_IN_BYTES; + } + void *GetHitGroupRootTablePtr(int idx, int rayType) { + return (char *)GetHitGroupShaderIdPtr(idx, rayType) + + SHADER_ID_SIZE_IN_BYTES; + } + + int GetRaygenRangeInBytes() { + return m_raygenCount * m_shaderRecordSizeInBytes; + } + int GetMissRangeInBytes() { return m_missCount * m_shaderRecordSizeInBytes; } + int GetHitGroupRangeInBytes() { + return m_hitGroupCount * m_shaderRecordSizeInBytes; + } + + D3D12_GPU_VIRTUAL_ADDRESS GetRaygenStartGpuVA() { + return m_sbtResource->GetGPUVirtualAddress() + + GetRaygenShaderRecordIdx(0) * m_shaderRecordSizeInBytes; + } + D3D12_GPU_VIRTUAL_ADDRESS GetMissStartGpuVA() { + return m_sbtResource->GetGPUVirtualAddress() + + GetMissShaderRecordIdx(0, 0) * m_shaderRecordSizeInBytes; + } + D3D12_GPU_VIRTUAL_ADDRESS GetHitGroupStartGpuVA() { + return m_sbtResource->GetGPUVirtualAddress() + + GetHitGroupShaderRecordIdx(0, 0) * m_shaderRecordSizeInBytes; + } + +private: + CComPtr m_sbtResource; + CComPtr m_sbtUploadResource; + char *m_hostPtr = nullptr; + int m_rayTypeCount = 0; + int m_raygenCount = 0; + int m_missCount = 0; + int m_hitGroupCount = 0; + int m_rootTableSizeInBytes = 0; + int m_shaderRecordSizeInBytes = 0; + int m_missStartIdx = 0; + int m_hitGroupStartIdx = 0; +}; + +//----------------------------------------------------------------------------- +void AllocateBuffer( + ID3D12Device *pDevice, UINT64 bufferSize, ID3D12Resource **ppResource, + bool allowUAV = false, + D3D12_RESOURCE_STATES initialResourceState = D3D12_RESOURCE_STATE_COMMON, + const wchar_t *resourceName = nullptr) { + auto uploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); + auto bufferDesc = CD3DX12_RESOURCE_DESC::Buffer( + bufferSize, allowUAV ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS + : D3D12_RESOURCE_FLAG_NONE); + VERIFY_SUCCEEDED(pDevice->CreateCommittedResource( + &uploadHeapProperties, D3D12_HEAP_FLAG_NONE, &bufferDesc, + initialResourceState, nullptr, IID_PPV_ARGS(ppResource))); + if (resourceName) { + (*ppResource)->SetName(resourceName); + } +} + +//----------------------------------------------------------------------------- +void ReallocScratchResource(ID3D12Device *pDevice, ID3D12Resource **ppResource, + UINT64 nbytes) { + + if (!(*ppResource) || (*ppResource)->GetDesc().Width < nbytes) { + AllocateBuffer(pDevice, nbytes, ppResource, true, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, L"scratchResource"); + } +} + +//----------------------------------------------------------------------------- +void AllocateUploadBuffer(ID3D12Device *pDevice, const void *pData, + UINT64 datasize, ID3D12Resource **ppResource, + const wchar_t *resourceName = nullptr) { + auto uploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); + auto bufferDesc = CD3DX12_RESOURCE_DESC::Buffer(datasize); + VERIFY_SUCCEEDED(pDevice->CreateCommittedResource( + &uploadHeapProperties, D3D12_HEAP_FLAG_NONE, &bufferDesc, + D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(ppResource))); + if (resourceName) { + (*ppResource)->SetName(resourceName); + } + void *pMappedData; + VERIFY_SUCCEEDED((*ppResource)->Map(0, nullptr, &pMappedData)); + memcpy(pMappedData, pData, datasize); + (*ppResource)->Unmap(0, nullptr); +} + +//----------------------------------------------------------------------------- +void AllocateBufferFromUpload(ID3D12Device *pDevice, + ID3D12GraphicsCommandList *pCommandList, + ID3D12Resource *uploadSource, + ID3D12Resource **ppResource, + D3D12_RESOURCE_STATES targetResourceState, + const wchar_t *resourceName = nullptr) { + const bool allowUAV = + targetResourceState == D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + AllocateBuffer(pDevice, uploadSource->GetDesc().Width, ppResource, allowUAV, + D3D12_RESOURCE_STATE_COPY_DEST, resourceName); + pCommandList->CopyResource(*ppResource, uploadSource); + CD3DX12_RESOURCE_BARRIER barrier = CD3DX12_RESOURCE_BARRIER::Transition( + *ppResource, D3D12_RESOURCE_STATE_COPY_DEST, targetResourceState); + pCommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&barrier); +} + +//= DXR Utility +//============================================================================ \ No newline at end of file diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 6db27d7a41..403b261df1 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -60,6 +60,8 @@ #include "ShaderOpTest.h" #include #include + +#include "DXRUtil.h" // clang-format on #pragma comment(lib, "d3dcompiler.lib") @@ -292,6 +294,15 @@ class ExecutionTest { TEST_METHOD(SaturateTest); TEST_METHOD(SignTest); TEST_METHOD(Int64Test); + TEST_METHOD(SERBasicTest); + TEST_METHOD(SERScalarGetterTest); + TEST_METHOD(SERVectorGetterTest); + TEST_METHOD(SERMatrixGetterTest); + TEST_METHOD(SERRayQueryTest); + TEST_METHOD(SERIntersectionTest); + TEST_METHOD(SERGetAttributesTest); + TEST_METHOD(SERTraceHitMissNopTest); + TEST_METHOD(SERIsMissTest); TEST_METHOD(LifetimeIntrinsicTest) TEST_METHOD(WaveIntrinsicsTest); TEST_METHOD(WaveIntrinsicsDDITest); @@ -1917,6 +1928,12 @@ class ExecutionTest { CComPtr &pRootSignature, LPCWSTR pTargetProfile, LPCWSTR *pOptions, int numOptions); + CComPtr + RunDXRTest(ID3D12Device *pDevice0, LPCSTR shader, + D3D_SHADER_MODEL shaderModel, LPCWSTR *pOptions, int numOptions, + std::vector &testData, int windowWidth, int windowHeight, + bool useMesh, bool useProceduralGeometry, bool useIS, + int payloadCount = 1, int attributeCount = 2); void SetDescriptorHeap(ID3D12GraphicsCommandList *pCommandList, ID3D12DescriptorHeap *pHeap) { @@ -2078,6 +2095,707 @@ void ExecutionTest::RunRWByteBufferComputeTest(ID3D12Device *pDevice, WaitForSignal(pCommandQueue, FO); } +CComPtr ExecutionTest::RunDXRTest( + ID3D12Device *pDevice0, LPCSTR shader, D3D_SHADER_MODEL shaderModel, + LPCWSTR *pOptions, int numOptions, std::vector &testData, + int windowWidth, int windowHeight, bool useMesh, bool useProceduralGeometry, + bool useIS, int payloadCount, int attributeCount) { + CComPtr pDevice; + VERIFY_SUCCEEDED(pDevice0->QueryInterface(IID_PPV_ARGS(&pDevice))); + + LPCWSTR pTargetProfile; + switch (shaderModel) { + case D3D_SHADER_MODEL_6_9: + pTargetProfile = L"lib_6_9"; + break; + case D3D_SHADER_MODEL_6_8: + pTargetProfile = L"lib_6_8"; + break; + case D3D_SHADER_MODEL_6_7: + pTargetProfile = L"lib_6_7"; + break; + case D3D_SHADER_MODEL_6_6: + pTargetProfile = L"lib_6_6"; + break; + case D3D_SHADER_MODEL_6_5: + pTargetProfile = L"lib_6_5"; + break; + case D3D_SHADER_MODEL_6_4: + pTargetProfile = L"lib_6_4"; + break; + case D3D_SHADER_MODEL_6_3: + pTargetProfile = L"lib_6_3"; + break; + default: + // DXR capable shader model not found. + LogErrorFmt(L"DXR capable shader model not found."); + return nullptr; + } + + FenceObj FO; + InitFenceObj(pDevice, &FO); + + // Setup Resources + CComPtr pTestBuffer; + CComPtr pTestBufferRead; + CComPtr pSceneConstantBuffer; + + // Descriptor heap + CComPtr pDescriptorHeap; + { + // + // UAV descriptor heap layout: + // 0 - test buffer UAV + // 1 - vertex buffer SRV + // 2 - index buffer SRV + // + D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {}; + descriptorHeapDesc.NumDescriptors = 3; + descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + descriptorHeapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + pDevice->CreateDescriptorHeap(&descriptorHeapDesc, + IID_PPV_ARGS(&pDescriptorHeap)); + pDescriptorHeap->SetName(L"Descriptor Heap"); + } + int descriptorSize = pDevice->GetDescriptorHandleIncrementSize( + D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV); + + // Testbuffer + { + auto resDesc = CD3DX12_RESOURCE_DESC::Buffer( + testData.size() * sizeof(int), + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + auto defaultHeapProperties = + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); + VERIFY_SUCCEEDED(pDevice->CreateCommittedResource( + &defaultHeapProperties, D3D12_HEAP_FLAG_NONE, &resDesc, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, nullptr, + IID_PPV_ARGS(&pTestBuffer))); + pTestBuffer->SetName(L"Test Buffer"); + + const int descriptorIndex = 0; + D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + CD3DX12_CPU_DESCRIPTOR_HANDLE( + pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + descriptorIndex, descriptorSize); + D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {}; + UAVDesc.Format = DXGI_FORMAT_UNKNOWN; + UAVDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER; + UAVDesc.Buffer.FirstElement = 0; + UAVDesc.Buffer.NumElements = (UINT)testData.size(); + UAVDesc.Buffer.StructureByteStride = sizeof(int); + UAVDesc.Buffer.CounterOffsetInBytes = 0; + UAVDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_NONE; + pDevice->CreateUnorderedAccessView(pTestBuffer, nullptr, &UAVDesc, + cpuDescriptorHandle); + } + + // Testbuffer Readback + { + CD3DX12_HEAP_PROPERTIES readHeap(D3D12_HEAP_TYPE_READBACK); + CD3DX12_RESOURCE_DESC readDesc( + CD3DX12_RESOURCE_DESC::Buffer(testData.size() * sizeof(int))); + pDevice->CreateCommittedResource(&readHeap, D3D12_HEAP_FLAG_NONE, &readDesc, + D3D12_RESOURCE_STATE_COPY_DEST, nullptr, + IID_PPV_ARGS(&pTestBufferRead)); + } + + // Create CBV resource (sceneConstantBuffer), index 1 + { + const int descriptorIndex = 1; + const UINT constantBufferSize = + (sizeof(SceneConsts) + + (D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT - 1)) & + ~(D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT - + 1); // must be a multiple 256 bytes + D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + CD3DX12_CPU_DESCRIPTOR_HANDLE( + pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + descriptorIndex, descriptorSize); + auto resDesc = CD3DX12_RESOURCE_DESC::Buffer(constantBufferSize); + auto uploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); + pDevice->CreateCommittedResource(&uploadHeapProperties, + D3D12_HEAP_FLAG_NONE, &resDesc, + D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, + IID_PPV_ARGS(&pSceneConstantBuffer)); + + UINT8 *sceneConstantBufferWO; + CD3DX12_RANGE readRange( + 0, 0); // We do not intend to read from this resource on the CPU. + pSceneConstantBuffer->Map( + 0, &readRange, reinterpret_cast(&sceneConstantBufferWO)); + + // Setup Scene Constants + SceneConsts sceneConsts = { + {25.f, -25.f, 700.f, 0.f}, + {536.f, 0.f, 0.f, 0.f}, + {0.f, 301.f, 0.f, 0.f}, + {0.f, 0., -699.f, 0.f}, + 100.f, + {(unsigned int)windowWidth, (unsigned int)windowHeight}, + 0x00}; + + memcpy(sceneConstantBufferWO, &sceneConsts, sizeof(SceneConsts)); + pSceneConstantBuffer->Unmap(0, nullptr); + + D3D12_CONSTANT_BUFFER_VIEW_DESC desc = {}; + desc.SizeInBytes = constantBufferSize; + desc.BufferLocation = pSceneConstantBuffer->GetGPUVirtualAddress(); + pDevice->CreateConstantBufferView(&desc, cpuDescriptorHandle); + } + + // Local (SBT) root signature + CComPtr pLocalRootSignature; + { + CD3DX12_DESCRIPTOR_RANGE bufferRanges[1]; + CD3DX12_ROOT_PARAMETER rootParameters[1]; + bufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 1, 0, + 2); // vertexBuffer(t1), indexBuffer(t2) + rootParameters[0].InitAsDescriptorTable( + _countof(bufferRanges), bufferRanges, D3D12_SHADER_VISIBILITY_ALL); + + CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc; + rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, + D3D12_ROOT_SIGNATURE_FLAG_LOCAL_ROOT_SIGNATURE); + CComPtr signature; + CComPtr error; + VERIFY_SUCCEEDED(D3D12SerializeRootSignature( + &rootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error)); + VERIFY_SUCCEEDED(pDevice->CreateRootSignature( + 0, signature->GetBufferPointer(), signature->GetBufferSize(), + IID_PPV_ARGS(&pLocalRootSignature))); + pLocalRootSignature->SetName(L"Local Root Signature"); + } + + // Global root signature + CComPtr pGlobalRootSignature; + { + CD3DX12_DESCRIPTOR_RANGE bufferRanges[1]; + CD3DX12_ROOT_PARAMETER rootParameters[3]; + bufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, + 0); // testBuffer(u0) + rootParameters[0].InitAsShaderResourceView( + 0, 0, D3D12_SHADER_VISIBILITY_ALL); // accelStruct(t0) + rootParameters[1].InitAsConstantBufferView(0); // sceneConstants(b0) + rootParameters[2].InitAsDescriptorTable( + _countof(bufferRanges), bufferRanges, D3D12_SHADER_VISIBILITY_ALL); + + CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc; + rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, + D3D12_ROOT_SIGNATURE_FLAG_NONE); + CComPtr signature; + CComPtr error; + VERIFY_SUCCEEDED(D3D12SerializeRootSignature( + &rootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error)); + VERIFY_SUCCEEDED(pDevice->CreateRootSignature( + 0, signature->GetBufferPointer(), signature->GetBufferSize(), + IID_PPV_ARGS(&pGlobalRootSignature))); + pGlobalRootSignature->SetName(L"Global Root Signature"); + } + + // Create command queue. + CComPtr pCommandQueue; + CreateCommandQueue(pDevice, L"RunDXRTest Command Queue", &pCommandQueue, + D3D12_COMMAND_LIST_TYPE_DIRECT); + + // Compile raygen shader. + CComPtr pShaderLib; + CompileFromText(shader, L"raygen", pTargetProfile, &pShaderLib, pOptions, + numOptions); + + // Describe and create the RT pipeline state object (RTPSO). + CD3DX12_STATE_OBJECT_DESC stateObjectDesc( + D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE); + auto lib = stateObjectDesc.CreateSubobject(); + CD3DX12_SHADER_BYTECODE byteCode(pShaderLib); + lib->SetDXILLibrary(&byteCode); + lib->DefineExport(L"raygen"); + lib->DefineExport(L"closesthit"); + lib->DefineExport(L"anyhit"); + lib->DefineExport(L"miss"); + if (useIS) { + lib->DefineExport(L"intersection"); + } + + const int maxRecursion = 1; + stateObjectDesc.CreateSubobject() + ->Config(payloadCount * sizeof(float), attributeCount * sizeof(float)); + stateObjectDesc + .CreateSubobject() + ->Config(maxRecursion); + + // Set Global Root Signature subobject. + auto globalRootSigSubObj = + stateObjectDesc + .CreateSubobject(); + globalRootSigSubObj->SetRootSignature(pGlobalRootSignature); + auto exports = stateObjectDesc.CreateSubobject< + CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT>(); + exports->SetSubobjectToAssociate(*globalRootSigSubObj); + exports->AddExport(L"raygen"); + exports->AddExport(L"closesthit"); + exports->AddExport(L"anyhit"); + exports->AddExport(L"miss"); + if (useIS) { + exports->AddExport(L"intersection"); + } + + auto hitGroup = + stateObjectDesc.CreateSubobject(); + hitGroup->SetClosestHitShaderImport(L"closesthit"); + hitGroup->SetAnyHitShaderImport(L"anyhit"); + if (useIS) { + hitGroup->SetIntersectionShaderImport(L"intersection"); + hitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); + } + hitGroup->SetHitGroupExport(L"HitGroup"); + + CComPtr pStateObject; + CComPtr pStateObjectProperties; + VERIFY_SUCCEEDED( + pDevice->CreateStateObject(stateObjectDesc, IID_PPV_ARGS(&pStateObject))); + VERIFY_SUCCEEDED(pStateObject->QueryInterface(&pStateObjectProperties)); + stateObjectDesc.CreateSubobject() + ->SetRootSignature(pLocalRootSignature); + stateObjectDesc.CreateSubobject() + ->SetRootSignature(pGlobalRootSignature); + + // Create SBT + ShaderTable shaderTable; + shaderTable.Init(pDevice, + 1, // raygen count + 1, // miss count + useMesh && useProceduralGeometry ? 2 : 1, // hit group count + 1, // ray type count + 2 // dwords per root table + ); + + memcpy(shaderTable.GetRaygenShaderIdPtr(0), + pStateObjectProperties->GetShaderIdentifier(L"raygen"), + SHADER_ID_SIZE_IN_BYTES); + memcpy(shaderTable.GetMissShaderIdPtr(0, 0), + pStateObjectProperties->GetShaderIdentifier(L"miss"), + SHADER_ID_SIZE_IN_BYTES); + memcpy(shaderTable.GetHitGroupShaderIdPtr(0, 0), + pStateObjectProperties->GetShaderIdentifier(L"HitGroup"), + SHADER_ID_SIZE_IN_BYTES); + + auto tbl = pDescriptorHeap->GetGPUDescriptorHandleForHeapStart().ptr; + memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), &tbl, 8); + + // Create a command allocator and list. + CComPtr pCommandAllocator; + CComPtr pCommandList; + VERIFY_SUCCEEDED(pDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); + VERIFY_SUCCEEDED(pDevice->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, + pCommandAllocator, nullptr, + IID_PPV_ARGS(&pCommandList))); + pCommandList->SetName(L"ExecutionTest::RunDXRTest Command List"); + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + + // Create scene geometry. + CComPtr tlasResource; + CComPtr blasMeshResource; + CComPtr blasProceduralGeometryResource; + CComPtr instanceDescs; + CComPtr scratchResource; + + if (useMesh) { + CComPtr vertexBuffer; + CComPtr vertexBufferUpload; + CComPtr indexBuffer; + CComPtr indexBufferUpload; + + // Define a Quad + const float verts[] = { + -50.5f, 50.5f, 0.5f, // top left + 50.5f, -50.5f, 0.5f, // bottom right + -50.5f, -50.5f, 0.5f, // bottom left + 50.5f, 50.5f, 0.5f // top right + }; + const int indices[] = { + 0, 1, 2, // first triangle + 0, 3, 1 // second triangle + }; + + const UINT64 vertexDataSize = sizeof(verts); + const UINT64 indexDataSize = sizeof(indices); + + AllocateUploadBuffer(pDevice, verts, vertexDataSize, &vertexBufferUpload, + L"vertexBufferUpload"); + AllocateUploadBuffer(pDevice, indices, indexDataSize, &indexBufferUpload, + L"indexBufferUpload"); + + AllocateBufferFromUpload( + pDevice, pCommandList, vertexBufferUpload, &vertexBuffer, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"vertexBuffer"); + AllocateBufferFromUpload( + pDevice, pCommandList, indexBufferUpload, &indexBuffer, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"indexBuffer"); + + { + const int descriptorIndex = 1; + D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + CD3DX12_CPU_DESCRIPTOR_HANDLE( + pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + descriptorIndex, descriptorSize); + D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {}; + srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; + srvDesc.Shader4ComponentMapping = + D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; + srvDesc.Buffer.NumElements = + UINT(vertexDataSize / sizeof(DirectX::XMFLOAT3)); + srvDesc.Format = DXGI_FORMAT_UNKNOWN; + srvDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_NONE; + srvDesc.Buffer.StructureByteStride = sizeof(DirectX::XMFLOAT3); + pDevice->CreateShaderResourceView(vertexBuffer, &srvDesc, + cpuDescriptorHandle); + } + { + const int descriptorIndex = 2; + D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + CD3DX12_CPU_DESCRIPTOR_HANDLE( + pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + descriptorIndex, descriptorSize); + D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {}; + srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; + srvDesc.Shader4ComponentMapping = + D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; + srvDesc.Buffer.NumElements = UINT(indexDataSize / sizeof(int)); + srvDesc.Format = DXGI_FORMAT_UNKNOWN; + srvDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_NONE; + srvDesc.Buffer.StructureByteStride = sizeof(int); + pDevice->CreateShaderResourceView(indexBuffer, &srvDesc, + cpuDescriptorHandle); + } + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + + if (!useIS) { + // Build BLAS. + { + D3D12_RAYTRACING_GEOMETRY_DESC geometryDesc = {}; + geometryDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES; + geometryDesc.Triangles.IndexBuffer = + indexBuffer->GetGPUVirtualAddress(); + geometryDesc.Triangles.IndexCount = + static_cast(indexBuffer->GetDesc().Width) / sizeof(int); + geometryDesc.Triangles.IndexFormat = DXGI_FORMAT_R32_UINT; + geometryDesc.Triangles.Transform3x4 = 0; + geometryDesc.Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT; + geometryDesc.Triangles.VertexCount = + static_cast(vertexBuffer->GetDesc().Width) / + sizeof(DirectX::XMFLOAT3); + geometryDesc.Triangles.VertexBuffer.StartAddress = + vertexBuffer->GetGPUVirtualAddress(); + geometryDesc.Triangles.VertexBuffer.StrideInBytes = + sizeof(DirectX::XMFLOAT3); + geometryDesc.Flags = + D3D12_RAYTRACING_GEOMETRY_FLAG_NONE; // Non-opaque to trigger + // anyhit. + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS buildFlags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE; + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS accelInputs = {}; + accelInputs.Type = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + accelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + accelInputs.pGeometryDescs = &geometryDesc; + accelInputs.NumDescs = 1; + accelInputs.Flags = buildFlags; + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; + pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, + &prebuildInfo); + + ReallocScratchResource(pDevice, &scratchResource, + prebuildInfo.ScratchDataSizeInBytes); + AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes, + &blasMeshResource, true, + D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, + L"blasMesh"); + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC buildDesc = {}; + buildDesc.Inputs = accelInputs; + buildDesc.ScratchAccelerationStructureData = + scratchResource->GetGPUVirtualAddress(); + buildDesc.DestAccelerationStructureData = + blasMeshResource->GetGPUVirtualAddress(); + + pCommandList->BuildRaytracingAccelerationStructure(&buildDesc, 0, + nullptr); + CD3DX12_RESOURCE_BARRIER barrier = + CD3DX12_RESOURCE_BARRIER::UAV(blasMeshResource); + pCommandList->ResourceBarrier(1, + (const D3D12_RESOURCE_BARRIER *)&barrier); + } + } + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + } + + if (useProceduralGeometry) { + // Define procedural geometry AABB for a plane + CComPtr aabbBuffer; + CComPtr aabbBufferUpload; + + // Define the AABB for the plane, matching the size of the quad defined by + // verts[] + const D3D12_RAYTRACING_AABB aabb = { + -150.5f, -500.5f, -1000.0f, // Min corner (x, y, z) + 150.5f, -150.5f, 1000.0f // Max corner (x, y, z) + }; + const UINT64 aabbDataSize = sizeof(aabb); + + // Create an upload buffer for the AABB + AllocateUploadBuffer(pDevice, &aabb, aabbDataSize, &aabbBufferUpload, + L"aabbBufferUpload"); + + // Create a GPU buffer for the AABB + AllocateBufferFromUpload( + pDevice, pCommandList, aabbBufferUpload, &aabbBuffer, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"aabbBuffer"); + + // Describe the procedural geometry + D3D12_RAYTRACING_GEOMETRY_DESC procGeometryDesc = {}; + procGeometryDesc.Type = + D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS; + procGeometryDesc.AABBs.AABBs.StartAddress = + aabbBuffer->GetGPUVirtualAddress(); + procGeometryDesc.AABBs.AABBs.StrideInBytes = sizeof(D3D12_RAYTRACING_AABB); + procGeometryDesc.AABBs.AABBCount = 1; + + // Build the BLAS for the procedural geometry + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS blasInputs = {}; + blasInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + blasInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + blasInputs.NumDescs = 1; + blasInputs.pGeometryDescs = &procGeometryDesc; + blasInputs.Flags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE; + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; + pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&blasInputs, + &prebuildInfo); + + // Allocate scratch and result buffers for the BLAS + ReallocScratchResource(pDevice, &scratchResource, + prebuildInfo.ScratchDataSizeInBytes); + AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes, + &blasProceduralGeometryResource, true, + D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, + L"blasProceduralGeometry"); + + // Build the BLAS + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC blasDesc = {}; + blasDesc.Inputs = blasInputs; + blasDesc.ScratchAccelerationStructureData = + scratchResource->GetGPUVirtualAddress(); + blasDesc.DestAccelerationStructureData = + blasProceduralGeometryResource->GetGPUVirtualAddress(); + + pCommandList->BuildRaytracingAccelerationStructure(&blasDesc, 0, nullptr); + + // Add a UAV barrier to ensure the BLAS is built before using it + CD3DX12_RESOURCE_BARRIER barrier = + CD3DX12_RESOURCE_BARRIER::UAV(blasProceduralGeometryResource); + pCommandList->ResourceBarrier(1, &barrier); + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + } + + // Build TLAS. + { + if (useMesh) { + D3D12_RAYTRACING_INSTANCE_DESC instanceDesc = {}; + instanceDesc.Transform[0][0] = instanceDesc.Transform[1][1] = + instanceDesc.Transform[2][2] = 1; + instanceDesc.InstanceMask = 1; + instanceDesc.AccelerationStructure = + blasMeshResource->GetGPUVirtualAddress(); + + AllocateUploadBuffer(pDevice, &instanceDesc, sizeof(instanceDesc), + &instanceDescs, L"instanceDescs"); + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS buildFlags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS accelInputs = {}; + accelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; + accelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + accelInputs.NumDescs = 1; + accelInputs.Flags = buildFlags; + accelInputs.InstanceDescs = instanceDescs->GetGPUVirtualAddress(); + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; + pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, + &prebuildInfo); + + AllocateBuffer( + pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true, + D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC buildDesc = {}; + buildDesc.Inputs = accelInputs; + buildDesc.ScratchAccelerationStructureData = + scratchResource->GetGPUVirtualAddress(); + buildDesc.DestAccelerationStructureData = + tlasResource->GetGPUVirtualAddress(); + + pCommandList->BuildRaytracingAccelerationStructure(&buildDesc, 0, 0); + } else { + D3D12_RAYTRACING_INSTANCE_DESC instanceDesc = {}; + instanceDesc.Transform[0][0] = instanceDesc.Transform[1][1] = + instanceDesc.Transform[2][2] = 1; + instanceDesc.InstanceMask = 1; + instanceDesc.AccelerationStructure = + blasProceduralGeometryResource->GetGPUVirtualAddress(); + + AllocateUploadBuffer(pDevice, &instanceDesc, sizeof(instanceDesc), + &instanceDescs, L"instanceDescs"); + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS buildFlags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS accelInputs = {}; + accelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; + accelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + accelInputs.NumDescs = 1; + accelInputs.Flags = buildFlags; + accelInputs.InstanceDescs = instanceDescs->GetGPUVirtualAddress(); + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; + pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, + &prebuildInfo); + + AllocateBuffer( + pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true, + D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC buildDesc = {}; + buildDesc.Inputs = accelInputs; + buildDesc.ScratchAccelerationStructureData = + scratchResource->GetGPUVirtualAddress(); + buildDesc.DestAccelerationStructureData = + tlasResource->GetGPUVirtualAddress(); + + pCommandList->BuildRaytracingAccelerationStructure(&buildDesc, 0, 0); + } + + CD3DX12_RESOURCE_BARRIER barrier = + CD3DX12_RESOURCE_BARRIER::UAV(tlasResource); + pCommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&barrier); + } + + shaderTable.Upload(pCommandList); + + ID3D12DescriptorHeap *const pHeaps[1] = {pDescriptorHeap}; + pCommandList->SetDescriptorHeaps(1, pHeaps); + pCommandList->SetComputeRootSignature(pGlobalRootSignature); + pCommandList->SetComputeRootShaderResourceView( + 0, tlasResource->GetGPUVirtualAddress()); + pCommandList->SetComputeRootConstantBufferView( + 1, pSceneConstantBuffer->GetGPUVirtualAddress()); + pCommandList->SetComputeRootDescriptorTable( + 2, pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + D3D12_DISPATCH_RAYS_DESC dispatchDesc = {}; + dispatchDesc.RayGenerationShaderRecord.StartAddress = + shaderTable.GetRaygenStartGpuVA(); + dispatchDesc.RayGenerationShaderRecord.SizeInBytes = + shaderTable.GetRaygenRangeInBytes(); + dispatchDesc.MissShaderTable.StartAddress = shaderTable.GetMissStartGpuVA(); + dispatchDesc.MissShaderTable.SizeInBytes = shaderTable.GetMissRangeInBytes(); + dispatchDesc.MissShaderTable.StrideInBytes = + shaderTable.GetShaderRecordSizeInBytes(); + dispatchDesc.HitGroupTable.StartAddress = shaderTable.GetHitGroupStartGpuVA(); + dispatchDesc.HitGroupTable.SizeInBytes = + shaderTable.GetHitGroupRangeInBytes(); + dispatchDesc.HitGroupTable.StrideInBytes = + shaderTable.GetShaderRecordSizeInBytes(); + dispatchDesc.Width = windowWidth; + dispatchDesc.Height = windowHeight; + dispatchDesc.Depth = 1; + pCommandList->SetPipelineState1(pStateObject); + pCommandList->DispatchRays(&dispatchDesc); + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + + // Copy the testBuffer contents to CPU + D3D12_RESOURCE_BARRIER barriers[1]; + barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( + pTestBuffer, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE); + pCommandList->ResourceBarrier(1, barriers); + pCommandList->CopyResource(pTestBufferRead, pTestBuffer); + barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( + pTestBuffer, D3D12_RESOURCE_STATE_COPY_SOURCE, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + pCommandList->ResourceBarrier(1, barriers); + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + // Copy the shader test data into 'testData'. + MappedData data(pTestBufferRead, (UINT32)testData.size() * sizeof(int)); + const int *pData = (int *)data.data(); + + for (int i = 0; i < testData.size(); i++) { + testData[i] = *pData++; + } + + // Cleanup resources + pTestBuffer.Release(); + pTestBufferRead.Release(); + pSceneConstantBuffer.Release(); + pDescriptorHeap.Release(); + pCommandQueue.Release(); + pCommandAllocator.Release(); + pCommandList.Release(); + pStateObject.Release(); + pStateObjectProperties.Release(); + tlasResource.Release(); + blasMeshResource.Release(); + blasProceduralGeometryResource.Release(); + instanceDescs.Release(); + scratchResource.Release(); + + return pTestBufferRead; +} + +// SER TESTS +#include "ExecutionTest_SER.h" +// + void ExecutionTest::RunLifetimeIntrinsicComputeTest( ID3D12Device *pDevice, LPCSTR pShader, CComPtr &pUavHeap, diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h new file mode 100644 index 0000000000..a99d55e79c --- /dev/null +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -0,0 +1,1703 @@ +//===--------- ExecutionTest_SER.h - SER Execution Tests -------*- C++ -*-===// +/////////////////////////////////////////////////////////////////////////////// +// // +// ExecutionTest_SER.h // +// Copyright (C) Nvidia Corporation. All rights reserved. // +// This file is distributed under the University of Illinois Open Source // +// License. See LICENSE.TXT for details. // +// // +// This file contains the execution tests for SER. // +// // +/////////////////////////////////////////////////////////////////////////////// + +TEST_F(ExecutionTest, SERScalarGetterTest) { + // SER: Test basic function of HitObject getters. + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + VALTYPE value : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + int id = 2 * (launchIndex.x + launchIndex.y * launchDim.x); + + RayDesc ray = ComputeRay(); + + // Fetch reference value + PerRayData refPayload; + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, refPayload); + testBuffer[id] = refPayload.value; + + PerRayData serPayload; + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, serPayload); + dx::MaybeReorderThread(hitObject); + VALTYPE serVal = hitObject.SER_GET_SCALAR(); + testBuffer[id + 1] = serVal; +} + +float getFloatZero() { return 0.0f; } +int getIntZero() { return 0; } + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.value = MISS_GET_SCALAR(); +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + // UNUSED +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.value = HIT_GET_SCALAR(); +} +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + + if (!bDXRSupported) + return; + + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + + // RayTMin + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=float", + L"-DHIT_GET_SCALAR=RayTMin", + L"-DMISS_GET_SCALAR=RayTMin", + L"-DSER_GET_SCALAR=GetRayTMin"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + float *resArray = (float *)(testData.data() + id); + float refVal = resArray[0]; + float serVal = resArray[1]; + const bool passRayTMin = CompareFloatEpsilon(serVal, refVal, 0.0008f); + if (!passRayTMin) { + VERIFY_IS_TRUE(passRayTMin); + WEX::Logging::Log::Comment(L"HitObject::GetRayTMin() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetRayTMin() PASSED"); + } + + // RayTCurrent + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=float", + L"-DHIT_GET_SCALAR=RayTCurrent", + L"-DMISS_GET_SCALAR=RayTCurrent", + L"-DSER_GET_SCALAR=GetRayTCurrent"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + float *resArray = (float *)(testData.data() + id); + float refVal = resArray[0]; + float serVal = resArray[1]; + const bool passRayTCurrent = CompareFloatEpsilon(serVal, refVal, 0.0008f); + if (!passRayTCurrent) { + VERIFY_IS_TRUE(passRayTCurrent); + WEX::Logging::Log::Comment(L"HitObject::GetRayTCurrent() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetRayTCurrent() PASSED"); + } + + // RayFlags + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=uint", + L"-DHIT_GET_SCALAR=RayFlags", + L"-DMISS_GET_SCALAR=RayFlags", + L"-DSER_GET_SCALAR=GetRayFlags"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + const int refVal = testData[id]; + const int serVal = testData[id + 1]; + if (refVal != serVal) { + VERIFY_ARE_EQUAL(refVal, serVal); + WEX::Logging::Log::Comment(L"HitObject::GetRayFlags() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetRayFlags() PASSED"); + } + + // HitKind + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=uint", + L"-DHIT_GET_SCALAR=HitKind", + L"-DMISS_GET_SCALAR=getIntZero", + L"-DSER_GET_SCALAR=GetHitKind"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + const int refVal = testData[id]; + const int serVal = testData[id + 1]; + if (refVal != serVal) { + VERIFY_ARE_EQUAL(refVal, serVal); + WEX::Logging::Log::Comment(L"HitObject::GetHitKind() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetHitKind() PASSED"); + } + + // GeometryIndex + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=uint", + L"-DHIT_GET_SCALAR=GeometryIndex", + L"-DMISS_GET_SCALAR=getIntZero", + L"-DSER_GET_SCALAR=GetGeometryIndex"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + const int refVal = testData[id]; + const int serVal = testData[id + 1]; + if (refVal != serVal) { + VERIFY_ARE_EQUAL(refVal, serVal); + WEX::Logging::Log::Comment(L"HitObject::GetGeometryIndex() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetGeometryIndex() PASSED"); + } + + // InstanceIndex + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=uint", + L"-DHIT_GET_SCALAR=InstanceIndex", + L"-DMISS_GET_SCALAR=getIntZero", + L"-DSER_GET_SCALAR=GetInstanceIndex"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + const int refVal = testData[id]; + const int serVal = testData[id + 1]; + if (refVal != serVal) { + VERIFY_ARE_EQUAL(refVal, serVal); + WEX::Logging::Log::Comment(L"HitObject::GetInstanceIndex() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetInstanceIndex() PASSED"); + } + + // InstanceID + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=uint", + L"-DHIT_GET_SCALAR=InstanceID", + L"-DMISS_GET_SCALAR=getIntZero", + L"-DSER_GET_SCALAR=GetInstanceID"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + const int refVal = testData[id]; + const int serVal = testData[id + 1]; + if (refVal != serVal) { + VERIFY_ARE_EQUAL(refVal, serVal); + WEX::Logging::Log::Comment(L"HitObject::GetInstanceID() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetInstanceID() PASSED"); + } + + // PrimitiveIndex + { + std::vector testData(windowSize * windowSize * 2, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DVALTYPE=uint", + L"-DHIT_GET_SCALAR=PrimitiveIndex", + L"-DMISS_GET_SCALAR=getIntZero", + L"-DSER_GET_SCALAR=GetPrimitiveIndex"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + for (int id = 0; id < testData.size(); id += 2) { + const int refVal = testData[id]; + const int serVal = testData[id + 1]; + if (refVal != serVal) { + VERIFY_ARE_EQUAL(refVal, serVal); + WEX::Logging::Log::Comment(L"HitObject::GetPrimitiveIndex() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetPrimitiveIndex() PASSED"); + } +} + +TEST_F(ExecutionTest, SERVectorGetterTest) { + // SER: Test basic function of HitObject getters. + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + float3 value : read(caller) : write(miss,closesthit); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + int id = 6 * (launchIndex.x + launchIndex.y * launchDim.x); + + RayDesc ray = ComputeRay(); + + // Fetch reference value + PerRayData refPayload; + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, refPayload); + testBuffer[id] = refPayload.value.x; + testBuffer[id + 2] = refPayload.value.y; + testBuffer[id + 4] = refPayload.value.z; + + PerRayData serPayload; + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, serPayload); + dx::MaybeReorderThread(hitObject); + float3 serVal = hitObject.SER_GET_VECTOR(); + testBuffer[id + 1] = serVal.x; + testBuffer[id + 3] = serVal.y; + testBuffer[id + 5] = serVal.z; +} + +float3 getVecZero() { return 0.0f; } + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.value = MISS_GET_VECTOR(); +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + // UNUSED +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.value = HIT_GET_VECTOR(); +} +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + + if (!bDXRSupported) + return; + + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + + // WorldRayOrigin + { + std::vector testData(windowSize * windowSize * 6, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=WorldRayOrigin", + L"-DMISS_GET_VECTOR=WorldRayOrigin", + L"-DSER_GET_VECTOR=GetWorldRayOrigin"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 3 /*payloadCount*/); + for (int id = 0; id < testData.size(); id += 6) { + float *resArray = (float *)(testData.data() + id); + float refX = resArray[0]; + float serX = resArray[1]; + float refY = resArray[2]; + float serY = resArray[3]; + float refZ = resArray[4]; + float serZ = resArray[5]; + const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); + const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); + const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); + if (!passX || !passY || !passZ) { + VERIFY_ARE_EQUAL(serX, refX); + VERIFY_ARE_EQUAL(serY, refY); + VERIFY_ARE_EQUAL(serZ, refZ); + WEX::Logging::Log::Comment(L"HitObject::GetWorldRayOrigin() FAILED"); + break; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetWorldRayOrigin() PASSED"); + } + + // WorldRayDirection + { + std::vector testData(windowSize * windowSize * 6, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd", + L"-DHIT_GET_VECTOR=WorldRayDirection", + L"-DMISS_GET_VECTOR=WorldRayDirection", + L"-DSER_GET_VECTOR=GetWorldRayDirection"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 3 /*payloadCount*/); + for (int id = 0; id < testData.size(); id += 6) { + float *resArray = (float *)(testData.data() + id); + float refX = resArray[0]; + float serX = resArray[1]; + float refY = resArray[2]; + float serY = resArray[3]; + float refZ = resArray[4]; + float serZ = resArray[5]; + const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); + const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); + const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); + if (!passX || !passY || !passZ) { + VERIFY_ARE_EQUAL(serX, refX); + VERIFY_ARE_EQUAL(serY, refY); + VERIFY_ARE_EQUAL(serZ, refZ); + WEX::Logging::Log::Comment(L"HitObject::GetWorldRayDirection() FAILED"); + return; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetWorldRayDirection() PASSED"); + } + + // ObjectRayOrigin + { + std::vector testData(windowSize * windowSize * 6, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=ObjectRayOrigin", + L"-DMISS_GET_VECTOR=WorldRayOrigin", + L"-DSER_GET_VECTOR=GetObjectRayOrigin"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 3 /*payloadCount*/); + for (int id = 0; id < testData.size(); id += 6) { + float *resArray = (float *)(testData.data() + id); + float refX = resArray[0]; + float serX = resArray[1]; + float refY = resArray[2]; + float serY = resArray[3]; + float refZ = resArray[4]; + float serZ = resArray[5]; + const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); + const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); + const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); + if (!passX || !passY || !passZ) { + VERIFY_ARE_EQUAL(serX, refX); + VERIFY_ARE_EQUAL(serY, refY); + VERIFY_ARE_EQUAL(serZ, refZ); + WEX::Logging::Log::Comment(L"HitObject::GetObjectRayOrigin() FAILED"); + break; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetObjectRayOrigin() PASSED"); + } + + // ObjectRayDirection + { + std::vector testData(windowSize * windowSize * 6, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd", + L"-DHIT_GET_VECTOR=ObjectRayDirection", + L"-DMISS_GET_VECTOR=WorldRayDirection", + L"-DSER_GET_VECTOR=GetObjectRayDirection"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 3 /*payloadCount*/); + for (int id = 0; id < testData.size(); id += 6) { + float *resArray = (float *)(testData.data() + id); + float refX = resArray[0]; + float serX = resArray[1]; + float refY = resArray[2]; + float serY = resArray[3]; + float refZ = resArray[4]; + float serZ = resArray[5]; + const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); + const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); + const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); + if (!passX || !passY || !passZ) { + VERIFY_ARE_EQUAL(serX, refX); + VERIFY_ARE_EQUAL(serY, refY); + VERIFY_ARE_EQUAL(serZ, refZ); + WEX::Logging::Log::Comment( + L"HitObject::GetObjectRayDirection() FAILED"); + break; + } + } + WEX::Logging::Log::Comment(L"HitObject::GetObjectRayDirection() PASSED"); + } +} + +TEST_F(ExecutionTest, SERMatrixGetterTest) { + // SER: Test basic function of HitObject getters. + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + matrix value : read(caller) : write(miss,closesthit); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + int id = 2 * ROWS * COLS * (launchIndex.x + launchIndex.y * launchDim.x); + + RayDesc ray = ComputeRay(); + + // Fetch reference value + PerRayData refPayload; + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, refPayload); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + testBuffer[id + 2 * (r * COLS + c)] = refPayload.value[r][c]; + } + } + + PerRayData serPayload; + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, serPayload); + dx::MaybeReorderThread(hitObject); + matrix serVal = hitObject.SER_GET_MATRIX(); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + testBuffer[1 + id + 2 * (r * COLS + c)] = serVal[r][c]; + } + } +} + +matrix getMatIdentity() { + matrix mat = 0; + mat[0][0] = 1.f; + mat[1][1] = 1.f; + mat[2][2] = 1.f; + return mat; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.value = MISS_GET_MATRIX(); +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + // UNUSED +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.value = HIT_GET_MATRIX(); +} +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + + if (!bDXRSupported) + return; + + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + + // WorldToObject3x4 + { + std::vector testData(windowSize * windowSize * 24, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DHIT_GET_MATRIX=WorldToObject3x4", + L"-DMISS_GET_MATRIX=getMatIdentity", + L"-DSER_GET_MATRIX=GetWorldToObject3x4", + L"-DROWS=3", + L"-DCOLS=4"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 12 /*payloadCount*/); + const int ROWS = 3; + const int COLS = 4; + for (int id = 0; id < testData.size(); id += 24) { + float *resArray = (float *)(testData.data() + id); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + int refIdx = 2 * (r * COLS + c); + float ref = resArray[refIdx]; + float ser = resArray[1 + refIdx]; + if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { + VERIFY_ARE_EQUAL(ser, ref); + } + } + } + } + WEX::Logging::Log::Comment(L"HitObject::GetWorldToObject3x4() PASSED"); + } + + // WorldToObject4x3 + { + const int ROWS = 4; + const int COLS = 3; + std::vector testData(windowSize * windowSize * 2 * ROWS * COLS, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DHIT_GET_MATRIX=WorldToObject4x3", + L"-DMISS_GET_MATRIX=getMatIdentity", + L"-DSER_GET_MATRIX=GetWorldToObject4x3", + L"-DROWS=4", + L"-DCOLS=3"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 12 /*payloadCount*/); + for (int id = 0; id < testData.size(); id += 2 * ROWS * COLS) { + float *resArray = (float *)(testData.data() + id); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + int refIdx = 2 * (r * COLS + c); + float ref = resArray[refIdx]; + float ser = resArray[1 + refIdx]; + if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { + VERIFY_ARE_EQUAL(ser, ref); + } + } + } + } + WEX::Logging::Log::Comment(L"HitObject::GetWorldToObject4x3() PASSED"); + } + + // ObjectToWorld3x4 + { + std::vector testData(windowSize * windowSize * 24, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DHIT_GET_MATRIX=ObjectToWorld3x4", + L"-DMISS_GET_MATRIX=getMatIdentity", + L"-DSER_GET_MATRIX=GetObjectToWorld3x4", + L"-DROWS=3", + L"-DCOLS=4"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 12 /*payloadCount*/); + const int ROWS = 3; + const int COLS = 4; + for (int id = 0; id < testData.size(); id += 24) { + float *resArray = (float *)(testData.data() + id); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + int refIdx = 2 * (r * COLS + c); + float ref = resArray[refIdx]; + float ser = resArray[1 + refIdx]; + if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { + VERIFY_ARE_EQUAL(ser, ref); + } + } + } + } + WEX::Logging::Log::Comment(L"HitObject::GetObjectToWorld3x4() PASSED"); + } + + // ObjectToWorld4x3 + { + std::vector testData(windowSize * windowSize * 24, 0); + LPCWSTR args[] = {L"-HV 2021", + L"-Vd", + L"-DHIT_GET_MATRIX=ObjectToWorld4x3", + L"-DMISS_GET_MATRIX=getMatIdentity", + L"-DSER_GET_MATRIX=GetObjectToWorld4x3", + L"-DROWS=4", + L"-DCOLS=3"}; + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/, + 12 /*payloadCount*/); + const int ROWS = 4; + const int COLS = 3; + for (int id = 0; id < testData.size(); id += 24) { + float *resArray = (float *)(testData.data() + id); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + int refIdx = 2 * (r * COLS + c); + float ref = resArray[refIdx]; + float ser = resArray[1 + refIdx]; + if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { + VERIFY_ARE_EQUAL(ser, ref); + WEX::Logging::Log::Comment( + L"HitObject::GetObjectToWorld4x3() FAILED"); + break; + } + } + } + } + WEX::Logging::Log::Comment(L"HitObject::GetObjectToWorld4x3() PASSED"); + } +} + +TEST_F(ExecutionTest, SERBasicTest) { + // SER: Test basic functionality. + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + // SER Test + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + dx::MaybeReorderThread(hitObject); + dx::HitObject::Invoke(hitObject, payload); + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 2U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (bDXRSupported) { + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 2); + VERIFY_ARE_EQUAL(histo[2], 4030); + VERIFY_ARE_EQUAL(histo[5], 66); + } +} + +TEST_F(ExecutionTest, SERRayQueryTest) { + // Test SER RayQuery + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + // Template parameter set at runtime before compilation + RayQuery rayQ; + + // Funtion parameter set at runtime before compilation + rayQ.TraceRayInline(topObject, RAY_FLAG_NONE, 0xFF, ray); + + // Storage for procedural primitive hit attributes + Attrs attrs; + attrs.barycentrics = float2(1, 1); + + while (rayQ.Proceed()) + { + switch (rayQ.CandidateType()) + { + + case CANDIDATE_NON_OPAQUE_TRIANGLE: + { + // The system has already determined that the candidate would be the closest + // hit so far in the ray extents + rayQ.CommitNonOpaqueTriangleHit(); + } + } + } + +#if 0 + switch (rayQ.CommittedStatus()) + { + case COMMITTED_TRIANGLE_HIT: + { + if (rayQ.CommittedTriangleFrontFace()) + { + // Hit + payload.visited |= 4U; + } + break; + } + case COMMITTED_PROCEDURAL_PRIMITIVE_HIT: + { + // Unused + break; + } + case COMMITTED_NOTHING: + { + // Miss + payload.visited |= 2U; + break; + } + } +#else + dx::HitObject hit; + if (rayQ.CommittedStatus() == COMMITTED_NOTHING) + { + hit = dx::HitObject::MakeMiss(RAY_FLAG_NONE, 0, ray); + } + else + { + hit = dx::HitObject::FromRayQuery(rayQ); + } + dx::MaybeReorderThread(hit); + dx::HitObject::Invoke(hit, payload); +#endif + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 2U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; + AcceptHitAndEndSearch(); +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERRayQueryTest requires shader model 6.9+ " + L"but no supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SERRayQueryTest skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SERRayQueryTest skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (bDXRSupported) { + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 2); + VERIFY_ARE_EQUAL(histo[0], 66); + VERIFY_ARE_EQUAL(histo[2], 4030); + } +} + +TEST_F(ExecutionTest, SERIntersectionTest) { + // Test SER with Intersection and procedural geometry + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit, closesthit, miss, caller) : write(anyhit, miss, closesthit, caller); +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x * sceneConstants.U.xyz + d.y * sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + +#if 0 + dx::HitObject hitObject; + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); +#else + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + dx::MaybeReorderThread(hitObject); + dx::HitObject::Invoke(hitObject, payload); +#endif + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 2U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; + AcceptHitAndEndSearch(); +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +[shader("intersection")] +void intersection() +{ + Attrs attrs; + + ReportHit(0.1, 0, attrs); +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (bDXRSupported) { + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, false /*mesh*/, + true /*procedural geometry*/, true /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 2); + VERIFY_ARE_EQUAL(histo[2], 3400); + VERIFY_ARE_EQUAL(histo[5], 696); + } +} + +TEST_F(ExecutionTest, SERGetAttributesTest) { + // Test SER with HitObject::GetAttributes + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct CustomAttrs +{ + float dist; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit, closesthit, miss, caller) : write(anyhit, miss, closesthit, caller); +}; + +// reordercoherent // Requires #7250 +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x * sceneConstants.U.xyz + d.y * sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + dx::MaybeReorderThread(hitObject); + + // Check Attributes for hit detection. + CustomAttrs customAttrs = hitObject.GetAttributes(); + bool isHit = hitObject.IsHit(); + + int testVal = 0; + if (isHit) { + if (int(floor(customAttrs.dist)) % 2 == 0) + testVal = hitObject.GetHitKind(); + } + else + { + // Use 255 to keep outside the HitKind range [0, 127] we passthru for hits. + testVal = 255; + } + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = testVal; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + // UNUSED +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in CustomAttrs attrs) +{ + AcceptHitAndEndSearch(); +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in CustomAttrs attrs) +{ + // UNUSED +} + +[shader("intersection")] +void intersection() +{ + // hitPos is intersection point with plane (base, n) + float3 base = {0.0f,0.0f,0.5f}; + float3 n = normalize(float3(0.0f,0.5f,0.5f)); + float t = dot(n, base - ObjectRayOrigin()) / dot(n, ObjectRayDirection()); + if (t > RayTCurrent() || t < RayTMin()) { + return; + } + float3 hitPos = ObjectRayOrigin() + t * ObjectRayDirection(); + float3 relHitPos = hitPos - base; + // Encode some hit information in hitKind + int hitKind = 0; + if (relHitPos.y >= 0.0f) + hitKind = 1; + hitKind *= 2; + if (relHitPos.x >= 0.0f) + hitKind += 1; + hitKind *= 2; + if (relHitPos.z >= 0.0f) + hitKind += 1; + + CustomAttrs attrs; + attrs.dist = length(relHitPos); + ReportHit(t, hitKind, attrs); +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (bDXRSupported) { + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, false /*mesh*/, + true /*procedural geometry*/, true /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 4); + VERIFY_ARE_EQUAL(histo[0], 328); + VERIFY_ARE_EQUAL(histo[1], 186); + VERIFY_ARE_EQUAL(histo[3], 182); + VERIFY_ARE_EQUAL(histo[255], 3400); + } +} + +TEST_F(ExecutionTest, SERTraceHitMissNopTest) { + // Test SER with conditional HitObject::TraceRay, HitObject::IsHit, + // HitObject::IsMiss, HitObject::IsNop + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + // SER Test + dx::HitObject hitObject; + if (launchIndex.x % 2 == 0) { + hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + } + dx::MaybeReorderThread(hitObject); + + // Check hitObject for hit detection. + if (hitObject.IsHit()) + payload.visited |= 4U; + if (hitObject.IsMiss()) + payload.visited |= 2U; + if (hitObject.IsNop()) + payload.visited |= 1U; + + dx::HitObject::Invoke(hitObject, payload); + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 16U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 8U; +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 32U; +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (bDXRSupported) { + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*mesh*/, + false /*procedural geometry*/, false /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 3); + VERIFY_ARE_EQUAL( + histo[1], + 2048); // isNop && !isMiss && !isHit && !anyhit && !closesthit && !miss + VERIFY_ARE_EQUAL( + histo[18], + 2015); // !isNop && isMiss && !isHit && !anyhit && !closesthit && miss + VERIFY_ARE_EQUAL( + histo[44], + 33); // !isNop && !isMiss && isHit && anyhit && closesthit && !miss + } +} + +TEST_F(ExecutionTest, SERIsMissTest) { + // Test SER with HitObject::IsMiss + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + // SER Test + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + dx::MaybeReorderThread(hitObject); + dx::HitObject::Invoke(hitObject, payload); + + // Check hitObject for hit detection. + if (hitObject.IsMiss()) + payload.visited |= 2U; + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + // UNUSED +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (bDXRSupported) { + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*mesh*/, + false /*procedural geometry*/, false /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 2); + VERIFY_ARE_EQUAL(histo[2], 4030); + VERIFY_ARE_EQUAL(histo[5], 66); + } +} From e6179100119462e1b4c38873e270668ac6dc5b42 Mon Sep 17 00:00:00 2001 From: Joshua Batista Date: Mon, 28 Apr 2025 09:50:11 -0700 Subject: [PATCH 13/31] [CoopVec] Add Linear Algebra common header with tests (#7350) This PR introduces the linear algebra header file, and places it in a location that is by default included in all HLSL compilation. The builtins in the API aren't yet defined, and depend on the #7290 PR merging first. The tests that have been added have temporary diagnostic messages while 7290 is in progress. They will need to be updated. Open to feedback on better / suggested error messages, or whether there shouldn't be any sema-level validation for these errors. Fixes [#7304](https://github.com/microsoft/DirectXShaderCompiler/issues/7304) --------- Co-authored-by: github-actions[bot] --- tools/clang/lib/Headers/hlsl/dx/linalg.h | 182 ++++++++++++++++++ .../CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl | 40 ++++ .../hlsl/linalg/mat-vec-muladd.hlsl | 90 +++++++++ .../hlsl/linalg/outerproductaccumulate.hlsl | 16 ++ .../hlsl/linalg/vectoraccumulate.hlsl | 14 ++ .../hlsl/linalg/make-interp-vec-errors.hlsl | 33 ++++ .../hlsl/linalg/mat-vec-mul-errors.hlsl | 16 ++ .../linalg/mat-vec-mul-transpose-errors.hlsl | 30 +++ .../hlsl/linalg/mat-vec-muladd-errors.hlsl | 16 ++ .../linalg/outerproductaccumulate-errors.hlsl | 44 +++++ .../outerproductaccumulate-spirv-errors.hlsl | 19 ++ .../hlsl/linalg/vectoraccumulate-errors.hlsl | 16 ++ 12 files changed, 516 insertions(+) create mode 100644 tools/clang/lib/Headers/hlsl/dx/linalg.h create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/linalg/outerproductaccumulate.hlsl create mode 100644 tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-transpose-errors.hlsl create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-muladd-errors.hlsl create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-errors.hlsl create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-spirv-errors.hlsl create mode 100644 tools/clang/test/SemaHLSL/hlsl/linalg/vectoraccumulate-errors.hlsl diff --git a/tools/clang/lib/Headers/hlsl/dx/linalg.h b/tools/clang/lib/Headers/hlsl/dx/linalg.h new file mode 100644 index 0000000000..51e662bbc9 --- /dev/null +++ b/tools/clang/lib/Headers/hlsl/dx/linalg.h @@ -0,0 +1,182 @@ +// Header for linear algebra APIs. + +#if __spirv__ +#error "Cooperative vectors not (yet) supported for SPIRV" +#endif + +#if ((__SHADER_TARGET_MAJOR > 6) || \ + (__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 9)) && \ + (__HLSL_VERSION >= 2021) + +namespace dx { +namespace linalg { + +// NOTE: can't be an enum class because we get this error: +// error: non-type template argument of type 'dx::linalg::DataType' is not +// an integral constant expression +// +enum DataType { + DATA_TYPE_SINT16 = 2, // ComponentType::I16 + DATA_TYPE_UINT16 = 3, // ComponentType::U16 + DATA_TYPE_SINT32 = 4, // ComponentType::I32 + DATA_TYPE_UINT32 = 5, // ComponentType::U32 + DATA_TYPE_FLOAT16 = 8, // ComponentType::F16 + DATA_TYPE_FLOAT32 = 9, // ComponentType::F32 + DATA_TYPE_SINT8_T4_PACKED = 17, // ComponentType::PackedS8x32 + DATA_TYPE_UINT8_T4_PACKED = 18, // ComponentType::PackedU8x32 + DATA_TYPE_UINT8 = 19, // ComponentType::U8 + DATA_TYPE_SINT8 = 20, // ComponentType::I8 + DATA_TYPE_FLOAT8_E4M3 = 21, // ComponentType::F8_E4M3 + // (1 sign, 4 exp, 3 mantissa bits) + DATA_TYPE_FLOAT8_E5M2 = 22, // ComponentType::F8_E5M2 + // (1 sign, 5 exp, 2 mantissa bits) +}; + +enum MatrixLayout { + MATRIX_LAYOUT_ROW_MAJOR = 0, + MATRIX_LAYOUT_COLUMN_MAJOR = 1, + MATRIX_LAYOUT_MUL_OPTIMAL = 2, + MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = 3 +}; + +// +// Helper for signedness +// +namespace details { +template bool IsUnsigned() { return false; } + +#ifdef __HLSL_ENABLE_16_BIT +template <> bool IsUnsigned() { return true; } +#endif + +template <> bool IsUnsigned() { return true; } +template <> bool IsUnsigned() { return true; } +} // namespace details + +// +// (RW)MatrixRef +// + +template +struct MatrixRefImpl { + BufferTy Buffer; + uint StartOffset; + uint Stride; +}; + +template +using MatrixRef = MatrixRefImpl; + +template +using RWMatrixRef = MatrixRefImpl; + +// +// (RW)VectorRef +// + +template struct VectorRefImpl { + BufferTy Buffer; + uint StartOffset; +}; + +template using VectorRef = VectorRefImpl; + +template +using RWVectorRef = VectorRefImpl; + +// +// Vector +// + +template struct InterpretedVector { + vector Data; +}; + +template +InterpretedVector MakeInterpretedVector(vector Vec) { + InterpretedVector IV = {Vec}; + return IV; +} + +// +// Mul +// + +template +vector +Mul(MatrixRefImpl + Matrix, + InterpretedVector InputVector) { + + vector OutputVector; + + __builtin_MatVecMul( + /*out*/ OutputVector, details::IsUnsigned(), InputVector.Data, + details::IsUnsigned(), InputDT, Matrix.Buffer, + Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout, + MatrixTranspose, Matrix.Stride); + + return OutputVector; +} + +// +// MulAdd +// + +template +vector +MulAdd(MatrixRefImpl + Matrix, + InterpretedVector InputVector, + VectorRefImpl BiasVector) { + + vector OutputVector; + + __builtin_MatVecMulAdd( + /*out*/ OutputVector, details::IsUnsigned(), InputVector.Data, + details::IsUnsigned(), InputDT, Matrix.Buffer, + Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout, + MatrixTranspose, Matrix.Stride, BiasVector.Buffer, BiasVector.StartOffset, + BiasVectorDT); + + return OutputVector; +} + +// +// OuterProductAccumulate +// + +template +void OuterProductAccumulate( + vector InputVector1, vector InputVector2, + RWMatrixRef Matrix) { + __builtin_OuterProductAccumulate(InputVector1, InputVector2, Matrix.Buffer, + Matrix.StartOffset, MatrixDT, MatrixLayout, + Matrix.Stride); +} + +// +// VectorAccumulate +// + +template +void VectorAccumulate(vector InputVector, + RWByteAddressBuffer Buffer, uint Offset) { + __builtin_VectorAccumulate(InputVector, Buffer, Offset); +} + +} // namespace linalg +} // namespace dx + +#endif // SM 6.9 check and HV version check diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl new file mode 100644 index 0000000000..141801c71c --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl @@ -0,0 +1,40 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s + +#include + +ByteAddressBuffer Buf; + +export float4 Test1(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMul.v4f32.v4f32(i32 305, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, i1 false) + return Mul( + Matrix, MakeInterpretedVector(Input)); +} + +export vector Test2(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // note the stride argument is dropped. + // CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 2, i1 false, i32 0, i1 false) + return Mul(Matrix, + MakeInterpretedVector(Input)); +} + +// test that "stride" isn't ignored in non-optimal layouts +export vector Test3(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 6 * 4 * 8}; + + // CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 0, i1 false, i32 192, i1 false) + return Mul(Matrix, + MakeInterpretedVector(Input)); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl new file mode 100644 index 0000000000..c19e601904 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl @@ -0,0 +1,90 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s + +#include + +ByteAddressBuffer Buf; + +export float4 Test1(float4 input) { + using namespace dx::linalg; + + MatrixRef matrix = {Buf, + 0, 0}; + VectorRef biasVector = {Buf, 256}; + + InterpretedVector theVector = {input}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 false, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false) + return MulAdd( + matrix, theVector, + biasVector); +} + +export float4 Test2(float4 input) { + using namespace dx::linalg; + + MatrixRef matrix = { + Buf, 0, 0}; + VectorRef biasVector = {Buf, 256}; + + InterpretedVector theVector = {input}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false) + return MulAdd( + matrix, theVector, + biasVector); +} + +export float4 Test3(float4 input) { + using namespace dx::linalg; + + MatrixRef matrix = { + Buf, 0, 0}; + VectorRef biasVector = {Buf, 256}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false) + return MulAdd( + matrix, MakeInterpretedVector(input), + biasVector); +} + +namespace ProposalExample { + +ByteAddressBuffer model; + +vector ApplyNeuralMaterial(vector inputVector) { + using namespace dx::linalg; + + MatrixRef matrix0 = { + model, 0, 0}; + + VectorRef biasVector0 = {model, 1024}; + + MatrixRef matrix1 = + {model, 2048, 0}; + + VectorRef biasVector1 = {model, 3072}; + + MatrixRef matrix2 = { + model, 4096, 0}; + + VectorRef biasVector2 = {model, 5120}; + + vector layer0 = MulAdd( + matrix0, MakeInterpretedVector(inputVector), + biasVector0); + layer0 = max(layer0, 0); + + vector layer1 = MulAdd( + matrix1, MakeInterpretedVector(layer0), + biasVector1); + layer1 = max(layer1, 0); + + vector output = MulAdd( + matrix2, MakeInterpretedVector(layer1), + biasVector2); + output = exp(output); + + return output; +} + +} // namespace ProposalExample diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/outerproductaccumulate.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/outerproductaccumulate.hlsl new file mode 100644 index 0000000000..eda15c66f6 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/outerproductaccumulate.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s + +#include + +RWByteAddressBuffer RWBuf; + +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // CHECK: call void @dx.op.outerProductAccumulate.v128f16.v64f16(i32 307, <128 x half> %{{.+}}, <64 x half> %{{.+}}, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 3, i32 0) + + OuterProductAccumulate(Input1, Input2, matrix); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl new file mode 100644 index 0000000000..9157156f10 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl @@ -0,0 +1,14 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s + +#include + +RWByteAddressBuffer RWBuf; + +export void Test5(vector Input) { + using namespace dx::linalg; + + RWBuf.Store >(0, Input); + + // CHECK: call void @dx.op.vectorAccumulate.v128f32(i32 308, <128 x float> %{{.*}}, %dx.types.Handle %{{.*}}, i32 0) + VectorAccumulate(Input, RWBuf, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl new file mode 100644 index 0000000000..9f2793d417 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl @@ -0,0 +1,33 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify + +#include +ByteAddressBuffer Buf; + +export float4 Test1(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}} + // expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}} + return Mul( + Matrix, MakeInterpretedVector<2>(Input)); +} + +enum DataType { + DATA_TYPE_InvalidType = 40 +}; + +export float4 Test2(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}} + // expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}} + return Mul( + Matrix, MakeInterpretedVector(Input)); +} + diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl new file mode 100644 index 0000000000..2d5a11e83e --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify + +#include + +ByteAddressBuffer Buf; + +vector MixUpVectorAndMatrixArguments(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+2{{no matching function for call to 'Mul'}} + // expected-note@dx/linalg.h:111{{candidate template ignored: could not match 'MatrixRefImpl' against 'InterpretedVector'}} + return Mul(MakeInterpretedVector(Input), Matrix); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-transpose-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-transpose-errors.hlsl new file mode 100644 index 0000000000..2018acafab --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-transpose-errors.hlsl @@ -0,0 +1,30 @@ +// XFAIL: * +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s -verify + +#include + +ByteAddressBuffer Buf; + +export float4 Test1(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // PREVIEW CHECK TODO: + // expected-error@+1{{something about transposing not supported for rowmajor / colmajor layouts}} + return Mul( + Matrix, MakeInterpretedVector(Input)); +} + +export vector Test2(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // PREVIEW CHECK TODO: + // expected-error@+1{{something about transposing not supported for rowmajor / colmajor layouts}} + return Mul(Matrix, + MakeInterpretedVector(Input)); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-muladd-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-muladd-errors.hlsl new file mode 100644 index 0000000000..f444f81c3a --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-muladd-errors.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify + +#include + +ByteAddressBuffer Buf; + +vector MixUpVectorAndMatrixArguments(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+2{{no matching function for call to 'MulAdd'}} + // expected-note@dx/linalg.h:137{{candidate template ignored: could not match 'MatrixRefImpl' against 'InterpretedVector'}} + return MulAdd(MakeInterpretedVector(Input), Matrix, MakeInterpretedVector(Input)); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-errors.hlsl new file mode 100644 index 0000000000..6f503b367b --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-errors.hlsl @@ -0,0 +1,44 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s -verify + +#include + +RWByteAddressBuffer RWBuf; + +// test for inputs of different size +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'OuterProductAccumulate'}} + // expected-note@dx/linalg.h:161{{candidate template ignored: could not match 0 against 1}} + + OuterProductAccumulate(Input1, Input2, matrix); +} + +// now test for an error when element types differ +export void Test5(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'OuterProductAccumulate'}} + // expected-note@dx/linalg.h:161{{candidate template ignored: could not match 0 against 1}} + + OuterProductAccumulate(Input1, Input2, matrix); +} + +// now test for an error when matrix transpose parameter is true +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'OuterProductAccumulate'}} + // expected-note@dx/linalg.h:161{{candidate template ignored: deduced conflicting types for parameter 'ElTy' ('int' vs. 'unsigned int')}} + + OuterProductAccumulate(Input1, Input2, matrix); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-spirv-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-spirv-errors.hlsl new file mode 100644 index 0000000000..0213103926 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-spirv-errors.hlsl @@ -0,0 +1,19 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types -spirv %s -verify + +// Tests that the header file cannot be included for spirv compilations +// This is a copy of \tools\clang\test\CodeGenDXIL\hlsl\linalg\outerproductaccumulate.hlsl +// except that spirv is targeted + +// expected-error@dx/linalg.h:4{{Cooperative vectors not (yet) supported for SPIRV}} +#include + +RWByteAddressBuffer RWBuf; + +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + OuterProductAccumulate(Input1, Input2, matrix); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/vectoraccumulate-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/vectoraccumulate-errors.hlsl new file mode 100644 index 0000000000..4c8ae6f049 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/vectoraccumulate-errors.hlsl @@ -0,0 +1,16 @@ +// XFAIL: * +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s + +#include + +RWByteAddressBuffer RWBuf; + +export void Test5(vector Input) { + using namespace dx::linalg; + + RWBuf.Store >(0, Input); + + // PREVIEW CHECK TODO: + // CHECK: Something about an error due to illegal conversions + VectorAccumulate(Input, RWBuf, 0); +} From 6bcb1515d475448e0d93c2762ac059a2da521588 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Mon, 28 Apr 2025 10:17:47 +0200 Subject: [PATCH 14/31] Add SetShaderTableIndex+LoadLocalRootConstant tests / host code for local constants / simplifications --- .../unittests/HLSLExec/ExecutionTest.cpp | 64 +++- .../unittests/HLSLExec/ExecutionTest_SER.h | 303 +++++++++++++++++- 2 files changed, 353 insertions(+), 14 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 403b261df1..d921c54489 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -303,6 +303,8 @@ class ExecutionTest { TEST_METHOD(SERGetAttributesTest); TEST_METHOD(SERTraceHitMissNopTest); TEST_METHOD(SERIsMissTest); + TEST_METHOD(SERShaderTableIndexTest); + TEST_METHOD(SERLoadLocalRootTableConstantTest); TEST_METHOD(LifetimeIntrinsicTest) TEST_METHOD(WaveIntrinsicsTest); TEST_METHOD(WaveIntrinsicsDDITest); @@ -2248,11 +2250,12 @@ CComPtr ExecutionTest::RunDXRTest( CComPtr pLocalRootSignature; { CD3DX12_DESCRIPTOR_RANGE bufferRanges[1]; - CD3DX12_ROOT_PARAMETER rootParameters[1]; + CD3DX12_ROOT_PARAMETER rootParameters[2]; bufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 1, 0, 2); // vertexBuffer(t1), indexBuffer(t2) rootParameters[0].InitAsDescriptorTable( _countof(bufferRanges), bufferRanges, D3D12_SHADER_VISIBILITY_ALL); + rootParameters[1].InitAsConstants(4, 1, 0, D3D12_SHADER_VISIBILITY_ALL); CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc; rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, @@ -2316,6 +2319,9 @@ CComPtr ExecutionTest::RunDXRTest( if (useIS) { lib->DefineExport(L"intersection"); } + if (useMesh && useProceduralGeometry) { + lib->DefineExport(L"chAABB"); + } const int maxRecursion = 1; stateObjectDesc.CreateSubobject() @@ -2329,6 +2335,10 @@ CComPtr ExecutionTest::RunDXRTest( stateObjectDesc .CreateSubobject(); globalRootSigSubObj->SetRootSignature(pGlobalRootSignature); + // Set Local Root Signature subobject. + stateObjectDesc.CreateSubobject() + ->SetRootSignature(pLocalRootSignature); + auto exports = stateObjectDesc.CreateSubobject< CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT>(); exports->SetSubobjectToAssociate(*globalRootSigSubObj); @@ -2339,6 +2349,9 @@ CComPtr ExecutionTest::RunDXRTest( if (useIS) { exports->AddExport(L"intersection"); } + if (useMesh && useProceduralGeometry) { + exports->AddExport(L"chAABB"); + } auto hitGroup = stateObjectDesc.CreateSubobject(); @@ -2350,15 +2363,23 @@ CComPtr ExecutionTest::RunDXRTest( } hitGroup->SetHitGroupExport(L"HitGroup"); + if (useMesh && useProceduralGeometry) { + auto hitGroupAABB = + stateObjectDesc.CreateSubobject(); + hitGroupAABB->SetClosestHitShaderImport(L"chAABB"); + hitGroupAABB->SetAnyHitShaderImport(L"anyhit"); + if (useIS) { + hitGroup->SetIntersectionShaderImport(L"intersection"); + hitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); + } + hitGroupAABB->SetHitGroupExport(L"HitGroupAABB"); + } + CComPtr pStateObject; CComPtr pStateObjectProperties; VERIFY_SUCCEEDED( pDevice->CreateStateObject(stateObjectDesc, IID_PPV_ARGS(&pStateObject))); VERIFY_SUCCEEDED(pStateObject->QueryInterface(&pStateObjectProperties)); - stateObjectDesc.CreateSubobject() - ->SetRootSignature(pLocalRootSignature); - stateObjectDesc.CreateSubobject() - ->SetRootSignature(pGlobalRootSignature); // Create SBT ShaderTable shaderTable; @@ -2367,21 +2388,33 @@ CComPtr ExecutionTest::RunDXRTest( 1, // miss count useMesh && useProceduralGeometry ? 2 : 1, // hit group count 1, // ray type count - 2 // dwords per root table + 4 // dwords per root table ); + int localRootConsts[4] = {12, 34, 56, 78}; memcpy(shaderTable.GetRaygenShaderIdPtr(0), pStateObjectProperties->GetShaderIdentifier(L"raygen"), SHADER_ID_SIZE_IN_BYTES); + memcpy(shaderTable.GetRaygenRootTablePtr(0), localRootConsts, + sizeof(localRootConsts)); memcpy(shaderTable.GetMissShaderIdPtr(0, 0), pStateObjectProperties->GetShaderIdentifier(L"miss"), SHADER_ID_SIZE_IN_BYTES); + memcpy(shaderTable.GetMissRootTablePtr(0, 0), localRootConsts, + sizeof(localRootConsts)); memcpy(shaderTable.GetHitGroupShaderIdPtr(0, 0), pStateObjectProperties->GetShaderIdentifier(L"HitGroup"), SHADER_ID_SIZE_IN_BYTES); + memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), localRootConsts, + sizeof(localRootConsts)); + if (useMesh && useProceduralGeometry) { + memcpy(shaderTable.GetHitGroupShaderIdPtr(0, 1), + pStateObjectProperties->GetShaderIdentifier(L"HitGroupAABB"), + SHADER_ID_SIZE_IN_BYTES); + } - auto tbl = pDescriptorHeap->GetGPUDescriptorHandleForHeapStart().ptr; - memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), &tbl, 8); + // auto tbl = pDescriptorHeap->GetGPUDescriptorHandleForHeapStart().ptr; + // memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), &tbl, 8); // Create a command allocator and list. CComPtr pCommandAllocator; @@ -2521,6 +2554,7 @@ CComPtr ExecutionTest::RunDXRTest( pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, &prebuildInfo); + scratchResource.Release(); ReallocScratchResource(pDevice, &scratchResource, prebuildInfo.ScratchDataSizeInBytes); AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes, @@ -2597,6 +2631,7 @@ CComPtr ExecutionTest::RunDXRTest( &prebuildInfo); // Allocate scratch and result buffers for the BLAS + scratchResource.Release(); ReallocScratchResource(pDevice, &scratchResource, prebuildInfo.ScratchDataSizeInBytes); AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes, @@ -2654,6 +2689,9 @@ CComPtr ExecutionTest::RunDXRTest( pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, &prebuildInfo); + scratchResource.Release(); + ReallocScratchResource(pDevice, &scratchResource, + prebuildInfo.ScratchDataSizeInBytes); AllocateBuffer( pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); @@ -2691,6 +2729,9 @@ CComPtr ExecutionTest::RunDXRTest( pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, &prebuildInfo); + scratchResource.Release(); + ReallocScratchResource(pDevice, &scratchResource, + prebuildInfo.ScratchDataSizeInBytes); AllocateBuffer( pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); @@ -2710,6 +2751,13 @@ CComPtr ExecutionTest::RunDXRTest( pCommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&barrier); } + // Set the local root constants. + pCommandList->SetComputeRootSignature(pLocalRootSignature); + pCommandList->SetComputeRoot32BitConstant(1, 12, 0); + pCommandList->SetComputeRoot32BitConstant(1, 34, 1); + pCommandList->SetComputeRoot32BitConstant(1, 56, 2); + pCommandList->SetComputeRoot32BitConstant(1, 78, 3); + shaderTable.Upload(pCommandList); ID3D12DescriptorHeap *const pHeaps[1] = {pDescriptorHeap}; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index a99d55e79c..1c24795a0c 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -586,7 +586,7 @@ struct SceneConstants struct[raypayload] PerRayData { - matrix value : read(caller) : write(miss,closesthit); + float elems[ROWS*COLS] : read(caller) : write(miss,closesthit); }; struct Attrs @@ -627,7 +627,7 @@ void raygen() TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, refPayload); for (int r = 0; r < ROWS; r++) { for (int c = 0; c < COLS; c++) { - testBuffer[id + 2 * (r * COLS + c)] = refPayload.value[r][c]; + testBuffer[id + 2 * (r * COLS + c)] = refPayload.elems[r*COLS + c]; } } @@ -642,8 +642,8 @@ void raygen() } } -matrix getMatIdentity() { - matrix mat = 0; +matrix getMatIdentity() { + matrix mat = 0; mat[0][0] = 1.f; mat[1][1] = 1.f; mat[2][2] = 1.f; @@ -653,7 +653,12 @@ matrix getMatIdentity() { [shader("miss")] void miss(inout PerRayData payload) { - payload.value = MISS_GET_MATRIX(); + matrix mat = MISS_GET_MATRIX(); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + payload.elems[r*COLS + c] = mat[r][c]; + } + } } [shader("anyhit")] @@ -665,7 +670,12 @@ void anyhit(inout PerRayData payload, in Attrs attrs) [shader("closesthit")] void closesthit(inout PerRayData payload, in Attrs attrs) { - payload.value = HIT_GET_MATRIX(); + matrix mat = HIT_GET_MATRIX(); + for (int r = 0; r < ROWS; r++) { + for (int c = 0; c < COLS; c++) { + payload.elems[r*COLS + c] = mat[r][c]; + } + } } )"; @@ -952,6 +962,287 @@ void closesthit(inout PerRayData payload, in Attrs attrs) } } +TEST_F(ExecutionTest, SERShaderTableIndexTest) { + // Test SER with HitObject::SetShaderTableIndex and + // HitObject::GetShaderTableIndex + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + // SER Test + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + dx::MaybeReorderThread(hitObject); + dx::HitObject::Invoke(hitObject, payload); + + if (hitObject.IsHit()) + { + // Alter the hit object to point to a new shader index to hit chAABB. + hitObject.SetShaderTableIndex( 1 ); + dx::HitObject::Invoke( hitObject, payload ); + // Poison the test data if GetShaderTableIndex does not match SetShaderTableIndex. + if (hitObject.GetShaderTableIndex() != 1) + payload.visited = 0; + } + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 2U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; + AcceptHitAndEndSearch(); +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +[shader("closesthit")] +void chAABB(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 8U; +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (bDXRSupported) { + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*mesh*/, + true /*procedural geometry*/, false /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 2); + VERIFY_ARE_EQUAL(histo[2], 4030); + VERIFY_ARE_EQUAL(histo[13], 66); + } +} + +TEST_F(ExecutionTest, SERLoadLocalRootTableConstantTest) { + // Test SER with HitObject::LoadLocalRootTableConstant + static const char *pShader = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint res : read(caller) : write(miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +struct LocalConstants +{ + int c0; + int c1; + int c2; + int c3; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); +ConstantBuffer localConstants : register(b1); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.res = 0; + + // SER Test +#if 1 + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + dx::MaybeReorderThread(hitObject); + int c0 = hitObject.LoadLocalRootTableConstant(0); + int c1 = hitObject.LoadLocalRootTableConstant(4); + int c2 = hitObject.LoadLocalRootTableConstant(8); + int c3 = hitObject.LoadLocalRootTableConstant(12); + int res = c0 | c1 | c2 | c3; +#else + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + int res = payload.res; +#endif + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = res; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.res = localConstants.c0 | localConstants.c1 | localConstants.c2 | localConstants.c3; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + // UNUSED +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.res = localConstants.c0 | localConstants.c1 | localConstants.c2 | localConstants.c3; +} + +)"; + + CComPtr pDevice; + bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " + L"supported device was found."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + bool bDXRSupported = + bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); + + if (!bSM_6_9_Supported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support SM 6.9."); + } + if (!bDXRSupported) { + WEX::Logging::Log::Comment( + L"SER tests skipped, device does not support DXR."); + } + + // Initialize test data. + const int windowSize = 64; + std::vector testData(windowSize * windowSize, 0); + LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + + if (!bDXRSupported) + return; + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), + testData, windowSize, windowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + std::map histo; + for (int val : testData) { + ++histo[val]; + } + VERIFY_ARE_EQUAL(histo.size(), 1); + VERIFY_ARE_EQUAL(histo[126], 4096); +} + TEST_F(ExecutionTest, SERRayQueryTest) { // Test SER RayQuery static const char *pShader = R"( From 144a083d9f112f26cec31811f304d2971d2fb039 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Wed, 30 Apr 2025 12:51:13 +0200 Subject: [PATCH 15/31] Uppercased vars for coding standards / added CreateDXRDevice helper / simplified tests/ added SERInvokeNoSBTTest --- tools/clang/unittests/HLSLExec/DXRUtil.h | 271 ++-- .../unittests/HLSLExec/ExecutionTest.cpp | 1069 ++++++++-------- .../unittests/HLSLExec/ExecutionTest_SER.h | 1108 ++++++++--------- 3 files changed, 1169 insertions(+), 1279 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/DXRUtil.h b/tools/clang/unittests/HLSLExec/DXRUtil.h index 1f008885cf..54828f4857 100644 --- a/tools/clang/unittests/HLSLExec/DXRUtil.h +++ b/tools/clang/unittests/HLSLExec/DXRUtil.h @@ -10,216 +10,213 @@ // // /////////////////////////////////////////////////////////////////////////////// +#pragma once + //= DXR Utility //============================================================================ #define SHADER_ID_SIZE_IN_BYTES 32 #ifndef ROUND_UP -#define ROUND_UP(v, powerOf2Alignment) \ - (((v) + (powerOf2Alignment)-1) & ~((powerOf2Alignment)-1)) +#define ROUND_UP(v, PowerOf2Alignment) \ + (((v) + (PowerOf2Alignment)-1) & ~((PowerOf2Alignment)-1)) #endif struct SceneConsts { - DirectX::XMFLOAT4 eye; + DirectX::XMFLOAT4 Eye; DirectX::XMFLOAT4 U; DirectX::XMFLOAT4 V; DirectX::XMFLOAT4 W; - float sceneScale; - unsigned windowSize[2]; - int rayFlags; + float SceneScale; + unsigned WindowSize[2]; + int RayFlags; }; struct Instance { - D3D12_RAYTRACING_GEOMETRY_TYPE type; - DirectX::XMFLOAT4X4 matrix; - UINT geometryCount; - UINT bottomASIdx; - UINT instanceID; - UINT mask; - UINT flags; + D3D12_RAYTRACING_GEOMETRY_TYPE Type; + DirectX::XMFLOAT4X4 Matrix; + UINT GeometryCount; + UINT BottomASIdx; + UINT InstanceID; + UINT Mask; + UINT Flags; }; class ShaderTable { public: - void Init(ID3D12Device *device, int raygenCount, int missCount, - int hitGroupCount, int rayTypeCount, int rootTableDwords) { - m_rayTypeCount = rayTypeCount; - m_raygenCount = raygenCount; - m_missCount = missCount * rayTypeCount; - m_hitGroupCount = hitGroupCount * rayTypeCount; - m_rootTableSizeInBytes = rootTableDwords * 4; - m_shaderRecordSizeInBytes = - ROUND_UP(m_rootTableSizeInBytes + SHADER_ID_SIZE_IN_BYTES, + void Init(ID3D12Device *Device, int RaygenCount, int MissCount, + int HitGroupCount, int RayTypeCount, int RootTableDwords) { + RayTypeCount = RayTypeCount; + RaygenCount = RaygenCount; + MissCount = MissCount * RayTypeCount; + HitGroupCount = HitGroupCount * RayTypeCount; + RootTableSizeInBytes = RootTableDwords * 4; + ShaderRecordSizeInBytes = + ROUND_UP(RootTableSizeInBytes + SHADER_ID_SIZE_IN_BYTES, D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT); - m_missStartIdx = m_raygenCount; - m_hitGroupStartIdx = m_missStartIdx + m_missCount; + MissStartIdx = RaygenCount; + HitGroupStartIdx = MissStartIdx + MissCount; - const int m_totalSizeInBytes = - (m_raygenCount + m_missCount + m_hitGroupCount) * - m_shaderRecordSizeInBytes; + const int TotalSizeInBytes = + (RaygenCount + MissCount + HitGroupCount) * ShaderRecordSizeInBytes; - D3D12_RESOURCE_DESC desc = CD3DX12_RESOURCE_DESC::Buffer( - m_totalSizeInBytes, D3D12_RESOURCE_FLAG_NONE, + D3D12_RESOURCE_DESC Desc = CD3DX12_RESOURCE_DESC::Buffer( + TotalSizeInBytes, D3D12_RESOURCE_FLAG_NONE, std::max(D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT, D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT)); - CD3DX12_HEAP_PROPERTIES heap = + CD3DX12_HEAP_PROPERTIES Heap = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); - VERIFY_SUCCEEDED(device->CreateCommittedResource( - &heap, D3D12_HEAP_FLAG_NONE, &desc, + VERIFY_SUCCEEDED(Device->CreateCommittedResource( + &Heap, D3D12_HEAP_FLAG_NONE, &Desc, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, nullptr, - IID_PPV_ARGS(&m_sbtResource))); - m_sbtResource->SetName(L"SBT Resource Heap"); - CD3DX12_HEAP_PROPERTIES upload = + IID_PPV_ARGS(&SBTResource))); + SBTResource->SetName(L"SBT Resource Heap"); + CD3DX12_HEAP_PROPERTIES Upload = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); - VERIFY_SUCCEEDED(device->CreateCommittedResource( - &upload, D3D12_HEAP_FLAG_NONE, &desc, D3D12_RESOURCE_STATE_GENERIC_READ, - nullptr, IID_PPV_ARGS(&m_sbtUploadResource))); - m_sbtUploadResource->SetName(L"SBT Upload Heap"); + VERIFY_SUCCEEDED(Device->CreateCommittedResource( + &Upload, D3D12_HEAP_FLAG_NONE, &Desc, D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, IID_PPV_ARGS(&SBTUploadResource))); + SBTUploadResource->SetName(L"SBT Upload Heap"); - VERIFY_SUCCEEDED(m_sbtUploadResource->Map(0, nullptr, (void **)&m_hostPtr)); + VERIFY_SUCCEEDED(SBTUploadResource->Map(0, nullptr, (void **)&HostPtr)); } - void Upload(ID3D12GraphicsCommandList *cmdlist) { - CD3DX12_RESOURCE_BARRIER barrier = CD3DX12_RESOURCE_BARRIER::Transition( - m_sbtResource, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, + void Upload(ID3D12GraphicsCommandList *CmdList) { + CD3DX12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition( + SBTResource, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, D3D12_RESOURCE_STATE_COPY_DEST); - cmdlist->ResourceBarrier(1, &barrier); - cmdlist->CopyResource(m_sbtResource, m_sbtUploadResource); - CD3DX12_RESOURCE_BARRIER barrier2 = CD3DX12_RESOURCE_BARRIER::Transition( - m_sbtResource, D3D12_RESOURCE_STATE_COPY_DEST, + CmdList->ResourceBarrier(1, &Barrier); + CmdList->CopyResource(SBTResource, SBTUploadResource); + CD3DX12_RESOURCE_BARRIER Barrier2 = CD3DX12_RESOURCE_BARRIER::Transition( + SBTResource, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE); - cmdlist->ResourceBarrier(1, &barrier2); + CmdList->ResourceBarrier(1, &Barrier2); } - int GetShaderRecordSizeInBytes() { return m_shaderRecordSizeInBytes; } + int GetShaderRecordSizeInBytes() { return ShaderRecordSizeInBytes; } - int GetRaygenShaderRecordIdx(int idx) { return idx; } - int GetMissShaderRecordIdx(int idx, int rayType) { - return m_missStartIdx + idx * m_rayTypeCount + rayType; + int GetRaygenShaderRecordIdx(int Idx) { return Idx; } + int GetMissShaderRecordIdx(int Idx, int RayType) { + return MissStartIdx + Idx * RayTypeCount + RayType; } - int GetHitGroupShaderRecordIdx(int idx, int rayType) { - return m_hitGroupStartIdx + idx * m_rayTypeCount + rayType; + int GetHitGroupShaderRecordIdx(int Idx, int RayType) { + return HitGroupStartIdx + Idx * RayTypeCount + RayType; } - void *GetRaygenShaderIdPtr(int idx) { - return m_hostPtr + - GetRaygenShaderRecordIdx(idx) * m_shaderRecordSizeInBytes; + void *GetRaygenShaderIdPtr(int Idx) { + return HostPtr + GetRaygenShaderRecordIdx(Idx) * ShaderRecordSizeInBytes; } - void *GetMissShaderIdPtr(int idx, int rayType) { - return m_hostPtr + - GetMissShaderRecordIdx(idx, rayType) * m_shaderRecordSizeInBytes; + void *GetMissShaderIdPtr(int Idx, int RayType) { + return HostPtr + + GetMissShaderRecordIdx(Idx, RayType) * ShaderRecordSizeInBytes; } - void *GetHitGroupShaderIdPtr(int idx, int rayType) { - return m_hostPtr + - GetHitGroupShaderRecordIdx(idx, rayType) * m_shaderRecordSizeInBytes; + void *GetHitGroupShaderIdPtr(int Idx, int RayType) { + return HostPtr + + GetHitGroupShaderRecordIdx(Idx, RayType) * ShaderRecordSizeInBytes; } - void *GetRaygenRootTablePtr(int idx) { - return (char *)GetRaygenShaderIdPtr(idx) + SHADER_ID_SIZE_IN_BYTES; + void *GetRaygenRootTablePtr(int Idx) { + return (char *)GetRaygenShaderIdPtr(Idx) + SHADER_ID_SIZE_IN_BYTES; } - void *GetMissRootTablePtr(int idx, int rayType) { - return (char *)GetMissShaderIdPtr(idx, rayType) + SHADER_ID_SIZE_IN_BYTES; + void *GetMissRootTablePtr(int Idx, int RayType) { + return (char *)GetMissShaderIdPtr(Idx, RayType) + SHADER_ID_SIZE_IN_BYTES; } - void *GetHitGroupRootTablePtr(int idx, int rayType) { - return (char *)GetHitGroupShaderIdPtr(idx, rayType) + + void *GetHitGroupRootTablePtr(int Idx, int RayType) { + return (char *)GetHitGroupShaderIdPtr(Idx, RayType) + SHADER_ID_SIZE_IN_BYTES; } - int GetRaygenRangeInBytes() { - return m_raygenCount * m_shaderRecordSizeInBytes; - } - int GetMissRangeInBytes() { return m_missCount * m_shaderRecordSizeInBytes; } + int GetRaygenRangeInBytes() { return RaygenCount * ShaderRecordSizeInBytes; } + int GetMissRangeInBytes() { return MissCount * ShaderRecordSizeInBytes; } int GetHitGroupRangeInBytes() { - return m_hitGroupCount * m_shaderRecordSizeInBytes; + return HitGroupCount * ShaderRecordSizeInBytes; } D3D12_GPU_VIRTUAL_ADDRESS GetRaygenStartGpuVA() { - return m_sbtResource->GetGPUVirtualAddress() + - GetRaygenShaderRecordIdx(0) * m_shaderRecordSizeInBytes; + return SBTResource->GetGPUVirtualAddress() + + GetRaygenShaderRecordIdx(0) * ShaderRecordSizeInBytes; } D3D12_GPU_VIRTUAL_ADDRESS GetMissStartGpuVA() { - return m_sbtResource->GetGPUVirtualAddress() + - GetMissShaderRecordIdx(0, 0) * m_shaderRecordSizeInBytes; + return SBTResource->GetGPUVirtualAddress() + + GetMissShaderRecordIdx(0, 0) * ShaderRecordSizeInBytes; } D3D12_GPU_VIRTUAL_ADDRESS GetHitGroupStartGpuVA() { - return m_sbtResource->GetGPUVirtualAddress() + - GetHitGroupShaderRecordIdx(0, 0) * m_shaderRecordSizeInBytes; + return SBTResource->GetGPUVirtualAddress() + + GetHitGroupShaderRecordIdx(0, 0) * ShaderRecordSizeInBytes; } private: - CComPtr m_sbtResource; - CComPtr m_sbtUploadResource; - char *m_hostPtr = nullptr; - int m_rayTypeCount = 0; - int m_raygenCount = 0; - int m_missCount = 0; - int m_hitGroupCount = 0; - int m_rootTableSizeInBytes = 0; - int m_shaderRecordSizeInBytes = 0; - int m_missStartIdx = 0; - int m_hitGroupStartIdx = 0; + CComPtr SBTResource; + CComPtr SBTUploadResource; + char *HostPtr = nullptr; + int RayTypeCount = 0; + int RaygenCount = 0; + int MissCount = 0; + int HitGroupCount = 0; + int RootTableSizeInBytes = 0; + int ShaderRecordSizeInBytes = 0; + int MissStartIdx = 0; + int HitGroupStartIdx = 0; }; //----------------------------------------------------------------------------- void AllocateBuffer( - ID3D12Device *pDevice, UINT64 bufferSize, ID3D12Resource **ppResource, - bool allowUAV = false, - D3D12_RESOURCE_STATES initialResourceState = D3D12_RESOURCE_STATE_COMMON, - const wchar_t *resourceName = nullptr) { - auto uploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); - auto bufferDesc = CD3DX12_RESOURCE_DESC::Buffer( - bufferSize, allowUAV ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS + ID3D12Device *Device, UINT64 BufferSize, ID3D12Resource **Resource, + bool AllowUAV = false, + D3D12_RESOURCE_STATES InitialResourceState = D3D12_RESOURCE_STATE_COMMON, + const wchar_t *ResourceName = nullptr) { + auto UploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); + auto BufferDesc = CD3DX12_RESOURCE_DESC::Buffer( + BufferSize, AllowUAV ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS : D3D12_RESOURCE_FLAG_NONE); - VERIFY_SUCCEEDED(pDevice->CreateCommittedResource( - &uploadHeapProperties, D3D12_HEAP_FLAG_NONE, &bufferDesc, - initialResourceState, nullptr, IID_PPV_ARGS(ppResource))); - if (resourceName) { - (*ppResource)->SetName(resourceName); + VERIFY_SUCCEEDED(Device->CreateCommittedResource( + &UploadHeapProperties, D3D12_HEAP_FLAG_NONE, &BufferDesc, + InitialResourceState, nullptr, IID_PPV_ARGS(Resource))); + if (ResourceName) { + (*Resource)->SetName(ResourceName); } } //----------------------------------------------------------------------------- -void ReallocScratchResource(ID3D12Device *pDevice, ID3D12Resource **ppResource, - UINT64 nbytes) { - - if (!(*ppResource) || (*ppResource)->GetDesc().Width < nbytes) { - AllocateBuffer(pDevice, nbytes, ppResource, true, +void ReallocScratchResource(ID3D12Device *Device, ID3D12Resource **Resource, + UINT64 NBytes) { + if (!(*Resource) || (*Resource)->GetDesc().Width < NBytes) { + AllocateBuffer(Device, NBytes, Resource, true, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, L"scratchResource"); } } //----------------------------------------------------------------------------- -void AllocateUploadBuffer(ID3D12Device *pDevice, const void *pData, - UINT64 datasize, ID3D12Resource **ppResource, - const wchar_t *resourceName = nullptr) { - auto uploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); - auto bufferDesc = CD3DX12_RESOURCE_DESC::Buffer(datasize); - VERIFY_SUCCEEDED(pDevice->CreateCommittedResource( - &uploadHeapProperties, D3D12_HEAP_FLAG_NONE, &bufferDesc, - D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(ppResource))); - if (resourceName) { - (*ppResource)->SetName(resourceName); - } - void *pMappedData; - VERIFY_SUCCEEDED((*ppResource)->Map(0, nullptr, &pMappedData)); - memcpy(pMappedData, pData, datasize); - (*ppResource)->Unmap(0, nullptr); +void AllocateUploadBuffer(ID3D12Device *Device, const void *Data, + UINT64 DataSize, ID3D12Resource **Resource, + const wchar_t *ResourceName = nullptr) { + auto UploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); + auto BufferDesc = CD3DX12_RESOURCE_DESC::Buffer(DataSize); + VERIFY_SUCCEEDED(Device->CreateCommittedResource( + &UploadHeapProperties, D3D12_HEAP_FLAG_NONE, &BufferDesc, + D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(Resource))); + if (ResourceName) { + (*Resource)->SetName(ResourceName); + } + void *MappedData; + VERIFY_SUCCEEDED((*Resource)->Map(0, nullptr, &MappedData)); + memcpy(MappedData, Data, DataSize); + (*Resource)->Unmap(0, nullptr); } //----------------------------------------------------------------------------- -void AllocateBufferFromUpload(ID3D12Device *pDevice, - ID3D12GraphicsCommandList *pCommandList, - ID3D12Resource *uploadSource, - ID3D12Resource **ppResource, - D3D12_RESOURCE_STATES targetResourceState, - const wchar_t *resourceName = nullptr) { - const bool allowUAV = - targetResourceState == D3D12_RESOURCE_STATE_UNORDERED_ACCESS; - AllocateBuffer(pDevice, uploadSource->GetDesc().Width, ppResource, allowUAV, - D3D12_RESOURCE_STATE_COPY_DEST, resourceName); - pCommandList->CopyResource(*ppResource, uploadSource); - CD3DX12_RESOURCE_BARRIER barrier = CD3DX12_RESOURCE_BARRIER::Transition( - *ppResource, D3D12_RESOURCE_STATE_COPY_DEST, targetResourceState); - pCommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&barrier); +void AllocateBufferFromUpload(ID3D12Device *Device, + ID3D12GraphicsCommandList *CommandList, + ID3D12Resource *UploadSource, + ID3D12Resource **Resource, + D3D12_RESOURCE_STATES TargetResourceState, + const wchar_t *ResourceName = nullptr) { + const bool AllowUAV = + TargetResourceState == D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + AllocateBuffer(Device, UploadSource->GetDesc().Width, Resource, AllowUAV, + D3D12_RESOURCE_STATE_COPY_DEST, ResourceName); + CommandList->CopyResource(*Resource, UploadSource); + CD3DX12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition( + *Resource, D3D12_RESOURCE_STATE_COPY_DEST, TargetResourceState); + CommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&Barrier); } //= DXR Utility diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index d921c54489..25c7170316 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -60,7 +60,6 @@ #include "ShaderOpTest.h" #include #include - #include "DXRUtil.h" // clang-format on @@ -294,17 +293,6 @@ class ExecutionTest { TEST_METHOD(SaturateTest); TEST_METHOD(SignTest); TEST_METHOD(Int64Test); - TEST_METHOD(SERBasicTest); - TEST_METHOD(SERScalarGetterTest); - TEST_METHOD(SERVectorGetterTest); - TEST_METHOD(SERMatrixGetterTest); - TEST_METHOD(SERRayQueryTest); - TEST_METHOD(SERIntersectionTest); - TEST_METHOD(SERGetAttributesTest); - TEST_METHOD(SERTraceHitMissNopTest); - TEST_METHOD(SERIsMissTest); - TEST_METHOD(SERShaderTableIndexTest); - TEST_METHOD(SERLoadLocalRootTableConstantTest); TEST_METHOD(LifetimeIntrinsicTest) TEST_METHOD(WaveIntrinsicsTest); TEST_METHOD(WaveIntrinsicsDDITest); @@ -514,6 +502,20 @@ class ExecutionTest { L"Table:ShaderOpArithTable.xml#PackUnpackOpTable") END_TEST_METHOD() + // Shader Execution Reordering tests + TEST_METHOD(SERBasicTest); + TEST_METHOD(SERScalarGetterTest); + TEST_METHOD(SERVectorGetterTest); + TEST_METHOD(SERMatrixGetterTest); + TEST_METHOD(SERRayQueryTest); + TEST_METHOD(SERIntersectionTest); + TEST_METHOD(SERGetAttributesTest); + TEST_METHOD(SERTraceHitMissNopTest); + TEST_METHOD(SERIsMissTest); + TEST_METHOD(SERShaderTableIndexTest); + TEST_METHOD(SERLoadLocalRootTableConstantTest); + TEST_METHOD(SERInvokeNoSBTTest); + dxc::DxcDllSupport m_support; bool m_D3DInitCompleted = false; @@ -1930,12 +1932,15 @@ class ExecutionTest { CComPtr &pRootSignature, LPCWSTR pTargetProfile, LPCWSTR *pOptions, int numOptions); - CComPtr - RunDXRTest(ID3D12Device *pDevice0, LPCSTR shader, - D3D_SHADER_MODEL shaderModel, LPCWSTR *pOptions, int numOptions, - std::vector &testData, int windowWidth, int windowHeight, - bool useMesh, bool useProceduralGeometry, bool useIS, - int payloadCount = 1, int attributeCount = 2); + bool CreateDXRDevice(ID3D12Device **ppDevice, D3D_SHADER_MODEL testModel, + bool skipUnsupported); + CComPtr RunDXRTest(ID3D12Device *Device0, LPCSTR ShaderSrc, + LPCWSTR TargetProfile, LPCWSTR *Options, + int NumOptions, std::vector &TestData, + int WindowWidth, int WindowHeight, + bool UseMesh, bool UseProceduralGeometry, + bool UseIS, int PayloadCount = 1, + int AttributeCount = 2); void SetDescriptorHeap(ID3D12GraphicsCommandList *pCommandList, ID3D12DescriptorHeap *pHeap) { @@ -2097,53 +2102,42 @@ void ExecutionTest::RunRWByteBufferComputeTest(ID3D12Device *pDevice, WaitForSignal(pCommandQueue, FO); } -CComPtr ExecutionTest::RunDXRTest( - ID3D12Device *pDevice0, LPCSTR shader, D3D_SHADER_MODEL shaderModel, - LPCWSTR *pOptions, int numOptions, std::vector &testData, - int windowWidth, int windowHeight, bool useMesh, bool useProceduralGeometry, - bool useIS, int payloadCount, int attributeCount) { - CComPtr pDevice; - VERIFY_SUCCEEDED(pDevice0->QueryInterface(IID_PPV_ARGS(&pDevice))); +bool ExecutionTest::CreateDXRDevice(ID3D12Device **ppDevice, + D3D_SHADER_MODEL testModel, + bool skipUnsupported) { + bool SupportsSM = CreateDevice(ppDevice, testModel, skipUnsupported); + if (!SupportsSM) + return false; - LPCWSTR pTargetProfile; - switch (shaderModel) { - case D3D_SHADER_MODEL_6_9: - pTargetProfile = L"lib_6_9"; - break; - case D3D_SHADER_MODEL_6_8: - pTargetProfile = L"lib_6_8"; - break; - case D3D_SHADER_MODEL_6_7: - pTargetProfile = L"lib_6_7"; - break; - case D3D_SHADER_MODEL_6_6: - pTargetProfile = L"lib_6_6"; - break; - case D3D_SHADER_MODEL_6_5: - pTargetProfile = L"lib_6_5"; - break; - case D3D_SHADER_MODEL_6_4: - pTargetProfile = L"lib_6_4"; - break; - case D3D_SHADER_MODEL_6_3: - pTargetProfile = L"lib_6_3"; - break; - default: - // DXR capable shader model not found. - LogErrorFmt(L"DXR capable shader model not found."); - return nullptr; + if (DoesDeviceSupportRayTracing(*ppDevice)) + return true; + + if (skipUnsupported) { + WEX::Logging::Log::Comment( + L"DXR test skipped: device does not support DXR."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); } + return false; +} + +CComPtr ExecutionTest::RunDXRTest( + ID3D12Device *Device0, LPCSTR ShaderSrc, LPCWSTR TargetProfile, + LPCWSTR *Options, int NumOptions, std::vector &TestData, + int WindowWidth, int WindowHeight, bool UseMesh, bool UseProceduralGeometry, + bool UseIS, int PayloadCount, int AttributeCount) { + CComPtr Device; + VERIFY_SUCCEEDED(Device0->QueryInterface(IID_PPV_ARGS(&Device))); FenceObj FO; - InitFenceObj(pDevice, &FO); + InitFenceObj(Device, &FO); // Setup Resources - CComPtr pTestBuffer; - CComPtr pTestBufferRead; - CComPtr pSceneConstantBuffer; + CComPtr TestBuffer; + CComPtr TestBufferRead; + CComPtr SceneConstantBuffer; // Descriptor heap - CComPtr pDescriptorHeap; + CComPtr DescriptorHeap; { // // UAV descriptor heap layout: @@ -2151,693 +2145,684 @@ CComPtr ExecutionTest::RunDXRTest( // 1 - vertex buffer SRV // 2 - index buffer SRV // - D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {}; - descriptorHeapDesc.NumDescriptors = 3; - descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; - descriptorHeapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - pDevice->CreateDescriptorHeap(&descriptorHeapDesc, - IID_PPV_ARGS(&pDescriptorHeap)); - pDescriptorHeap->SetName(L"Descriptor Heap"); - } - int descriptorSize = pDevice->GetDescriptorHandleIncrementSize( + D3D12_DESCRIPTOR_HEAP_DESC DescriptorHeapDesc = {}; + DescriptorHeapDesc.NumDescriptors = 3; + DescriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + DescriptorHeapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + Device->CreateDescriptorHeap(&DescriptorHeapDesc, + IID_PPV_ARGS(&DescriptorHeap)); + DescriptorHeap->SetName(L"Descriptor Heap"); + } + int DescriptorSize = Device->GetDescriptorHandleIncrementSize( D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV); // Testbuffer { - auto resDesc = CD3DX12_RESOURCE_DESC::Buffer( - testData.size() * sizeof(int), + auto ResDesc = CD3DX12_RESOURCE_DESC::Buffer( + TestData.size() * sizeof(int), D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); - auto defaultHeapProperties = + auto DefaultHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); - VERIFY_SUCCEEDED(pDevice->CreateCommittedResource( - &defaultHeapProperties, D3D12_HEAP_FLAG_NONE, &resDesc, + VERIFY_SUCCEEDED(Device->CreateCommittedResource( + &DefaultHeapProperties, D3D12_HEAP_FLAG_NONE, &ResDesc, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, nullptr, - IID_PPV_ARGS(&pTestBuffer))); - pTestBuffer->SetName(L"Test Buffer"); + IID_PPV_ARGS(&TestBuffer))); + TestBuffer->SetName(L"Test Buffer"); - const int descriptorIndex = 0; - D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + const int DescriptorIndex = 0; + D3D12_CPU_DESCRIPTOR_HANDLE CPUDescriptorHandle = CD3DX12_CPU_DESCRIPTOR_HANDLE( - pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), - descriptorIndex, descriptorSize); + DescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + DescriptorIndex, DescriptorSize); D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {}; UAVDesc.Format = DXGI_FORMAT_UNKNOWN; UAVDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER; UAVDesc.Buffer.FirstElement = 0; - UAVDesc.Buffer.NumElements = (UINT)testData.size(); + UAVDesc.Buffer.NumElements = (UINT)TestData.size(); UAVDesc.Buffer.StructureByteStride = sizeof(int); UAVDesc.Buffer.CounterOffsetInBytes = 0; UAVDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_NONE; - pDevice->CreateUnorderedAccessView(pTestBuffer, nullptr, &UAVDesc, - cpuDescriptorHandle); + Device->CreateUnorderedAccessView(TestBuffer, nullptr, &UAVDesc, + CPUDescriptorHandle); } // Testbuffer Readback { - CD3DX12_HEAP_PROPERTIES readHeap(D3D12_HEAP_TYPE_READBACK); - CD3DX12_RESOURCE_DESC readDesc( - CD3DX12_RESOURCE_DESC::Buffer(testData.size() * sizeof(int))); - pDevice->CreateCommittedResource(&readHeap, D3D12_HEAP_FLAG_NONE, &readDesc, - D3D12_RESOURCE_STATE_COPY_DEST, nullptr, - IID_PPV_ARGS(&pTestBufferRead)); + CD3DX12_HEAP_PROPERTIES ReadHeap(D3D12_HEAP_TYPE_READBACK); + CD3DX12_RESOURCE_DESC ReadDesc( + CD3DX12_RESOURCE_DESC::Buffer(TestData.size() * sizeof(int))); + Device->CreateCommittedResource(&ReadHeap, D3D12_HEAP_FLAG_NONE, &ReadDesc, + D3D12_RESOURCE_STATE_COPY_DEST, nullptr, + IID_PPV_ARGS(&TestBufferRead)); } // Create CBV resource (sceneConstantBuffer), index 1 { - const int descriptorIndex = 1; - const UINT constantBufferSize = + const int DescriptorIndex = 1; + const UINT ConstantBufferSize = (sizeof(SceneConsts) + (D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT - 1)) & ~(D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT - 1); // must be a multiple 256 bytes - D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + D3D12_CPU_DESCRIPTOR_HANDLE CPUDescriptorHandle = CD3DX12_CPU_DESCRIPTOR_HANDLE( - pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), - descriptorIndex, descriptorSize); - auto resDesc = CD3DX12_RESOURCE_DESC::Buffer(constantBufferSize); - auto uploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); - pDevice->CreateCommittedResource(&uploadHeapProperties, - D3D12_HEAP_FLAG_NONE, &resDesc, - D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, - IID_PPV_ARGS(&pSceneConstantBuffer)); - - UINT8 *sceneConstantBufferWO; - CD3DX12_RANGE readRange( + DescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + DescriptorIndex, DescriptorSize); + auto ResDesc = CD3DX12_RESOURCE_DESC::Buffer(ConstantBufferSize); + auto UploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); + Device->CreateCommittedResource(&UploadHeapProperties, D3D12_HEAP_FLAG_NONE, + &ResDesc, D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, + IID_PPV_ARGS(&SceneConstantBuffer)); + + UINT8 *SceneConstantBufferWO; + CD3DX12_RANGE ReadRange( 0, 0); // We do not intend to read from this resource on the CPU. - pSceneConstantBuffer->Map( - 0, &readRange, reinterpret_cast(&sceneConstantBufferWO)); + SceneConstantBuffer->Map(0, &ReadRange, + reinterpret_cast(&SceneConstantBufferWO)); // Setup Scene Constants - SceneConsts sceneConsts = { + SceneConsts SceneConsts = { {25.f, -25.f, 700.f, 0.f}, {536.f, 0.f, 0.f, 0.f}, {0.f, 301.f, 0.f, 0.f}, {0.f, 0., -699.f, 0.f}, 100.f, - {(unsigned int)windowWidth, (unsigned int)windowHeight}, + {(unsigned int)WindowWidth, (unsigned int)WindowHeight}, 0x00}; - memcpy(sceneConstantBufferWO, &sceneConsts, sizeof(SceneConsts)); - pSceneConstantBuffer->Unmap(0, nullptr); + memcpy(SceneConstantBufferWO, &SceneConsts, sizeof(SceneConsts)); + SceneConstantBuffer->Unmap(0, nullptr); - D3D12_CONSTANT_BUFFER_VIEW_DESC desc = {}; - desc.SizeInBytes = constantBufferSize; - desc.BufferLocation = pSceneConstantBuffer->GetGPUVirtualAddress(); - pDevice->CreateConstantBufferView(&desc, cpuDescriptorHandle); + D3D12_CONSTANT_BUFFER_VIEW_DESC Desc = {}; + Desc.SizeInBytes = ConstantBufferSize; + Desc.BufferLocation = SceneConstantBuffer->GetGPUVirtualAddress(); + Device->CreateConstantBufferView(&Desc, CPUDescriptorHandle); } // Local (SBT) root signature - CComPtr pLocalRootSignature; + CComPtr LocalRootSignature; { - CD3DX12_DESCRIPTOR_RANGE bufferRanges[1]; - CD3DX12_ROOT_PARAMETER rootParameters[2]; - bufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 1, 0, + CD3DX12_DESCRIPTOR_RANGE BufferRanges[1]; + CD3DX12_ROOT_PARAMETER RootParameters[2]; + BufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 1, 0, 2); // vertexBuffer(t1), indexBuffer(t2) - rootParameters[0].InitAsDescriptorTable( - _countof(bufferRanges), bufferRanges, D3D12_SHADER_VISIBILITY_ALL); - rootParameters[1].InitAsConstants(4, 1, 0, D3D12_SHADER_VISIBILITY_ALL); + RootParameters[0].InitAsDescriptorTable( + _countof(BufferRanges), BufferRanges, D3D12_SHADER_VISIBILITY_ALL); + RootParameters[1].InitAsConstants(4, 1, 0, D3D12_SHADER_VISIBILITY_ALL); - CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc; - rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, + CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc; + RootSignatureDesc.Init(_countof(RootParameters), RootParameters, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_LOCAL_ROOT_SIGNATURE); - CComPtr signature; - CComPtr error; + CComPtr Signature; + CComPtr Error; VERIFY_SUCCEEDED(D3D12SerializeRootSignature( - &rootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error)); - VERIFY_SUCCEEDED(pDevice->CreateRootSignature( - 0, signature->GetBufferPointer(), signature->GetBufferSize(), - IID_PPV_ARGS(&pLocalRootSignature))); - pLocalRootSignature->SetName(L"Local Root Signature"); + &RootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &Signature, &Error)); + VERIFY_SUCCEEDED(Device->CreateRootSignature( + 0, Signature->GetBufferPointer(), Signature->GetBufferSize(), + IID_PPV_ARGS(&LocalRootSignature))); + LocalRootSignature->SetName(L"Local Root Signature"); } // Global root signature - CComPtr pGlobalRootSignature; + CComPtr GlobalRootSignature; { - CD3DX12_DESCRIPTOR_RANGE bufferRanges[1]; - CD3DX12_ROOT_PARAMETER rootParameters[3]; - bufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, + CD3DX12_DESCRIPTOR_RANGE BufferRanges[1]; + CD3DX12_ROOT_PARAMETER RootParameters[3]; + BufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0); // testBuffer(u0) - rootParameters[0].InitAsShaderResourceView( + RootParameters[0].InitAsShaderResourceView( 0, 0, D3D12_SHADER_VISIBILITY_ALL); // accelStruct(t0) - rootParameters[1].InitAsConstantBufferView(0); // sceneConstants(b0) - rootParameters[2].InitAsDescriptorTable( - _countof(bufferRanges), bufferRanges, D3D12_SHADER_VISIBILITY_ALL); + RootParameters[1].InitAsConstantBufferView(0); // sceneConstants(b0) + RootParameters[2].InitAsDescriptorTable( + _countof(BufferRanges), BufferRanges, D3D12_SHADER_VISIBILITY_ALL); - CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc; - rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, + CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc; + RootSignatureDesc.Init(_countof(RootParameters), RootParameters, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_NONE); - CComPtr signature; - CComPtr error; + CComPtr Signature; + CComPtr Error; VERIFY_SUCCEEDED(D3D12SerializeRootSignature( - &rootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error)); - VERIFY_SUCCEEDED(pDevice->CreateRootSignature( - 0, signature->GetBufferPointer(), signature->GetBufferSize(), - IID_PPV_ARGS(&pGlobalRootSignature))); - pGlobalRootSignature->SetName(L"Global Root Signature"); + &RootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &Signature, &Error)); + VERIFY_SUCCEEDED(Device->CreateRootSignature( + 0, Signature->GetBufferPointer(), Signature->GetBufferSize(), + IID_PPV_ARGS(&GlobalRootSignature))); + GlobalRootSignature->SetName(L"Global Root Signature"); } // Create command queue. - CComPtr pCommandQueue; - CreateCommandQueue(pDevice, L"RunDXRTest Command Queue", &pCommandQueue, + CComPtr CommandQueue; + CreateCommandQueue(Device, L"RunDXRTest Command Queue", &CommandQueue, D3D12_COMMAND_LIST_TYPE_DIRECT); // Compile raygen shader. - CComPtr pShaderLib; - CompileFromText(shader, L"raygen", pTargetProfile, &pShaderLib, pOptions, - numOptions); + CComPtr ShaderLib; + CompileFromText(ShaderSrc, L"raygen", TargetProfile, &ShaderLib, Options, + NumOptions); // Describe and create the RT pipeline state object (RTPSO). - CD3DX12_STATE_OBJECT_DESC stateObjectDesc( + CD3DX12_STATE_OBJECT_DESC StateObjectDesc( D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE); - auto lib = stateObjectDesc.CreateSubobject(); - CD3DX12_SHADER_BYTECODE byteCode(pShaderLib); - lib->SetDXILLibrary(&byteCode); - lib->DefineExport(L"raygen"); - lib->DefineExport(L"closesthit"); - lib->DefineExport(L"anyhit"); - lib->DefineExport(L"miss"); - if (useIS) { - lib->DefineExport(L"intersection"); - } - if (useMesh && useProceduralGeometry) { - lib->DefineExport(L"chAABB"); - } - - const int maxRecursion = 1; - stateObjectDesc.CreateSubobject() - ->Config(payloadCount * sizeof(float), attributeCount * sizeof(float)); - stateObjectDesc + auto Lib = StateObjectDesc.CreateSubobject(); + CD3DX12_SHADER_BYTECODE ByteCode(ShaderLib); + Lib->SetDXILLibrary(&ByteCode); + Lib->DefineExport(L"raygen"); + Lib->DefineExport(L"closesthit"); + Lib->DefineExport(L"anyhit"); + Lib->DefineExport(L"miss"); + if (UseIS) + Lib->DefineExport(L"intersection"); + if (UseMesh && UseProceduralGeometry) + Lib->DefineExport(L"chAABB"); + + const int MaxRecursion = 1; + StateObjectDesc.CreateSubobject() + ->Config(PayloadCount * sizeof(float), AttributeCount * sizeof(float)); + StateObjectDesc .CreateSubobject() - ->Config(maxRecursion); + ->Config(MaxRecursion); // Set Global Root Signature subobject. - auto globalRootSigSubObj = - stateObjectDesc + auto GlobalRootSigSubObj = + StateObjectDesc .CreateSubobject(); - globalRootSigSubObj->SetRootSignature(pGlobalRootSignature); + GlobalRootSigSubObj->SetRootSignature(GlobalRootSignature); // Set Local Root Signature subobject. - stateObjectDesc.CreateSubobject() - ->SetRootSignature(pLocalRootSignature); + StateObjectDesc.CreateSubobject() + ->SetRootSignature(LocalRootSignature); - auto exports = stateObjectDesc.CreateSubobject< + auto Exports = StateObjectDesc.CreateSubobject< CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT>(); - exports->SetSubobjectToAssociate(*globalRootSigSubObj); - exports->AddExport(L"raygen"); - exports->AddExport(L"closesthit"); - exports->AddExport(L"anyhit"); - exports->AddExport(L"miss"); - if (useIS) { - exports->AddExport(L"intersection"); - } - if (useMesh && useProceduralGeometry) { - exports->AddExport(L"chAABB"); - } - - auto hitGroup = - stateObjectDesc.CreateSubobject(); - hitGroup->SetClosestHitShaderImport(L"closesthit"); - hitGroup->SetAnyHitShaderImport(L"anyhit"); - if (useIS) { - hitGroup->SetIntersectionShaderImport(L"intersection"); - hitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); - } - hitGroup->SetHitGroupExport(L"HitGroup"); - - if (useMesh && useProceduralGeometry) { - auto hitGroupAABB = - stateObjectDesc.CreateSubobject(); - hitGroupAABB->SetClosestHitShaderImport(L"chAABB"); - hitGroupAABB->SetAnyHitShaderImport(L"anyhit"); - if (useIS) { - hitGroup->SetIntersectionShaderImport(L"intersection"); - hitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); - } - hitGroupAABB->SetHitGroupExport(L"HitGroupAABB"); - } - - CComPtr pStateObject; - CComPtr pStateObjectProperties; + Exports->SetSubobjectToAssociate(*GlobalRootSigSubObj); + Exports->AddExport(L"raygen"); + Exports->AddExport(L"closesthit"); + Exports->AddExport(L"anyhit"); + Exports->AddExport(L"miss"); + if (UseIS) + Exports->AddExport(L"intersection"); + if (UseMesh && UseProceduralGeometry) + Exports->AddExport(L"chAABB"); + + auto HitGroup = + StateObjectDesc.CreateSubobject(); + HitGroup->SetClosestHitShaderImport(L"closesthit"); + HitGroup->SetAnyHitShaderImport(L"anyhit"); + if (UseIS) { + HitGroup->SetIntersectionShaderImport(L"intersection"); + HitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); + } + HitGroup->SetHitGroupExport(L"HitGroup"); + + if (UseMesh && UseProceduralGeometry) { + auto HitGroupAABB = + StateObjectDesc.CreateSubobject(); + HitGroupAABB->SetClosestHitShaderImport(L"chAABB"); + HitGroupAABB->SetAnyHitShaderImport(L"anyhit"); + if (UseIS) { + HitGroupAABB->SetIntersectionShaderImport(L"intersection"); + HitGroupAABB->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); + } + HitGroupAABB->SetHitGroupExport(L"HitGroupAABB"); + } + + CComPtr StateObject; + CComPtr StateObjectProperties; VERIFY_SUCCEEDED( - pDevice->CreateStateObject(stateObjectDesc, IID_PPV_ARGS(&pStateObject))); - VERIFY_SUCCEEDED(pStateObject->QueryInterface(&pStateObjectProperties)); + Device->CreateStateObject(StateObjectDesc, IID_PPV_ARGS(&StateObject))); + VERIFY_SUCCEEDED(StateObject->QueryInterface(&StateObjectProperties)); // Create SBT - ShaderTable shaderTable; - shaderTable.Init(pDevice, + ShaderTable ShaderTable; + ShaderTable.Init(Device, 1, // raygen count 1, // miss count - useMesh && useProceduralGeometry ? 2 : 1, // hit group count + UseMesh && UseProceduralGeometry ? 2 : 1, // hit group count 1, // ray type count 4 // dwords per root table ); - int localRootConsts[4] = {12, 34, 56, 78}; - memcpy(shaderTable.GetRaygenShaderIdPtr(0), - pStateObjectProperties->GetShaderIdentifier(L"raygen"), + int LocalRootConsts[4] = {12, 34, 56, 78}; + memcpy(ShaderTable.GetRaygenShaderIdPtr(0), + StateObjectProperties->GetShaderIdentifier(L"raygen"), SHADER_ID_SIZE_IN_BYTES); - memcpy(shaderTable.GetRaygenRootTablePtr(0), localRootConsts, - sizeof(localRootConsts)); - memcpy(shaderTable.GetMissShaderIdPtr(0, 0), - pStateObjectProperties->GetShaderIdentifier(L"miss"), + memcpy(ShaderTable.GetRaygenRootTablePtr(0), LocalRootConsts, + sizeof(LocalRootConsts)); + memcpy(ShaderTable.GetMissShaderIdPtr(0, 0), + StateObjectProperties->GetShaderIdentifier(L"miss"), SHADER_ID_SIZE_IN_BYTES); - memcpy(shaderTable.GetMissRootTablePtr(0, 0), localRootConsts, - sizeof(localRootConsts)); - memcpy(shaderTable.GetHitGroupShaderIdPtr(0, 0), - pStateObjectProperties->GetShaderIdentifier(L"HitGroup"), + memcpy(ShaderTable.GetMissRootTablePtr(0, 0), LocalRootConsts, + sizeof(LocalRootConsts)); + memcpy(ShaderTable.GetHitGroupShaderIdPtr(0, 0), + StateObjectProperties->GetShaderIdentifier(L"HitGroup"), SHADER_ID_SIZE_IN_BYTES); - memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), localRootConsts, - sizeof(localRootConsts)); - if (useMesh && useProceduralGeometry) { - memcpy(shaderTable.GetHitGroupShaderIdPtr(0, 1), - pStateObjectProperties->GetShaderIdentifier(L"HitGroupAABB"), + memcpy(ShaderTable.GetHitGroupRootTablePtr(0, 0), LocalRootConsts, + sizeof(LocalRootConsts)); + if (UseMesh && UseProceduralGeometry) + memcpy(ShaderTable.GetHitGroupShaderIdPtr(0, 1), + StateObjectProperties->GetShaderIdentifier(L"HitGroupAABB"), SHADER_ID_SIZE_IN_BYTES); - } - - // auto tbl = pDescriptorHeap->GetGPUDescriptorHandleForHeapStart().ptr; - // memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), &tbl, 8); // Create a command allocator and list. - CComPtr pCommandAllocator; - CComPtr pCommandList; - VERIFY_SUCCEEDED(pDevice->CreateCommandAllocator( - D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); - VERIFY_SUCCEEDED(pDevice->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, - pCommandAllocator, nullptr, - IID_PPV_ARGS(&pCommandList))); - pCommandList->SetName(L"ExecutionTest::RunDXRTest Command List"); - - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); - - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + CComPtr CommandAllocator; + CComPtr CommandList; + VERIFY_SUCCEEDED(Device->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); + VERIFY_SUCCEEDED(Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, + CommandAllocator, nullptr, + IID_PPV_ARGS(&CommandList))); + CommandList->SetName(L"ExecutionTest::RunDXRTest Command List"); + + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, nullptr)); // Create scene geometry. - CComPtr tlasResource; - CComPtr blasMeshResource; - CComPtr blasProceduralGeometryResource; - CComPtr instanceDescs; - CComPtr scratchResource; - - if (useMesh) { - CComPtr vertexBuffer; - CComPtr vertexBufferUpload; - CComPtr indexBuffer; - CComPtr indexBufferUpload; + CComPtr TLASResource; + CComPtr BLASMeshResource; + CComPtr BLASProceduralGeometryResource; + CComPtr InstanceDescs; + CComPtr ScratchResource; + + if (UseMesh) { + CComPtr VertexBuffer; + CComPtr VertexBufferUpload; + CComPtr IndexBuffer; + CComPtr IndexBufferUpload; // Define a Quad - const float verts[] = { + const float Verts[] = { -50.5f, 50.5f, 0.5f, // top left 50.5f, -50.5f, 0.5f, // bottom right -50.5f, -50.5f, 0.5f, // bottom left 50.5f, 50.5f, 0.5f // top right }; - const int indices[] = { + const int Indices[] = { 0, 1, 2, // first triangle 0, 3, 1 // second triangle }; - const UINT64 vertexDataSize = sizeof(verts); - const UINT64 indexDataSize = sizeof(indices); + const UINT64 VertexDataSize = sizeof(Verts); + const UINT64 IndexDataSize = sizeof(Indices); - AllocateUploadBuffer(pDevice, verts, vertexDataSize, &vertexBufferUpload, - L"vertexBufferUpload"); - AllocateUploadBuffer(pDevice, indices, indexDataSize, &indexBufferUpload, - L"indexBufferUpload"); + AllocateUploadBuffer(Device, Verts, VertexDataSize, &VertexBufferUpload, + L"VertexBufferUpload"); + AllocateUploadBuffer(Device, Indices, IndexDataSize, &IndexBufferUpload, + L"IndexBufferUpload"); AllocateBufferFromUpload( - pDevice, pCommandList, vertexBufferUpload, &vertexBuffer, - D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"vertexBuffer"); + Device, CommandList, VertexBufferUpload, &VertexBuffer, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"VertexBuffer"); AllocateBufferFromUpload( - pDevice, pCommandList, indexBufferUpload, &indexBuffer, - D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"indexBuffer"); + Device, CommandList, IndexBufferUpload, &IndexBuffer, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"IndexBuffer"); { - const int descriptorIndex = 1; - D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + const int DescriptorIndex = 1; + D3D12_CPU_DESCRIPTOR_HANDLE CpuDescriptorHandle = CD3DX12_CPU_DESCRIPTOR_HANDLE( - pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), - descriptorIndex, descriptorSize); - D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {}; - srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; - srvDesc.Shader4ComponentMapping = + DescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + DescriptorIndex, DescriptorSize); + D3D12_SHADER_RESOURCE_VIEW_DESC SrvDesc = {}; + SrvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; + SrvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; - srvDesc.Buffer.NumElements = - UINT(vertexDataSize / sizeof(DirectX::XMFLOAT3)); - srvDesc.Format = DXGI_FORMAT_UNKNOWN; - srvDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_NONE; - srvDesc.Buffer.StructureByteStride = sizeof(DirectX::XMFLOAT3); - pDevice->CreateShaderResourceView(vertexBuffer, &srvDesc, - cpuDescriptorHandle); + SrvDesc.Buffer.NumElements = + UINT(VertexDataSize / sizeof(DirectX::XMFLOAT3)); + SrvDesc.Format = DXGI_FORMAT_UNKNOWN; + SrvDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_NONE; + SrvDesc.Buffer.StructureByteStride = sizeof(DirectX::XMFLOAT3); + Device->CreateShaderResourceView(VertexBuffer, &SrvDesc, + CpuDescriptorHandle); } { - const int descriptorIndex = 2; - D3D12_CPU_DESCRIPTOR_HANDLE cpuDescriptorHandle = + const int DescriptorIndex = 2; + D3D12_CPU_DESCRIPTOR_HANDLE CpuDescriptorHandle = CD3DX12_CPU_DESCRIPTOR_HANDLE( - pDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), - descriptorIndex, descriptorSize); - D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {}; - srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; - srvDesc.Shader4ComponentMapping = + DescriptorHeap->GetCPUDescriptorHandleForHeapStart(), + DescriptorIndex, DescriptorSize); + D3D12_SHADER_RESOURCE_VIEW_DESC SrvDesc = {}; + SrvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; + SrvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; - srvDesc.Buffer.NumElements = UINT(indexDataSize / sizeof(int)); - srvDesc.Format = DXGI_FORMAT_UNKNOWN; - srvDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_NONE; - srvDesc.Buffer.StructureByteStride = sizeof(int); - pDevice->CreateShaderResourceView(indexBuffer, &srvDesc, - cpuDescriptorHandle); + SrvDesc.Buffer.NumElements = UINT(IndexDataSize / sizeof(int)); + SrvDesc.Format = DXGI_FORMAT_UNKNOWN; + SrvDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_NONE; + SrvDesc.Buffer.StructureByteStride = sizeof(int); + Device->CreateShaderResourceView(IndexBuffer, &SrvDesc, + CpuDescriptorHandle); } - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, nullptr)); - if (!useIS) { + if (!UseIS) { // Build BLAS. { - D3D12_RAYTRACING_GEOMETRY_DESC geometryDesc = {}; - geometryDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES; - geometryDesc.Triangles.IndexBuffer = - indexBuffer->GetGPUVirtualAddress(); - geometryDesc.Triangles.IndexCount = - static_cast(indexBuffer->GetDesc().Width) / sizeof(int); - geometryDesc.Triangles.IndexFormat = DXGI_FORMAT_R32_UINT; - geometryDesc.Triangles.Transform3x4 = 0; - geometryDesc.Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT; - geometryDesc.Triangles.VertexCount = - static_cast(vertexBuffer->GetDesc().Width) / + D3D12_RAYTRACING_GEOMETRY_DESC GeometryDesc = {}; + GeometryDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES; + GeometryDesc.Triangles.IndexBuffer = + IndexBuffer->GetGPUVirtualAddress(); + GeometryDesc.Triangles.IndexCount = + static_cast(IndexBuffer->GetDesc().Width) / sizeof(int); + GeometryDesc.Triangles.IndexFormat = DXGI_FORMAT_R32_UINT; + GeometryDesc.Triangles.Transform3x4 = 0; + GeometryDesc.Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT; + GeometryDesc.Triangles.VertexCount = + static_cast(VertexBuffer->GetDesc().Width) / sizeof(DirectX::XMFLOAT3); - geometryDesc.Triangles.VertexBuffer.StartAddress = - vertexBuffer->GetGPUVirtualAddress(); - geometryDesc.Triangles.VertexBuffer.StrideInBytes = + GeometryDesc.Triangles.VertexBuffer.StartAddress = + VertexBuffer->GetGPUVirtualAddress(); + GeometryDesc.Triangles.VertexBuffer.StrideInBytes = sizeof(DirectX::XMFLOAT3); - geometryDesc.Flags = + GeometryDesc.Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_NONE; // Non-opaque to trigger // anyhit. - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS buildFlags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE; - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS accelInputs = {}; - accelInputs.Type = + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; + AccelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; - accelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; - accelInputs.pGeometryDescs = &geometryDesc; - accelInputs.NumDescs = 1; - accelInputs.Flags = buildFlags; - - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; - pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, - &prebuildInfo); - - scratchResource.Release(); - ReallocScratchResource(pDevice, &scratchResource, - prebuildInfo.ScratchDataSizeInBytes); - AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes, - &blasMeshResource, true, + AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + AccelInputs.pGeometryDescs = &GeometryDesc; + AccelInputs.NumDescs = 1; + AccelInputs.Flags = BuildFlags; + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; + Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, + &PrebuildInfo); + + ScratchResource.Release(); + ReallocScratchResource(Device, &ScratchResource, + PrebuildInfo.ScratchDataSizeInBytes); + AllocateBuffer(Device, PrebuildInfo.ResultDataMaxSizeInBytes, + &BLASMeshResource, true, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"blasMesh"); - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC buildDesc = {}; - buildDesc.Inputs = accelInputs; - buildDesc.ScratchAccelerationStructureData = - scratchResource->GetGPUVirtualAddress(); - buildDesc.DestAccelerationStructureData = - blasMeshResource->GetGPUVirtualAddress(); - - pCommandList->BuildRaytracingAccelerationStructure(&buildDesc, 0, - nullptr); - CD3DX12_RESOURCE_BARRIER barrier = - CD3DX12_RESOURCE_BARRIER::UAV(blasMeshResource); - pCommandList->ResourceBarrier(1, - (const D3D12_RESOURCE_BARRIER *)&barrier); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; + BuildDesc.Inputs = AccelInputs; + BuildDesc.ScratchAccelerationStructureData = + ScratchResource->GetGPUVirtualAddress(); + BuildDesc.DestAccelerationStructureData = + BLASMeshResource->GetGPUVirtualAddress(); + + CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, + nullptr); + CD3DX12_RESOURCE_BARRIER Barrier = + CD3DX12_RESOURCE_BARRIER::UAV(BLASMeshResource); + CommandList->ResourceBarrier(1, + (const D3D12_RESOURCE_BARRIER *)&Barrier); } } - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, nullptr)); } - if (useProceduralGeometry) { + if (UseProceduralGeometry) { // Define procedural geometry AABB for a plane - CComPtr aabbBuffer; - CComPtr aabbBufferUpload; + CComPtr AabbBuffer; + CComPtr AabbBufferUpload; // Define the AABB for the plane, matching the size of the quad defined by // verts[] - const D3D12_RAYTRACING_AABB aabb = { + const D3D12_RAYTRACING_AABB Aabb = { -150.5f, -500.5f, -1000.0f, // Min corner (x, y, z) 150.5f, -150.5f, 1000.0f // Max corner (x, y, z) }; - const UINT64 aabbDataSize = sizeof(aabb); + const UINT64 AabbDataSize = sizeof(Aabb); // Create an upload buffer for the AABB - AllocateUploadBuffer(pDevice, &aabb, aabbDataSize, &aabbBufferUpload, - L"aabbBufferUpload"); + AllocateUploadBuffer(Device, &Aabb, AabbDataSize, &AabbBufferUpload, + L"AabbBufferUpload"); // Create a GPU buffer for the AABB - AllocateBufferFromUpload( - pDevice, pCommandList, aabbBufferUpload, &aabbBuffer, - D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, L"aabbBuffer"); + AllocateBufferFromUpload(Device, CommandList, AabbBufferUpload, &AabbBuffer, + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE, + L"AabbBuffer"); // Describe the procedural geometry - D3D12_RAYTRACING_GEOMETRY_DESC procGeometryDesc = {}; - procGeometryDesc.Type = + D3D12_RAYTRACING_GEOMETRY_DESC ProcGeometryDesc = {}; + ProcGeometryDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS; - procGeometryDesc.AABBs.AABBs.StartAddress = - aabbBuffer->GetGPUVirtualAddress(); - procGeometryDesc.AABBs.AABBs.StrideInBytes = sizeof(D3D12_RAYTRACING_AABB); - procGeometryDesc.AABBs.AABBCount = 1; + ProcGeometryDesc.AABBs.AABBs.StartAddress = + AabbBuffer->GetGPUVirtualAddress(); + ProcGeometryDesc.AABBs.AABBs.StrideInBytes = sizeof(D3D12_RAYTRACING_AABB); + ProcGeometryDesc.AABBs.AABBCount = 1; // Build the BLAS for the procedural geometry - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS blasInputs = {}; - blasInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; - blasInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; - blasInputs.NumDescs = 1; - blasInputs.pGeometryDescs = &procGeometryDesc; - blasInputs.Flags = + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS BLASInputs = {}; + BLASInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + BLASInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + BLASInputs.NumDescs = 1; + BLASInputs.pGeometryDescs = &ProcGeometryDesc; + BLASInputs.Flags = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE; - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; - pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&blasInputs, - &prebuildInfo); + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; + Device->GetRaytracingAccelerationStructurePrebuildInfo(&BLASInputs, + &PrebuildInfo); // Allocate scratch and result buffers for the BLAS - scratchResource.Release(); - ReallocScratchResource(pDevice, &scratchResource, - prebuildInfo.ScratchDataSizeInBytes); - AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes, - &blasProceduralGeometryResource, true, + ScratchResource.Release(); + ReallocScratchResource(Device, &ScratchResource, + PrebuildInfo.ScratchDataSizeInBytes); + AllocateBuffer(Device, PrebuildInfo.ResultDataMaxSizeInBytes, + &BLASProceduralGeometryResource, true, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, - L"blasProceduralGeometry"); + L"BlasProceduralGeometry"); // Build the BLAS - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC blasDesc = {}; - blasDesc.Inputs = blasInputs; - blasDesc.ScratchAccelerationStructureData = - scratchResource->GetGPUVirtualAddress(); - blasDesc.DestAccelerationStructureData = - blasProceduralGeometryResource->GetGPUVirtualAddress(); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BLASDesc = {}; + BLASDesc.Inputs = BLASInputs; + BLASDesc.ScratchAccelerationStructureData = + ScratchResource->GetGPUVirtualAddress(); + BLASDesc.DestAccelerationStructureData = + BLASProceduralGeometryResource->GetGPUVirtualAddress(); - pCommandList->BuildRaytracingAccelerationStructure(&blasDesc, 0, nullptr); + CommandList->BuildRaytracingAccelerationStructure(&BLASDesc, 0, nullptr); // Add a UAV barrier to ensure the BLAS is built before using it - CD3DX12_RESOURCE_BARRIER barrier = - CD3DX12_RESOURCE_BARRIER::UAV(blasProceduralGeometryResource); - pCommandList->ResourceBarrier(1, &barrier); + CD3DX12_RESOURCE_BARRIER Barrier = + CD3DX12_RESOURCE_BARRIER::UAV(BLASProceduralGeometryResource); + CommandList->ResourceBarrier(1, &Barrier); - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, nullptr)); } // Build TLAS. { - if (useMesh) { - D3D12_RAYTRACING_INSTANCE_DESC instanceDesc = {}; - instanceDesc.Transform[0][0] = instanceDesc.Transform[1][1] = - instanceDesc.Transform[2][2] = 1; - instanceDesc.InstanceMask = 1; - instanceDesc.AccelerationStructure = - blasMeshResource->GetGPUVirtualAddress(); - - AllocateUploadBuffer(pDevice, &instanceDesc, sizeof(instanceDesc), - &instanceDescs, L"instanceDescs"); - - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS buildFlags = + if (UseMesh) { + D3D12_RAYTRACING_INSTANCE_DESC InstanceDesc = {}; + InstanceDesc.Transform[0][0] = InstanceDesc.Transform[1][1] = + InstanceDesc.Transform[2][2] = 1; + InstanceDesc.InstanceMask = 1; + InstanceDesc.AccelerationStructure = + BLASMeshResource->GetGPUVirtualAddress(); + + AllocateUploadBuffer(Device, &InstanceDesc, sizeof(InstanceDesc), + &InstanceDescs, L"InstanceDescs"); + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS accelInputs = {}; - accelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; - accelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; - accelInputs.NumDescs = 1; - accelInputs.Flags = buildFlags; - accelInputs.InstanceDescs = instanceDescs->GetGPUVirtualAddress(); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; + AccelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; + AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + AccelInputs.NumDescs = 1; + AccelInputs.Flags = BuildFlags; + AccelInputs.InstanceDescs = InstanceDescs->GetGPUVirtualAddress(); - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; - pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, - &prebuildInfo); + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; + Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, + &PrebuildInfo); - scratchResource.Release(); - ReallocScratchResource(pDevice, &scratchResource, - prebuildInfo.ScratchDataSizeInBytes); + ScratchResource.Release(); + ReallocScratchResource(Device, &ScratchResource, + PrebuildInfo.ScratchDataSizeInBytes); AllocateBuffer( - pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true, + Device, PrebuildInfo.ResultDataMaxSizeInBytes, &TLASResource, true, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC buildDesc = {}; - buildDesc.Inputs = accelInputs; - buildDesc.ScratchAccelerationStructureData = - scratchResource->GetGPUVirtualAddress(); - buildDesc.DestAccelerationStructureData = - tlasResource->GetGPUVirtualAddress(); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; + BuildDesc.Inputs = AccelInputs; + BuildDesc.ScratchAccelerationStructureData = + ScratchResource->GetGPUVirtualAddress(); + BuildDesc.DestAccelerationStructureData = + TLASResource->GetGPUVirtualAddress(); - pCommandList->BuildRaytracingAccelerationStructure(&buildDesc, 0, 0); + CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, 0); } else { - D3D12_RAYTRACING_INSTANCE_DESC instanceDesc = {}; - instanceDesc.Transform[0][0] = instanceDesc.Transform[1][1] = - instanceDesc.Transform[2][2] = 1; - instanceDesc.InstanceMask = 1; - instanceDesc.AccelerationStructure = - blasProceduralGeometryResource->GetGPUVirtualAddress(); + D3D12_RAYTRACING_INSTANCE_DESC InstanceDesc = {}; + InstanceDesc.Transform[0][0] = InstanceDesc.Transform[1][1] = + InstanceDesc.Transform[2][2] = 1; + InstanceDesc.InstanceMask = 1; + InstanceDesc.AccelerationStructure = + BLASProceduralGeometryResource->GetGPUVirtualAddress(); - AllocateUploadBuffer(pDevice, &instanceDesc, sizeof(instanceDesc), - &instanceDescs, L"instanceDescs"); + AllocateUploadBuffer(Device, &InstanceDesc, sizeof(InstanceDesc), + &InstanceDescs, L"InstanceDescs"); - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS buildFlags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS accelInputs = {}; - accelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; - accelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; - accelInputs.NumDescs = 1; - accelInputs.Flags = buildFlags; - accelInputs.InstanceDescs = instanceDescs->GetGPUVirtualAddress(); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; + AccelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; + AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + AccelInputs.NumDescs = 1; + AccelInputs.Flags = BuildFlags; + AccelInputs.InstanceDescs = InstanceDescs->GetGPUVirtualAddress(); - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {}; - pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs, - &prebuildInfo); + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; + Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, + &PrebuildInfo); - scratchResource.Release(); - ReallocScratchResource(pDevice, &scratchResource, - prebuildInfo.ScratchDataSizeInBytes); + ScratchResource.Release(); + ReallocScratchResource(Device, &ScratchResource, + PrebuildInfo.ScratchDataSizeInBytes); AllocateBuffer( - pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true, + Device, PrebuildInfo.ResultDataMaxSizeInBytes, &TLASResource, true, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC buildDesc = {}; - buildDesc.Inputs = accelInputs; - buildDesc.ScratchAccelerationStructureData = - scratchResource->GetGPUVirtualAddress(); - buildDesc.DestAccelerationStructureData = - tlasResource->GetGPUVirtualAddress(); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; + BuildDesc.Inputs = AccelInputs; + BuildDesc.ScratchAccelerationStructureData = + ScratchResource->GetGPUVirtualAddress(); + BuildDesc.DestAccelerationStructureData = + TLASResource->GetGPUVirtualAddress(); - pCommandList->BuildRaytracingAccelerationStructure(&buildDesc, 0, 0); + CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, 0); } - CD3DX12_RESOURCE_BARRIER barrier = - CD3DX12_RESOURCE_BARRIER::UAV(tlasResource); - pCommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&barrier); + CD3DX12_RESOURCE_BARRIER Barrier = + CD3DX12_RESOURCE_BARRIER::UAV(TLASResource); + CommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&Barrier); } // Set the local root constants. - pCommandList->SetComputeRootSignature(pLocalRootSignature); - pCommandList->SetComputeRoot32BitConstant(1, 12, 0); - pCommandList->SetComputeRoot32BitConstant(1, 34, 1); - pCommandList->SetComputeRoot32BitConstant(1, 56, 2); - pCommandList->SetComputeRoot32BitConstant(1, 78, 3); - - shaderTable.Upload(pCommandList); - - ID3D12DescriptorHeap *const pHeaps[1] = {pDescriptorHeap}; - pCommandList->SetDescriptorHeaps(1, pHeaps); - pCommandList->SetComputeRootSignature(pGlobalRootSignature); - pCommandList->SetComputeRootShaderResourceView( - 0, tlasResource->GetGPUVirtualAddress()); - pCommandList->SetComputeRootConstantBufferView( - 1, pSceneConstantBuffer->GetGPUVirtualAddress()); - pCommandList->SetComputeRootDescriptorTable( - 2, pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); - - D3D12_DISPATCH_RAYS_DESC dispatchDesc = {}; - dispatchDesc.RayGenerationShaderRecord.StartAddress = - shaderTable.GetRaygenStartGpuVA(); - dispatchDesc.RayGenerationShaderRecord.SizeInBytes = - shaderTable.GetRaygenRangeInBytes(); - dispatchDesc.MissShaderTable.StartAddress = shaderTable.GetMissStartGpuVA(); - dispatchDesc.MissShaderTable.SizeInBytes = shaderTable.GetMissRangeInBytes(); - dispatchDesc.MissShaderTable.StrideInBytes = - shaderTable.GetShaderRecordSizeInBytes(); - dispatchDesc.HitGroupTable.StartAddress = shaderTable.GetHitGroupStartGpuVA(); - dispatchDesc.HitGroupTable.SizeInBytes = - shaderTable.GetHitGroupRangeInBytes(); - dispatchDesc.HitGroupTable.StrideInBytes = - shaderTable.GetShaderRecordSizeInBytes(); - dispatchDesc.Width = windowWidth; - dispatchDesc.Height = windowHeight; - dispatchDesc.Depth = 1; - pCommandList->SetPipelineState1(pStateObject); - pCommandList->DispatchRays(&dispatchDesc); - - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); - - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, nullptr)); + CommandList->SetComputeRootSignature(LocalRootSignature); + CommandList->SetComputeRoot32BitConstant(1, 12, 0); + CommandList->SetComputeRoot32BitConstant(1, 34, 1); + CommandList->SetComputeRoot32BitConstant(1, 56, 2); + CommandList->SetComputeRoot32BitConstant(1, 78, 3); + + ShaderTable.Upload(CommandList); + + ID3D12DescriptorHeap *const Heaps[1] = {DescriptorHeap}; + CommandList->SetDescriptorHeaps(1, Heaps); + CommandList->SetComputeRootSignature(GlobalRootSignature); + CommandList->SetComputeRootShaderResourceView( + 0, TLASResource->GetGPUVirtualAddress()); + CommandList->SetComputeRootConstantBufferView( + 1, SceneConstantBuffer->GetGPUVirtualAddress()); + CommandList->SetComputeRootDescriptorTable( + 2, DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + D3D12_DISPATCH_RAYS_DESC DispatchDesc = {}; + DispatchDesc.RayGenerationShaderRecord.StartAddress = + ShaderTable.GetRaygenStartGpuVA(); + DispatchDesc.RayGenerationShaderRecord.SizeInBytes = + ShaderTable.GetRaygenRangeInBytes(); + DispatchDesc.MissShaderTable.StartAddress = ShaderTable.GetMissStartGpuVA(); + DispatchDesc.MissShaderTable.SizeInBytes = ShaderTable.GetMissRangeInBytes(); + DispatchDesc.MissShaderTable.StrideInBytes = + ShaderTable.GetShaderRecordSizeInBytes(); + DispatchDesc.HitGroupTable.StartAddress = ShaderTable.GetHitGroupStartGpuVA(); + DispatchDesc.HitGroupTable.SizeInBytes = + ShaderTable.GetHitGroupRangeInBytes(); + DispatchDesc.HitGroupTable.StrideInBytes = + ShaderTable.GetShaderRecordSizeInBytes(); + DispatchDesc.Width = WindowWidth; + DispatchDesc.Height = WindowHeight; + DispatchDesc.Depth = 1; + CommandList->SetPipelineState1(StateObject); + CommandList->DispatchRays(&DispatchDesc); + + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, nullptr)); // Copy the testBuffer contents to CPU - D3D12_RESOURCE_BARRIER barriers[1]; - barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - pTestBuffer, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_BARRIER Barriers[1]; + Barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( + TestBuffer, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); - pCommandList->ResourceBarrier(1, barriers); - pCommandList->CopyResource(pTestBufferRead, pTestBuffer); - barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - pTestBuffer, D3D12_RESOURCE_STATE_COPY_SOURCE, + CommandList->ResourceBarrier(1, Barriers); + CommandList->CopyResource(TestBufferRead, TestBuffer); + Barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( + TestBuffer, D3D12_RESOURCE_STATE_COPY_SOURCE, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); - pCommandList->ResourceBarrier(1, barriers); + CommandList->ResourceBarrier(1, Barriers); - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); // Copy the shader test data into 'testData'. - MappedData data(pTestBufferRead, (UINT32)testData.size() * sizeof(int)); - const int *pData = (int *)data.data(); + MappedData Data(TestBufferRead, (UINT32)TestData.size() * sizeof(int)); + const int *DataPtr = (int *)Data.data(); - for (int i = 0; i < testData.size(); i++) { - testData[i] = *pData++; - } + for (int i = 0; i < TestData.size(); i++) + TestData[i] = *DataPtr++; // Cleanup resources - pTestBuffer.Release(); - pTestBufferRead.Release(); - pSceneConstantBuffer.Release(); - pDescriptorHeap.Release(); - pCommandQueue.Release(); - pCommandAllocator.Release(); - pCommandList.Release(); - pStateObject.Release(); - pStateObjectProperties.Release(); - tlasResource.Release(); - blasMeshResource.Release(); - blasProceduralGeometryResource.Release(); - instanceDescs.Release(); - scratchResource.Release(); - - return pTestBufferRead; + TestBuffer.Release(); + TestBufferRead.Release(); + SceneConstantBuffer.Release(); + DescriptorHeap.Release(); + CommandQueue.Release(); + CommandAllocator.Release(); + CommandList.Release(); + StateObject.Release(); + StateObjectProperties.Release(); + TLASResource.Release(); + BLASMeshResource.Release(); + BLASProceduralGeometryResource.Release(); + InstanceDescs.Release(); + ScratchResource.Release(); + + return TestBufferRead; } // SER TESTS diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 1c24795a0c..1a5c2eb7ab 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -10,9 +10,11 @@ // // /////////////////////////////////////////////////////////////////////////////// +#pragma once + TEST_F(ExecutionTest, SERScalarGetterTest) { // SER: Test basic function of HitObject getters. - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -20,7 +22,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -96,54 +98,33 @@ void closesthit(inout PerRayData payload, in Attrs attrs) } )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - - if (!bDXRSupported) - return; - - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + const int WindowSize = 64; // RayTMin { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetRayTMin()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=float", L"-DHIT_GET_SCALAR=RayTMin", L"-DMISS_GET_SCALAR=RayTMin", L"-DSER_GET_SCALAR=GetRayTMin"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - float *resArray = (float *)(testData.data() + id); - float refVal = resArray[0]; - float serVal = resArray[1]; - const bool passRayTMin = CompareFloatEpsilon(serVal, refVal, 0.0008f); - if (!passRayTMin) { - VERIFY_IS_TRUE(passRayTMin); - WEX::Logging::Log::Comment(L"HitObject::GetRayTMin() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + float *ResArray = (float *)(TestData.data() + Id); + float RefVal = ResArray[0]; + float SerVal = ResArray[1]; + const bool PassRayTMin = CompareFloatEpsilon(SerVal, RefVal, 0.0008f); + if (!PassRayTMin) { + VERIFY_IS_TRUE(PassRayTMin); return; } } @@ -152,24 +133,24 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // RayTCurrent { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetRayTCurrent()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=float", L"-DHIT_GET_SCALAR=RayTCurrent", L"-DMISS_GET_SCALAR=RayTCurrent", L"-DSER_GET_SCALAR=GetRayTCurrent"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - float *resArray = (float *)(testData.data() + id); - float refVal = resArray[0]; - float serVal = resArray[1]; - const bool passRayTCurrent = CompareFloatEpsilon(serVal, refVal, 0.0008f); - if (!passRayTCurrent) { - VERIFY_IS_TRUE(passRayTCurrent); - WEX::Logging::Log::Comment(L"HitObject::GetRayTCurrent() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + float *ResArray = (float *)(TestData.data() + Id); + float RefVal = ResArray[0]; + float SerVal = ResArray[1]; + const bool PassRayTCurrent = CompareFloatEpsilon(SerVal, RefVal, 0.0008f); + if (!PassRayTCurrent) { + VERIFY_IS_TRUE(PassRayTCurrent); return; } } @@ -178,22 +159,22 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // RayFlags { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetRayFlags()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=uint", L"-DHIT_GET_SCALAR=RayFlags", L"-DMISS_GET_SCALAR=RayFlags", L"-DSER_GET_SCALAR=GetRayFlags"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - const int refVal = testData[id]; - const int serVal = testData[id + 1]; - if (refVal != serVal) { - VERIFY_ARE_EQUAL(refVal, serVal); - WEX::Logging::Log::Comment(L"HitObject::GetRayFlags() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + const int RefVal = TestData[Id]; + const int SerVal = TestData[Id + 1]; + if (RefVal != SerVal) { + VERIFY_ARE_EQUAL(RefVal, SerVal); return; } } @@ -202,22 +183,22 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // HitKind { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetHitKind()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=uint", L"-DHIT_GET_SCALAR=HitKind", L"-DMISS_GET_SCALAR=getIntZero", L"-DSER_GET_SCALAR=GetHitKind"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - const int refVal = testData[id]; - const int serVal = testData[id + 1]; - if (refVal != serVal) { - VERIFY_ARE_EQUAL(refVal, serVal); - WEX::Logging::Log::Comment(L"HitObject::GetHitKind() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + const int RefVal = TestData[Id]; + const int SerVal = TestData[Id + 1]; + if (RefVal != SerVal) { + VERIFY_ARE_EQUAL(RefVal, SerVal); return; } } @@ -226,22 +207,22 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // GeometryIndex { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetGeometryIndex()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=uint", L"-DHIT_GET_SCALAR=GeometryIndex", L"-DMISS_GET_SCALAR=getIntZero", L"-DSER_GET_SCALAR=GetGeometryIndex"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - const int refVal = testData[id]; - const int serVal = testData[id + 1]; - if (refVal != serVal) { - VERIFY_ARE_EQUAL(refVal, serVal); - WEX::Logging::Log::Comment(L"HitObject::GetGeometryIndex() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + const int RefVal = TestData[Id]; + const int SerVal = TestData[Id + 1]; + if (RefVal != SerVal) { + VERIFY_ARE_EQUAL(RefVal, SerVal); return; } } @@ -250,22 +231,22 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // InstanceIndex { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetInstanceIndex()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=uint", L"-DHIT_GET_SCALAR=InstanceIndex", L"-DMISS_GET_SCALAR=getIntZero", L"-DSER_GET_SCALAR=GetInstanceIndex"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - const int refVal = testData[id]; - const int serVal = testData[id + 1]; - if (refVal != serVal) { - VERIFY_ARE_EQUAL(refVal, serVal); - WEX::Logging::Log::Comment(L"HitObject::GetInstanceIndex() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + const int RefVal = TestData[Id]; + const int SerVal = TestData[Id + 1]; + if (RefVal != SerVal) { + VERIFY_ARE_EQUAL(RefVal, SerVal); return; } } @@ -274,22 +255,22 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // InstanceID { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetInstanceID()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=uint", L"-DHIT_GET_SCALAR=InstanceID", L"-DMISS_GET_SCALAR=getIntZero", L"-DSER_GET_SCALAR=GetInstanceID"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - const int refVal = testData[id]; - const int serVal = testData[id + 1]; - if (refVal != serVal) { - VERIFY_ARE_EQUAL(refVal, serVal); - WEX::Logging::Log::Comment(L"HitObject::GetInstanceID() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + const int RefVal = TestData[Id]; + const int SerVal = TestData[Id + 1]; + if (RefVal != SerVal) { + VERIFY_ARE_EQUAL(RefVal, SerVal); return; } } @@ -298,22 +279,22 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // PrimitiveIndex { - std::vector testData(windowSize * windowSize * 2, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetPrimitiveIndex()"); + std::vector TestData(WindowSize * WindowSize * 2, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DVALTYPE=uint", L"-DHIT_GET_SCALAR=PrimitiveIndex", L"-DMISS_GET_SCALAR=getIntZero", L"-DSER_GET_SCALAR=GetPrimitiveIndex"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - for (int id = 0; id < testData.size(); id += 2) { - const int refVal = testData[id]; - const int serVal = testData[id + 1]; - if (refVal != serVal) { - VERIFY_ARE_EQUAL(refVal, serVal); - WEX::Logging::Log::Comment(L"HitObject::GetPrimitiveIndex() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 2) { + const int RefVal = TestData[Id]; + const int SerVal = TestData[Id + 1]; + if (RefVal != SerVal) { + VERIFY_ARE_EQUAL(RefVal, SerVal); return; } } @@ -323,7 +304,7 @@ void closesthit(inout PerRayData payload, in Attrs attrs) TEST_F(ExecutionTest, SERVectorGetterTest) { // SER: Test basic function of HitObject getters. - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -331,7 +312,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -410,60 +391,39 @@ void closesthit(inout PerRayData payload, in Attrs attrs) } )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - - if (!bDXRSupported) - return; - - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + const int WindowSize = 64; // WorldRayOrigin { - std::vector testData(windowSize * windowSize * 6, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=WorldRayOrigin", + WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldRayOrigin()"); + std::vector TestData(WindowSize * WindowSize * 6, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=WorldRayOrigin", L"-DMISS_GET_VECTOR=WorldRayOrigin", L"-DSER_GET_VECTOR=GetWorldRayOrigin"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 3 /*payloadCount*/); - for (int id = 0; id < testData.size(); id += 6) { - float *resArray = (float *)(testData.data() + id); - float refX = resArray[0]; - float serX = resArray[1]; - float refY = resArray[2]; - float serY = resArray[3]; - float refZ = resArray[4]; - float serZ = resArray[5]; - const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); - const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); - const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); - if (!passX || !passY || !passZ) { - VERIFY_ARE_EQUAL(serX, refX); - VERIFY_ARE_EQUAL(serY, refY); - VERIFY_ARE_EQUAL(serZ, refZ); - WEX::Logging::Log::Comment(L"HitObject::GetWorldRayOrigin() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 6) { + float *ResArray = (float *)(TestData.data() + Id); + float RefX = ResArray[0]; + float SerX = ResArray[1]; + float RefY = ResArray[2]; + float SerY = ResArray[3]; + float RefZ = ResArray[4]; + float SerZ = ResArray[5]; + const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); + const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); + const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); + if (!PassX || !PassY || !PassZ) { + VERIFY_ARE_EQUAL(SerX, RefX); + VERIFY_ARE_EQUAL(SerY, RefY); + VERIFY_ARE_EQUAL(SerZ, RefZ); break; } } @@ -472,32 +432,32 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // WorldRayDirection { - std::vector testData(windowSize * windowSize * 6, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd", + WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldRayDirection()"); + std::vector TestData(WindowSize * WindowSize * 6, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=WorldRayDirection", L"-DMISS_GET_VECTOR=WorldRayDirection", L"-DSER_GET_VECTOR=GetWorldRayDirection"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 3 /*payloadCount*/); - for (int id = 0; id < testData.size(); id += 6) { - float *resArray = (float *)(testData.data() + id); - float refX = resArray[0]; - float serX = resArray[1]; - float refY = resArray[2]; - float serY = resArray[3]; - float refZ = resArray[4]; - float serZ = resArray[5]; - const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); - const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); - const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); - if (!passX || !passY || !passZ) { - VERIFY_ARE_EQUAL(serX, refX); - VERIFY_ARE_EQUAL(serY, refY); - VERIFY_ARE_EQUAL(serZ, refZ); - WEX::Logging::Log::Comment(L"HitObject::GetWorldRayDirection() FAILED"); - return; + for (int Id = 0; Id < TestData.size(); Id += 6) { + float *ResArray = (float *)(TestData.data() + Id); + float RefX = ResArray[0]; + float SerX = ResArray[1]; + float RefY = ResArray[2]; + float SerY = ResArray[3]; + float RefZ = ResArray[4]; + float SerZ = ResArray[5]; + const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); + const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); + const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); + if (!PassX || !PassY || !PassZ) { + VERIFY_ARE_EQUAL(SerX, RefX); + VERIFY_ARE_EQUAL(SerY, RefY); + VERIFY_ARE_EQUAL(SerZ, RefZ); + break; } } WEX::Logging::Log::Comment(L"HitObject::GetWorldRayDirection() PASSED"); @@ -505,30 +465,30 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // ObjectRayOrigin { - std::vector testData(windowSize * windowSize * 6, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=ObjectRayOrigin", + WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectRayOrigin()"); + std::vector TestData(WindowSize * WindowSize * 6, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=ObjectRayOrigin", L"-DMISS_GET_VECTOR=WorldRayOrigin", L"-DSER_GET_VECTOR=GetObjectRayOrigin"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 3 /*payloadCount*/); - for (int id = 0; id < testData.size(); id += 6) { - float *resArray = (float *)(testData.data() + id); - float refX = resArray[0]; - float serX = resArray[1]; - float refY = resArray[2]; - float serY = resArray[3]; - float refZ = resArray[4]; - float serZ = resArray[5]; - const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); - const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); - const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); - if (!passX || !passY || !passZ) { - VERIFY_ARE_EQUAL(serX, refX); - VERIFY_ARE_EQUAL(serY, refY); - VERIFY_ARE_EQUAL(serZ, refZ); - WEX::Logging::Log::Comment(L"HitObject::GetObjectRayOrigin() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 6) { + float *ResArray = (float *)(TestData.data() + Id); + float RefX = ResArray[0]; + float SerX = ResArray[1]; + float RefY = ResArray[2]; + float SerY = ResArray[3]; + float RefZ = ResArray[4]; + float SerZ = ResArray[5]; + const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); + const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); + const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); + if (!PassX || !PassY || !PassZ) { + VERIFY_ARE_EQUAL(SerX, RefX); + VERIFY_ARE_EQUAL(SerY, RefY); + VERIFY_ARE_EQUAL(SerZ, RefZ); break; } } @@ -537,32 +497,31 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // ObjectRayDirection { - std::vector testData(windowSize * windowSize * 6, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd", + WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectRayDirection()"); + std::vector TestData(WindowSize * WindowSize * 6, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=ObjectRayDirection", L"-DMISS_GET_VECTOR=WorldRayDirection", L"-DSER_GET_VECTOR=GetObjectRayDirection"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 3 /*payloadCount*/); - for (int id = 0; id < testData.size(); id += 6) { - float *resArray = (float *)(testData.data() + id); - float refX = resArray[0]; - float serX = resArray[1]; - float refY = resArray[2]; - float serY = resArray[3]; - float refZ = resArray[4]; - float serZ = resArray[5]; - const bool passX = CompareFloatEpsilon(serX, refX, 0.0008f); - const bool passY = CompareFloatEpsilon(serY, refY, 0.0008f); - const bool passZ = CompareFloatEpsilon(serZ, refZ, 0.0008f); - if (!passX || !passY || !passZ) { - VERIFY_ARE_EQUAL(serX, refX); - VERIFY_ARE_EQUAL(serY, refY); - VERIFY_ARE_EQUAL(serZ, refZ); - WEX::Logging::Log::Comment( - L"HitObject::GetObjectRayDirection() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 6) { + float *ResArray = (float *)(TestData.data() + Id); + float RefX = ResArray[0]; + float SerX = ResArray[1]; + float RefY = ResArray[2]; + float SerY = ResArray[3]; + float RefZ = ResArray[4]; + float SerZ = ResArray[5]; + const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); + const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); + const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); + if (!PassX || !PassY || !PassZ) { + VERIFY_ARE_EQUAL(SerX, RefX); + VERIFY_ARE_EQUAL(SerY, RefY); + VERIFY_ARE_EQUAL(SerZ, RefZ); break; } } @@ -572,7 +531,7 @@ void closesthit(inout PerRayData payload, in Attrs attrs) TEST_F(ExecutionTest, SERMatrixGetterTest) { // SER: Test basic function of HitObject getters. - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -580,7 +539,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -679,59 +638,38 @@ void closesthit(inout PerRayData payload, in Attrs attrs) } )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } - // Initialize test data. - const int windowSize = 64; - - if (!bDXRSupported) - return; - - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + const int WindowSize = 64; // WorldToObject3x4 { - std::vector testData(windowSize * windowSize * 24, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldToObject3x4()"); + std::vector TestData(WindowSize * WindowSize * 24, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_MATRIX=WorldToObject3x4", L"-DMISS_GET_MATRIX=getMatIdentity", L"-DSER_GET_MATRIX=GetWorldToObject3x4", L"-DROWS=3", L"-DCOLS=4"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 12 /*payloadCount*/); const int ROWS = 3; const int COLS = 4; - for (int id = 0; id < testData.size(); id += 24) { - float *resArray = (float *)(testData.data() + id); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - int refIdx = 2 * (r * COLS + c); - float ref = resArray[refIdx]; - float ser = resArray[1 + refIdx]; - if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { - VERIFY_ARE_EQUAL(ser, ref); + for (int Id = 0; Id < TestData.size(); Id += 24) { + float *ResArray = (float *)(TestData.data() + Id); + for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { + for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { + int RefIdx = 2 * (RowIdx * COLS + ColIdx); + float Ref = ResArray[RefIdx]; + float Ser = ResArray[1 + RefIdx]; + if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { + VERIFY_ARE_EQUAL(Ser, Ref); } } } @@ -741,29 +679,30 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // WorldToObject4x3 { + WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldToObject4x3()"); const int ROWS = 4; const int COLS = 3; - std::vector testData(windowSize * windowSize * 2 * ROWS * COLS, 0); - LPCWSTR args[] = {L"-HV 2021", + std::vector TestData(WindowSize * WindowSize * 2 * ROWS * COLS, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_MATRIX=WorldToObject4x3", L"-DMISS_GET_MATRIX=getMatIdentity", L"-DSER_GET_MATRIX=GetWorldToObject4x3", L"-DROWS=4", L"-DCOLS=3"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 12 /*payloadCount*/); - for (int id = 0; id < testData.size(); id += 2 * ROWS * COLS) { - float *resArray = (float *)(testData.data() + id); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - int refIdx = 2 * (r * COLS + c); - float ref = resArray[refIdx]; - float ser = resArray[1 + refIdx]; - if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { - VERIFY_ARE_EQUAL(ser, ref); + for (int Id = 0; Id < TestData.size(); Id += 2 * ROWS * COLS) { + float *ResArray = (float *)(TestData.data() + Id); + for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { + for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { + int RefIdx = 2 * (RowIdx * COLS + ColIdx); + float Ref = ResArray[RefIdx]; + float Ser = ResArray[1 + RefIdx]; + if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { + VERIFY_ARE_EQUAL(Ser, Ref); } } } @@ -773,29 +712,30 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // ObjectToWorld3x4 { - std::vector testData(windowSize * windowSize * 24, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectToWorld3x4()"); + std::vector TestData(WindowSize * WindowSize * 24, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_MATRIX=ObjectToWorld3x4", L"-DMISS_GET_MATRIX=getMatIdentity", L"-DSER_GET_MATRIX=GetObjectToWorld3x4", L"-DROWS=3", L"-DCOLS=4"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 12 /*payloadCount*/); const int ROWS = 3; const int COLS = 4; - for (int id = 0; id < testData.size(); id += 24) { - float *resArray = (float *)(testData.data() + id); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - int refIdx = 2 * (r * COLS + c); - float ref = resArray[refIdx]; - float ser = resArray[1 + refIdx]; - if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { - VERIFY_ARE_EQUAL(ser, ref); + for (int Id = 0; Id < TestData.size(); Id += 24) { + float *ResArray = (float *)(TestData.data() + Id); + for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { + for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { + int RefIdx = 2 * (RowIdx * COLS + ColIdx); + float Ref = ResArray[RefIdx]; + float Ser = ResArray[1 + RefIdx]; + if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { + VERIFY_ARE_EQUAL(Ser, Ref); } } } @@ -805,31 +745,30 @@ void closesthit(inout PerRayData payload, in Attrs attrs) // ObjectToWorld4x3 { - std::vector testData(windowSize * windowSize * 24, 0); - LPCWSTR args[] = {L"-HV 2021", + WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectToWorld4x3()"); + std::vector TestData(WindowSize * WindowSize * 24, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_MATRIX=ObjectToWorld4x3", L"-DMISS_GET_MATRIX=getMatIdentity", L"-DSER_GET_MATRIX=GetObjectToWorld4x3", L"-DROWS=4", L"-DCOLS=3"}; - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/, 12 /*payloadCount*/); const int ROWS = 4; const int COLS = 3; - for (int id = 0; id < testData.size(); id += 24) { - float *resArray = (float *)(testData.data() + id); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - int refIdx = 2 * (r * COLS + c); - float ref = resArray[refIdx]; - float ser = resArray[1 + refIdx]; - if (!CompareFloatEpsilon(ser, ref, 0.0008f)) { - VERIFY_ARE_EQUAL(ser, ref); - WEX::Logging::Log::Comment( - L"HitObject::GetObjectToWorld4x3() FAILED"); + for (int Id = 0; Id < TestData.size(); Id += 24) { + float *ResArray = (float *)(TestData.data() + Id); + for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { + for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { + int RefIdx = 2 * (RowIdx * COLS + ColIdx); + float Ref = ResArray[RefIdx]; + float Ser = ResArray[1 + RefIdx]; + if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { + VERIFY_ARE_EQUAL(Ser, Ref); break; } } @@ -841,7 +780,7 @@ void closesthit(inout PerRayData payload, in Attrs attrs) TEST_F(ExecutionTest, SERBasicTest) { // SER: Test basic functionality. - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -849,7 +788,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -922,50 +861,29 @@ void closesthit(inout PerRayData payload, in Attrs attrs) )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; - // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; - - if (bDXRSupported) { - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; - } - VERIFY_ARE_EQUAL(histo.size(), 2); - VERIFY_ARE_EQUAL(histo[2], 4030); - VERIFY_ARE_EQUAL(histo[5], 66); - } + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 2); + VERIFY_ARE_EQUAL(Histo[2], 4030); + VERIFY_ARE_EQUAL(Histo[5], 66); } TEST_F(ExecutionTest, SERShaderTableIndexTest) { // Test SER with HitObject::SetShaderTableIndex and // HitObject::GetShaderTableIndex - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -973,7 +891,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -1063,49 +981,30 @@ void chAABB(inout PerRayData payload, in Attrs attrs) )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; - - if (bDXRSupported) { - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*mesh*/, - true /*procedural geometry*/, false /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; - } - VERIFY_ARE_EQUAL(histo.size(), 2); - VERIFY_ARE_EQUAL(histo[2], 4030); - VERIFY_ARE_EQUAL(histo[13], 66); - } + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*mesh*/, + true /*procedural geometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 2); + VERIFY_ARE_EQUAL(Histo[2], 4030); + VERIFY_ARE_EQUAL(Histo[13], 66); } TEST_F(ExecutionTest, SERLoadLocalRootTableConstantTest) { // Test SER with HitObject::LoadLocalRootTableConstant - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -1113,7 +1012,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -1204,48 +1103,28 @@ void closesthit(inout PerRayData payload, in Attrs attrs) )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; - if (!bDXRSupported) - return; - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, false /*useProceduralGeometry*/, false /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; - } - VERIFY_ARE_EQUAL(histo.size(), 1); - VERIFY_ARE_EQUAL(histo[126], 4096); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 1); + VERIFY_ARE_EQUAL(Histo[126], 4096); } TEST_F(ExecutionTest, SERRayQueryTest) { // Test SER RayQuery - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -1253,7 +1132,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -1381,49 +1260,29 @@ void closesthit(inout PerRayData payload, in Attrs attrs) )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERRayQueryTest requires shader model 6.9+ " - L"but no supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SERRayQueryTest skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SERRayQueryTest skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; - - if (bDXRSupported) { - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; - } - VERIFY_ARE_EQUAL(histo.size(), 2); - VERIFY_ARE_EQUAL(histo[0], 66); - VERIFY_ARE_EQUAL(histo[2], 4030); - } + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 2); + VERIFY_ARE_EQUAL(Histo[0], 66); + VERIFY_ARE_EQUAL(Histo[2], 4030); } TEST_F(ExecutionTest, SERIntersectionTest) { // Test SER with Intersection and procedural geometry - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -1431,7 +1290,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -1517,49 +1376,29 @@ void intersection() )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; - - if (bDXRSupported) { - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, false /*mesh*/, - true /*procedural geometry*/, true /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; - } - VERIFY_ARE_EQUAL(histo.size(), 2); - VERIFY_ARE_EQUAL(histo[2], 3400); - VERIFY_ARE_EQUAL(histo[5], 696); - } + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, false /*mesh*/, + true /*procedural geometry*/, true /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 2); + VERIFY_ARE_EQUAL(Histo[2], 3400); + VERIFY_ARE_EQUAL(Histo[5], 696); } TEST_F(ExecutionTest, SERGetAttributesTest) { // Test SER with HitObject::GetAttributes - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -1567,7 +1406,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -1681,52 +1520,32 @@ void intersection() )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; - - if (bDXRSupported) { - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, false /*mesh*/, - true /*procedural geometry*/, true /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; - } - VERIFY_ARE_EQUAL(histo.size(), 4); - VERIFY_ARE_EQUAL(histo[0], 328); - VERIFY_ARE_EQUAL(histo[1], 186); - VERIFY_ARE_EQUAL(histo[3], 182); - VERIFY_ARE_EQUAL(histo[255], 3400); - } + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, false /*mesh*/, + true /*procedural geometry*/, true /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 4); + VERIFY_ARE_EQUAL(Histo[0], 328); + VERIFY_ARE_EQUAL(Histo[1], 186); + VERIFY_ARE_EQUAL(Histo[3], 182); + VERIFY_ARE_EQUAL(Histo[255], 3400); } TEST_F(ExecutionTest, SERTraceHitMissNopTest) { // Test SER with conditional HitObject::TraceRay, HitObject::IsHit, // HitObject::IsMiss, HitObject::IsNop - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -1734,7 +1553,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -1819,56 +1638,36 @@ void closesthit(inout PerRayData payload, in Attrs attrs) )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; - - if (bDXRSupported) { - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*mesh*/, - false /*procedural geometry*/, false /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; - } - VERIFY_ARE_EQUAL(histo.size(), 3); - VERIFY_ARE_EQUAL( - histo[1], - 2048); // isNop && !isMiss && !isHit && !anyhit && !closesthit && !miss - VERIFY_ARE_EQUAL( - histo[18], - 2015); // !isNop && isMiss && !isHit && !anyhit && !closesthit && miss - VERIFY_ARE_EQUAL( - histo[44], - 33); // !isNop && !isMiss && isHit && anyhit && closesthit && !miss - } + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*mesh*/, + false /*procedural geometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 3); + VERIFY_ARE_EQUAL( + Histo[1], + 2048); // isNop && !isMiss && !isHit && !anyhit && !closesthit && !miss + VERIFY_ARE_EQUAL( + Histo[18], + 2015); // !isNop && isMiss && !isHit && !anyhit && !closesthit && miss + VERIFY_ARE_EQUAL( + Histo[44], + 33); // !isNop && !isMiss && isHit && anyhit && closesthit && !miss } TEST_F(ExecutionTest, SERIsMissTest) { // Test SER with HitObject::IsMiss - static const char *pShader = R"( + static const char *ShaderSrc = R"( struct SceneConstants { float4 eye; @@ -1876,7 +1675,7 @@ struct SceneConstants float4 V; float4 W; float sceneScale; - uint2 windowSize; + uint2 WindowSize; int rayFlags; }; @@ -1953,42 +1752,151 @@ void closesthit(inout PerRayData payload, in Attrs attrs) )"; - CComPtr pDevice; - bool bSM_6_9_Supported = CreateDevice(&pDevice, D3D_SHADER_MODEL_6_9, false); - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment(L"SERTest requires shader model 6.9+ but no " - L"supported device was found."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - } - bool bDXRSupported = - bSM_6_9_Supported && DoesDeviceSupportRayTracing(pDevice); - - if (!bSM_6_9_Supported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support SM 6.9."); - } - if (!bDXRSupported) { - WEX::Logging::Log::Comment( - L"SER tests skipped, device does not support DXR."); - } // Initialize test data. - const int windowSize = 64; - std::vector testData(windowSize * windowSize, 0); - LPCWSTR args[] = {L"-HV 2021", L"-Vd"}; - - if (bDXRSupported) { - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); - RunDXRTest(pDevice, pShader, D3D_SHADER_MODEL_6_9, args, _countof(args), - testData, windowSize, windowSize, true /*mesh*/, - false /*procedural geometry*/, false /*useIS*/); - std::map histo; - for (int val : testData) { - ++histo[val]; + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*mesh*/, + false /*procedural geometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 2); + VERIFY_ARE_EQUAL(Histo[2], 4030); + VERIFY_ARE_EQUAL(Histo[5], 66); +} + +TEST_F(ExecutionTest, SERInvokeNoSBTTest) { + // Test SER RayQuery with Invoke + static const char *ShaderSrc = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 WindowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + // Template parameter set at runtime before compilation + RayQuery rayQ; + + // Funtion parameter set at runtime before compilation + rayQ.TraceRayInline(topObject, RAY_FLAG_NONE, 0xFF, ray); + + // Storage for procedural primitive hit attributes + Attrs attrs; + attrs.barycentrics = float2(1, 1); + + while (rayQ.Proceed()) + { + switch (rayQ.CandidateType()) + { + case CANDIDATE_NON_OPAQUE_TRIANGLE: + { + // The system has already determined that the candidate would be the closest + // hit so far in the ray extents + rayQ.CommitNonOpaqueTriangleHit(); + } + } } - VERIFY_ARE_EQUAL(histo.size(), 2); - VERIFY_ARE_EQUAL(histo[2], 4030); - VERIFY_ARE_EQUAL(histo[5], 66); - } + + dx::HitObject hit = dx::HitObject::FromRayQuery(rayQ); + dx::MaybeReorderThread(hit); + // Set the payload based on the HitObject. + if (hit.IsHit()) + payload.visited |= 8U; + else + payload.visited |= 16U; + // Invoke should not trigger any shader. + dx::HitObject::Invoke(hit, payload); + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 2U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; + AcceptHitAndEndSearch(); } + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +)"; + + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, false)) + return; + + // Initialize test data. + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 2); + VERIFY_ARE_EQUAL(Histo[8], 66); + VERIFY_ARE_EQUAL(Histo[16], 4030); +} \ No newline at end of file From c2e2dee306ee735994790901a0062893aca1f3e7 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Fri, 2 May 2025 18:28:44 +0200 Subject: [PATCH 16/31] Test all MaybeReorderThread variants (in wave-incoherent execution) --- .../unittests/HLSLExec/ExecutionTest.cpp | 1 + .../unittests/HLSLExec/ExecutionTest_SER.h | 112 ++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index dfb2eb1328..061751e72a 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -631,6 +631,7 @@ class ExecutionTest { TEST_METHOD(SERShaderTableIndexTest); TEST_METHOD(SERLoadLocalRootTableConstantTest); TEST_METHOD(SERInvokeNoSBTTest); + TEST_METHOD(SERMaybeReorderThreadTest) dxc::DxcDllSupport m_support; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 1a5c2eb7ab..616433fb58 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -1899,4 +1899,116 @@ void closesthit(inout PerRayData payload, in Attrs attrs) VERIFY_ARE_EQUAL(Histo.size(), 2); VERIFY_ARE_EQUAL(Histo[8], 66); VERIFY_ARE_EQUAL(Histo[16], 4030); +} + +TEST_F(ExecutionTest, SERMaybeReorderThreadTest) { + // SER: Test MaybeReorderThread variants. + static const char *ShaderSrc = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + + if (launchIndex.x % 3 == 0) { + dx::MaybeReorderThread(hitObject); + } + else if (launchIndex.x % 3 == 1) { + dx::MaybeReorderThread(hitObject, 0xFF, 7); + } + else { + dx::MaybeReorderThread(0xFFF, 5); + } + + dx::HitObject::Invoke(hitObject, payload); + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 2U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +)"; + + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, false)) + return; + + // Initialize test data. + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 2); + VERIFY_ARE_EQUAL(Histo[2], 4030); + VERIFY_ARE_EQUAL(Histo[5], 66); } \ No newline at end of file From 6ab79fe5bbbb69a643c69b804574fd6bc587b58a Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Mon, 5 May 2025 13:08:31 +0200 Subject: [PATCH 17/31] DynamicHitObjectArrayTest - Array of 8 HitObjects, each with different ray flags - Samples at 'random' positions to block SROA from breaking down the array --- .../unittests/HLSLExec/ExecutionTest.cpp | 1 + .../unittests/HLSLExec/ExecutionTest_SER.h | 162 ++++++++++++++++++ 2 files changed, 163 insertions(+) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 061751e72a..b626519e8b 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -632,6 +632,7 @@ class ExecutionTest { TEST_METHOD(SERLoadLocalRootTableConstantTest); TEST_METHOD(SERInvokeNoSBTTest); TEST_METHOD(SERMaybeReorderThreadTest) + TEST_METHOD(SERDynamicHitObjectArrayTest); dxc::DxcDllSupport m_support; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 616433fb58..7bbdd17073 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -2011,4 +2011,166 @@ void closesthit(inout PerRayData payload, in Attrs attrs) VERIFY_ARE_EQUAL(Histo.size(), 2); VERIFY_ARE_EQUAL(Histo[2], 4030); VERIFY_ARE_EQUAL(Histo[5], 66); +} + +TEST_F(ExecutionTest, SERDynamicHitObjectArrayTest) { + // Test SER with dynamic access to local HitObject array + static const char *ShaderSrc = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +struct[raypayload] PerRayData +{ + uint dummy : read(caller) : write(miss, closesthit); +}; + +struct LocalConstants +{ + int c0; + int c1; + int c2; + int c3; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); +ConstantBuffer localConstants : register(b1); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x * sceneConstants.U.xyz + d.y * sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + int constants[4] = { localConstants.c0, localConstants.c1, localConstants.c2, localConstants.c3 }; + + const int NUM_SAMPLES = 64; + const int NUM_HITOBJECTS = 8; + + // Generate wave-incoerent sample positions + int sampleIndices[NUM_SAMPLES]; + int threadOffset = launchIndex.x; + for (int i = 0; i < NUM_SAMPLES; i++) + { + int baseIndex = i % 4; // Cycle through the 4 constants + sampleIndices[i] = abs(constants[baseIndex] + threadOffset + i * 3) % NUM_HITOBJECTS; + } + + // Define an array of ray flags + uint rayFlagsArray[NUM_HITOBJECTS] = { + RAY_FLAG_NONE, + RAY_FLAG_FORCE_OPAQUE, + RAY_FLAG_FORCE_NON_OPAQUE, + RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH, + RAY_FLAG_SKIP_CLOSEST_HIT_SHADER, + RAY_FLAG_CULL_BACK_FACING_TRIANGLES, + RAY_FLAG_CULL_FRONT_FACING_TRIANGLES, + RAY_FLAG_CULL_OPAQUE + }; + + // Create a local array of HitObjects with TraceRay + dx::HitObject hitObjects[NUM_HITOBJECTS]; + for (uint i = 0; i < NUM_HITOBJECTS; ++i) + { + PerRayData payload; + uint expectedRayFlags = rayFlagsArray[i]; + hitObjects[i] = dx::HitObject::TraceRay( + topObject, // Acceleration structure + expectedRayFlags, // Unique ray flag + 0xFF, // Instance mask + 0, // Ray contribution to hit group index + 1, // Multiplier for geometry contribution + 0, // Miss shader index + ray, // Ray description + payload // Payload + ); + } + + // Evaluate at sample positions. + int testVal = 0; + + for (uint i = 0; i < NUM_SAMPLES; i++) + { + int idx = sampleIndices[i]; + // Verify that the rayFlags match + uint actualRayFlags = hitObjects[idx].GetRayFlags(); + uint expectedRayFlags = rayFlagsArray[idx]; + if (expectedRayFlags != actualRayFlags) + { + testVal = 1; // Mark as failure if flags do not match + } + } + + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = testVal; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + // UNUSED +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + AcceptHitAndEndSearch(); +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + // UNUSED +} + +)"; + + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, false)) + return; + + // Initialize test data. + const int WindowSize = 64; + + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*mesh*/, + false /*procedural geometry*/, false /*useIS*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 1); + VERIFY_ARE_EQUAL(Histo[0], 4096); } \ No newline at end of file From d31a0ad3474ae459392a6d16f80f1b4afc1c4e9b Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Tue, 6 May 2025 10:48:17 +0200 Subject: [PATCH 18/31] Cleanup / fixes and update expected values based on changed geometry (procedural) - Support for procedural geometry and triangles at the same time - fix: make IS non-optional - Made payload/attributeCount in RunDXRTest non-defaulting - Use a circle for procedural geoemtry and make sure it fits into the AABB --- .../unittests/HLSLExec/ExecutionTest.cpp | 262 ++++++++---------- .../unittests/HLSLExec/ExecutionTest_SER.h | 165 +++++++---- 2 files changed, 228 insertions(+), 199 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index b626519e8b..c650f275cc 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -2061,8 +2061,7 @@ class ExecutionTest { int NumOptions, std::vector &TestData, int WindowWidth, int WindowHeight, bool UseMesh, bool UseProceduralGeometry, - bool UseIS, int PayloadCount = 1, - int AttributeCount = 2); + int PayloadCount, int AttributeCount); void SetDescriptorHeap(ID3D12GraphicsCommandList *pCommandList, ID3D12DescriptorHeap *pHeap) { @@ -2246,7 +2245,7 @@ CComPtr ExecutionTest::RunDXRTest( ID3D12Device *Device0, LPCSTR ShaderSrc, LPCWSTR TargetProfile, LPCWSTR *Options, int NumOptions, std::vector &TestData, int WindowWidth, int WindowHeight, bool UseMesh, bool UseProceduralGeometry, - bool UseIS, int PayloadCount, int AttributeCount) { + int PayloadCount, int AttributeCount) { CComPtr Device; VERIFY_SUCCEEDED(Device0->QueryInterface(IID_PPV_ARGS(&Device))); @@ -2432,10 +2431,12 @@ CComPtr ExecutionTest::RunDXRTest( Lib->DefineExport(L"closesthit"); Lib->DefineExport(L"anyhit"); Lib->DefineExport(L"miss"); - if (UseIS) + if (UseProceduralGeometry) Lib->DefineExport(L"intersection"); - if (UseMesh && UseProceduralGeometry) + if (UseMesh && UseProceduralGeometry) { + Lib->DefineExport(L"ahAABB"); Lib->DefineExport(L"chAABB"); + } const int MaxRecursion = 1; StateObjectDesc.CreateSubobject() @@ -2460,30 +2461,32 @@ CComPtr ExecutionTest::RunDXRTest( Exports->AddExport(L"closesthit"); Exports->AddExport(L"anyhit"); Exports->AddExport(L"miss"); - if (UseIS) + if (UseProceduralGeometry) Exports->AddExport(L"intersection"); - if (UseMesh && UseProceduralGeometry) + if (UseMesh && UseProceduralGeometry) { + Exports->AddExport(L"ahAABB"); Exports->AddExport(L"chAABB"); + } auto HitGroup = StateObjectDesc.CreateSubobject(); HitGroup->SetClosestHitShaderImport(L"closesthit"); HitGroup->SetAnyHitShaderImport(L"anyhit"); - if (UseIS) { + if (!UseMesh && UseProceduralGeometry) { HitGroup->SetIntersectionShaderImport(L"intersection"); HitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); + } else { + HitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_TRIANGLES); } HitGroup->SetHitGroupExport(L"HitGroup"); if (UseMesh && UseProceduralGeometry) { auto HitGroupAABB = StateObjectDesc.CreateSubobject(); + HitGroupAABB->SetAnyHitShaderImport(L"ahAABB"); HitGroupAABB->SetClosestHitShaderImport(L"chAABB"); - HitGroupAABB->SetAnyHitShaderImport(L"anyhit"); - if (UseIS) { - HitGroupAABB->SetIntersectionShaderImport(L"intersection"); - HitGroupAABB->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); - } + HitGroupAABB->SetIntersectionShaderImport(L"intersection"); + HitGroupAABB->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); HitGroupAABB->SetHitGroupExport(L"HitGroupAABB"); } @@ -2545,7 +2548,6 @@ CComPtr ExecutionTest::RunDXRTest( CComPtr TLASResource; CComPtr BLASMeshResource; CComPtr BLASProceduralGeometryResource; - CComPtr InstanceDescs; CComPtr ScratchResource; if (UseMesh) { @@ -2624,67 +2626,58 @@ CComPtr ExecutionTest::RunDXRTest( VERIFY_SUCCEEDED(CommandAllocator->Reset()); VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, nullptr)); - if (!UseIS) { - // Build BLAS. - { - D3D12_RAYTRACING_GEOMETRY_DESC GeometryDesc = {}; - GeometryDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES; - GeometryDesc.Triangles.IndexBuffer = - IndexBuffer->GetGPUVirtualAddress(); - GeometryDesc.Triangles.IndexCount = - static_cast(IndexBuffer->GetDesc().Width) / sizeof(int); - GeometryDesc.Triangles.IndexFormat = DXGI_FORMAT_R32_UINT; - GeometryDesc.Triangles.Transform3x4 = 0; - GeometryDesc.Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT; - GeometryDesc.Triangles.VertexCount = - static_cast(VertexBuffer->GetDesc().Width) / - sizeof(DirectX::XMFLOAT3); - GeometryDesc.Triangles.VertexBuffer.StartAddress = - VertexBuffer->GetGPUVirtualAddress(); - GeometryDesc.Triangles.VertexBuffer.StrideInBytes = - sizeof(DirectX::XMFLOAT3); - GeometryDesc.Flags = - D3D12_RAYTRACING_GEOMETRY_FLAG_NONE; // Non-opaque to trigger - // anyhit. - - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE; - - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; - AccelInputs.Type = - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; - AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; - AccelInputs.pGeometryDescs = &GeometryDesc; - AccelInputs.NumDescs = 1; - AccelInputs.Flags = BuildFlags; - - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; - Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, - &PrebuildInfo); - - ScratchResource.Release(); - ReallocScratchResource(Device, &ScratchResource, - PrebuildInfo.ScratchDataSizeInBytes); - AllocateBuffer(Device, PrebuildInfo.ResultDataMaxSizeInBytes, - &BLASMeshResource, true, - D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, - L"blasMesh"); - - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; - BuildDesc.Inputs = AccelInputs; - BuildDesc.ScratchAccelerationStructureData = - ScratchResource->GetGPUVirtualAddress(); - BuildDesc.DestAccelerationStructureData = - BLASMeshResource->GetGPUVirtualAddress(); - - CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, - nullptr); - CD3DX12_RESOURCE_BARRIER Barrier = - CD3DX12_RESOURCE_BARRIER::UAV(BLASMeshResource); - CommandList->ResourceBarrier(1, - (const D3D12_RESOURCE_BARRIER *)&Barrier); - } - } + // Build triangle BLAS. + D3D12_RAYTRACING_GEOMETRY_DESC GeometryDesc = {}; + GeometryDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES; + GeometryDesc.Triangles.IndexBuffer = IndexBuffer->GetGPUVirtualAddress(); + GeometryDesc.Triangles.IndexCount = + static_cast(IndexBuffer->GetDesc().Width) / sizeof(int); + GeometryDesc.Triangles.IndexFormat = DXGI_FORMAT_R32_UINT; + GeometryDesc.Triangles.Transform3x4 = 0; + GeometryDesc.Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT; + GeometryDesc.Triangles.VertexCount = + static_cast(VertexBuffer->GetDesc().Width) / + sizeof(DirectX::XMFLOAT3); + GeometryDesc.Triangles.VertexBuffer.StartAddress = + VertexBuffer->GetGPUVirtualAddress(); + GeometryDesc.Triangles.VertexBuffer.StrideInBytes = + sizeof(DirectX::XMFLOAT3); + GeometryDesc.Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_NONE; // Non-opaque to + // trigger anyhit. + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE; + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; + AccelInputs.Type = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + AccelInputs.pGeometryDescs = &GeometryDesc; + AccelInputs.NumDescs = 1; + AccelInputs.Flags = BuildFlags; + + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; + Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, + &PrebuildInfo); + + ScratchResource.Release(); + ReallocScratchResource(Device, &ScratchResource, + PrebuildInfo.ScratchDataSizeInBytes); + AllocateBuffer( + Device, PrebuildInfo.ResultDataMaxSizeInBytes, &BLASMeshResource, true, + D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"blasMesh"); + + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; + BuildDesc.Inputs = AccelInputs; + BuildDesc.ScratchAccelerationStructureData = + ScratchResource->GetGPUVirtualAddress(); + BuildDesc.DestAccelerationStructureData = + BLASMeshResource->GetGPUVirtualAddress(); + + CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, nullptr); + CD3DX12_RESOURCE_BARRIER Barrier = + CD3DX12_RESOURCE_BARRIER::UAV(BLASMeshResource); + CommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&Barrier); CommandList->Close(); ExecuteCommandList(CommandQueue, CommandList); @@ -2701,9 +2694,10 @@ CComPtr ExecutionTest::RunDXRTest( // Define the AABB for the plane, matching the size of the quad defined by // verts[] + const float BoxSize = 500.f; const D3D12_RAYTRACING_AABB Aabb = { - -150.5f, -500.5f, -1000.0f, // Min corner (x, y, z) - 150.5f, -150.5f, 1000.0f // Max corner (x, y, z) + -BoxSize, -BoxSize, -BoxSize, // Min corner (x, y, z) + BoxSize, BoxSize, BoxSize // Max corner (x, y, z) }; const UINT64 AabbDataSize = sizeof(Aabb); @@ -2771,88 +2765,64 @@ CComPtr ExecutionTest::RunDXRTest( } // Build TLAS. + CComPtr InstanceDescs; { - if (UseMesh) { - D3D12_RAYTRACING_INSTANCE_DESC InstanceDesc = {}; - InstanceDesc.Transform[0][0] = InstanceDesc.Transform[1][1] = - InstanceDesc.Transform[2][2] = 1; - InstanceDesc.InstanceMask = 1; - InstanceDesc.AccelerationStructure = - BLASMeshResource->GetGPUVirtualAddress(); + D3D12_RAYTRACING_INSTANCE_DESC CPUInstanceDescs[2] = {}; + const int MeshIdx = 0; + const int ProcGeoIdx = UseMesh && UseProceduralGeometry ? 1 : 0; + const int NumInstanceDescs = ProcGeoIdx + 1; - AllocateUploadBuffer(Device, &InstanceDesc, sizeof(InstanceDesc), - &InstanceDescs, L"InstanceDescs"); - - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; - - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; - AccelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; - AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; - AccelInputs.NumDescs = 1; - AccelInputs.Flags = BuildFlags; - AccelInputs.InstanceDescs = InstanceDescs->GetGPUVirtualAddress(); - - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; - Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, - &PrebuildInfo); - - ScratchResource.Release(); - ReallocScratchResource(Device, &ScratchResource, - PrebuildInfo.ScratchDataSizeInBytes); - AllocateBuffer( - Device, PrebuildInfo.ResultDataMaxSizeInBytes, &TLASResource, true, - D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); - - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; - BuildDesc.Inputs = AccelInputs; - BuildDesc.ScratchAccelerationStructureData = - ScratchResource->GetGPUVirtualAddress(); - BuildDesc.DestAccelerationStructureData = - TLASResource->GetGPUVirtualAddress(); - - CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, 0); - } else { - D3D12_RAYTRACING_INSTANCE_DESC InstanceDesc = {}; + for (int i = 0; i < NumInstanceDescs; ++i) { + D3D12_RAYTRACING_INSTANCE_DESC &InstanceDesc = CPUInstanceDescs[i]; InstanceDesc.Transform[0][0] = InstanceDesc.Transform[1][1] = InstanceDesc.Transform[2][2] = 1; + InstanceDesc.InstanceID = i; + InstanceDesc.InstanceContributionToHitGroupIndex = i; InstanceDesc.InstanceMask = 1; - InstanceDesc.AccelerationStructure = + InstanceDesc.Flags = D3D12_RAYTRACING_INSTANCE_FLAG_NONE; + } + + if (UseMesh) + CPUInstanceDescs[MeshIdx].AccelerationStructure = + BLASMeshResource->GetGPUVirtualAddress(); + if (UseProceduralGeometry) + CPUInstanceDescs[ProcGeoIdx].AccelerationStructure = BLASProceduralGeometryResource->GetGPUVirtualAddress(); - AllocateUploadBuffer(Device, &InstanceDesc, sizeof(InstanceDesc), - &InstanceDescs, L"InstanceDescs"); + AllocateUploadBuffer(Device, &CPUInstanceDescs, + NumInstanceDescs * + sizeof(D3D12_RAYTRACING_INSTANCE_DESC), + &InstanceDescs, L"InstanceDescs"); - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS BuildFlags = + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; - AccelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; - AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; - AccelInputs.NumDescs = 1; - AccelInputs.Flags = BuildFlags; - AccelInputs.InstanceDescs = InstanceDescs->GetGPUVirtualAddress(); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS AccelInputs = {}; + AccelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; + AccelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY; + AccelInputs.NumDescs = NumInstanceDescs; + AccelInputs.Flags = BuildFlags; + AccelInputs.InstanceDescs = InstanceDescs->GetGPUVirtualAddress(); - D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; - Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, - &PrebuildInfo); + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO PrebuildInfo = {}; + Device->GetRaytracingAccelerationStructurePrebuildInfo(&AccelInputs, + &PrebuildInfo); - ScratchResource.Release(); - ReallocScratchResource(Device, &ScratchResource, - PrebuildInfo.ScratchDataSizeInBytes); - AllocateBuffer( - Device, PrebuildInfo.ResultDataMaxSizeInBytes, &TLASResource, true, - D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS"); + ScratchResource.Release(); + ReallocScratchResource(Device, &ScratchResource, + PrebuildInfo.ScratchDataSizeInBytes); + AllocateBuffer(Device, PrebuildInfo.ResultDataMaxSizeInBytes, &TLASResource, + true, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, + L"TLAS"); - D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; - BuildDesc.Inputs = AccelInputs; - BuildDesc.ScratchAccelerationStructureData = - ScratchResource->GetGPUVirtualAddress(); - BuildDesc.DestAccelerationStructureData = - TLASResource->GetGPUVirtualAddress(); + D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BuildDesc = {}; + BuildDesc.Inputs = AccelInputs; + BuildDesc.ScratchAccelerationStructureData = + ScratchResource->GetGPUVirtualAddress(); + BuildDesc.DestAccelerationStructureData = + TLASResource->GetGPUVirtualAddress(); - CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, 0); - } + CommandList->BuildRaytracingAccelerationStructure(&BuildDesc, 0, 0); CD3DX12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::UAV(TLASResource); diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 7bbdd17073..be8df24b19 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -23,7 +23,7 @@ struct SceneConstants float4 W; float sceneScale; uint2 WindowSize; - int rayFlags; + int rayFlags; }; struct[raypayload] PerRayData @@ -117,7 +117,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetRayTMin"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { float *ResArray = (float *)(TestData.data() + Id); float RefVal = ResArray[0]; @@ -143,7 +144,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetRayTCurrent"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { float *ResArray = (float *)(TestData.data() + Id); float RefVal = ResArray[0]; @@ -169,7 +171,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetRayFlags"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { const int RefVal = TestData[Id]; const int SerVal = TestData[Id + 1]; @@ -193,7 +196,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetHitKind"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { const int RefVal = TestData[Id]; const int SerVal = TestData[Id + 1]; @@ -217,7 +221,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetGeometryIndex"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { const int RefVal = TestData[Id]; const int SerVal = TestData[Id + 1]; @@ -241,7 +246,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetInstanceIndex"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { const int RefVal = TestData[Id]; const int SerVal = TestData[Id + 1]; @@ -265,7 +271,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetInstanceID"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { const int RefVal = TestData[Id]; const int SerVal = TestData[Id + 1]; @@ -289,7 +296,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_SCALAR=GetPrimitiveIndex"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2) { const int RefVal = TestData[Id]; const int SerVal = TestData[Id + 1]; @@ -407,8 +415,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_VECTOR=GetWorldRayOrigin"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 3 /*payloadCount*/); + false /*useProceduralGeometry*/, 3 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 6) { float *ResArray = (float *)(TestData.data() + Id); float RefX = ResArray[0]; @@ -440,8 +448,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_VECTOR=GetWorldRayDirection"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 3 /*payloadCount*/); + false /*useProceduralGeometry*/, 3 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 6) { float *ResArray = (float *)(TestData.data() + Id); float RefX = ResArray[0]; @@ -472,8 +480,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_VECTOR=GetObjectRayOrigin"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 3 /*payloadCount*/); + false /*useProceduralGeometry*/, 3 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 6) { float *ResArray = (float *)(TestData.data() + Id); float RefX = ResArray[0]; @@ -505,8 +513,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DSER_GET_VECTOR=GetObjectRayDirection"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 3 /*payloadCount*/); + false /*useProceduralGeometry*/, 3 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 6) { float *ResArray = (float *)(TestData.data() + Id); float RefX = ResArray[0]; @@ -657,8 +665,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DCOLS=4"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 12 /*payloadCount*/); + false /*useProceduralGeometry*/, 12 /*payloadCount*/, + 2 /*attributeCount*/); const int ROWS = 3; const int COLS = 4; for (int Id = 0; Id < TestData.size(); Id += 24) { @@ -692,8 +700,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DCOLS=3"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 12 /*payloadCount*/); + false /*useProceduralGeometry*/, 12 /*payloadCount*/, + 2 /*attributeCount*/); for (int Id = 0; Id < TestData.size(); Id += 2 * ROWS * COLS) { float *ResArray = (float *)(TestData.data() + Id); for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { @@ -723,8 +731,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DCOLS=4"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 12 /*payloadCount*/); + false /*useProceduralGeometry*/, 12 /*payloadCount*/, + 2 /*attributeCount*/); const int ROWS = 3; const int COLS = 4; for (int Id = 0; Id < TestData.size(); Id += 24) { @@ -756,8 +764,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) L"-DCOLS=3"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/, - 12 /*payloadCount*/); + false /*useProceduralGeometry*/, 12 /*payloadCount*/, + 2 /*attributeCount*/); const int ROWS = 4; const int COLS = 3; for (int Id = 0; Id < TestData.size(); Id += 24) { @@ -871,7 +879,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; @@ -905,6 +914,11 @@ struct Attrs float2 barycentrics : BARYCENTRICS; }; +struct CustomAttrs +{ + float dist; +}; + RWStructuredBuffer testBuffer : register(u0); RaytracingAccelerationStructure topObject : register(t0); ConstantBuffer sceneConstants : register(b0); @@ -938,7 +952,6 @@ void raygen() // SER Test dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); dx::MaybeReorderThread(hitObject); - dx::HitObject::Invoke(hitObject, payload); if (hitObject.IsHit()) { @@ -947,7 +960,7 @@ void raygen() dx::HitObject::Invoke( hitObject, payload ); // Poison the test data if GetShaderTableIndex does not match SetShaderTableIndex. if (hitObject.GetShaderTableIndex() != 1) - payload.visited = 0; + payload.visited = 12345; } int id = launchIndex.x + launchIndex.y * launchDim.x; @@ -957,13 +970,14 @@ void raygen() [shader("miss")] void miss(inout PerRayData payload) { - payload.visited |= 2U; + payload.visited |= 1U; } +// Triangles [shader("anyhit")] void anyhit(inout PerRayData payload, in Attrs attrs) { - payload.visited |= 1U; + payload.visited |= 2U; AcceptHitAndEndSearch(); } @@ -973,10 +987,43 @@ void closesthit(inout PerRayData payload, in Attrs attrs) payload.visited |= 4U; } +// Procedural +[shader("intersection")] +void intersection() +{ + // Intersection with circle on a plane (base, n, radius) + // hitPos is intersection point with plane (base, n) + float3 base = {0.0f,0.0f,0.5f}; + float3 n = normalize(float3(0.2f,0.2f,0.5f)); + float radius = 150.f; + // Plane hit + float t = dot(n, base - ObjectRayOrigin()) / dot(n, ObjectRayDirection()); + if (t > RayTCurrent() || t < RayTMin()) { + return; + } + float3 hitPos = ObjectRayOrigin() + t * ObjectRayDirection(); + float3 relHitPos = hitPos - base; + // Circle hit + float hitDist = length(relHitPos); + if (hitDist > radius) + return; + + CustomAttrs attrs; + attrs.dist = hitDist; + ReportHit(t, 1, attrs); +} + +[shader("anyhit")] +void ahAABB(inout PerRayData payload, in CustomAttrs attrs) +{ + payload.visited |= 8U; + IgnoreHit(); +} + [shader("closesthit")] void chAABB(inout PerRayData payload, in Attrs attrs) { - payload.visited |= 8U; + payload.visited |= 16U; } )"; @@ -990,16 +1037,19 @@ void chAABB(inout PerRayData payload, in Attrs attrs) std::vector TestData(WindowSize * WindowSize, 0); LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; - WEX::Logging::Log::Comment(L"==== DXR lib_6_9 with SER"); RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*mesh*/, - true /*procedural geometry*/, false /*useIS*/); + true /*procedural geometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; - VERIFY_ARE_EQUAL(Histo.size(), 2); - VERIFY_ARE_EQUAL(Histo[2], 4030); - VERIFY_ARE_EQUAL(Histo[13], 66); + VERIFY_ARE_EQUAL(Histo.size(), 3); + VERIFY_ARE_EQUAL(Histo[0], 3696); // Miss (not Invoked) + VERIFY_ARE_EQUAL(Histo[8], 334); // AABB ignored hit -> (Miss not Invoked) + VERIFY_ARE_EQUAL( + Histo[26], + 66); // AABB ignored hit + TriHit -> setSBT(1) -> chAABB invoked } TEST_F(ExecutionTest, SERLoadLocalRootTableConstantTest) { @@ -1114,7 +1164,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; @@ -1271,7 +1322,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; @@ -1387,13 +1439,13 @@ void intersection() RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, false /*mesh*/, - true /*procedural geometry*/, true /*useIS*/); + true /*procedural geometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; - VERIFY_ARE_EQUAL(Histo.size(), 2); - VERIFY_ARE_EQUAL(Histo[2], 3400); - VERIFY_ARE_EQUAL(Histo[5], 696); + VERIFY_ARE_EQUAL(Histo.size(), 1); + VERIFY_ARE_EQUAL(Histo[5], 4096); // All rays hitting the procedural geometry } TEST_F(ExecutionTest, SERGetAttributesTest) { @@ -1531,15 +1583,17 @@ void intersection() RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, false /*mesh*/, - true /*procedural geometry*/, true /*useIS*/); + true /*procedural geometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; - VERIFY_ARE_EQUAL(Histo.size(), 4); - VERIFY_ARE_EQUAL(Histo[0], 328); - VERIFY_ARE_EQUAL(Histo[1], 186); - VERIFY_ARE_EQUAL(Histo[3], 182); - VERIFY_ARE_EQUAL(Histo[255], 3400); + VERIFY_ARE_EQUAL(Histo.size(), 5); + VERIFY_ARE_EQUAL(Histo[0], 2009); + VERIFY_ARE_EQUAL(Histo[1], 561); + VERIFY_ARE_EQUAL(Histo[3], 587); + VERIFY_ARE_EQUAL(Histo[4], 454); + VERIFY_ARE_EQUAL(Histo[6], 485); } TEST_F(ExecutionTest, SERTraceHitMissNopTest) { @@ -1649,7 +1703,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*mesh*/, - false /*procedural geometry*/, false /*useIS*/); + false /*procedural geometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; @@ -1763,7 +1818,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*mesh*/, - false /*procedural geometry*/, false /*useIS*/); + false /*procedural geometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; @@ -1892,7 +1948,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; @@ -2004,7 +2061,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, false /*useIS*/); + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; @@ -2167,7 +2225,8 @@ void closesthit(inout PerRayData payload, in Attrs attrs) LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, WindowSize, WindowSize, true /*mesh*/, - false /*procedural geometry*/, false /*useIS*/); + false /*procedural geometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); std::map Histo; for (int Val : TestData) ++Histo[Val]; From 8f1c59847e888921ca6475ddc0e4870372ce5bd0 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Mon, 5 May 2025 14:00:03 +0200 Subject: [PATCH 19/31] WaveIncoherentHitTest --- .../unittests/HLSLExec/ExecutionTest.cpp | 1 + .../unittests/HLSLExec/ExecutionTest_SER.h | 182 ++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index c650f275cc..ffa696d88c 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -633,6 +633,7 @@ class ExecutionTest { TEST_METHOD(SERInvokeNoSBTTest); TEST_METHOD(SERMaybeReorderThreadTest) TEST_METHOD(SERDynamicHitObjectArrayTest); + TEST_METHOD(SERWaveIncoherentHitTest); dxc::DxcDllSupport m_support; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index be8df24b19..9e69d7f50b 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -2232,4 +2232,186 @@ void closesthit(inout PerRayData payload, in Attrs attrs) ++Histo[Val]; VERIFY_ARE_EQUAL(Histo.size(), 1); VERIFY_ARE_EQUAL(Histo[0], 4096); +} + +TEST_F(ExecutionTest, SERWaveIncoherentHitTest) { + // Test SER with wave incoherent conditional assignment of HitObject values + // with and without procedural attributes. + static const char *ShaderSrc = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +struct CustomAttrs +{ + float dist; +}; + +RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +static const uint ProceduralHitKind = 11; + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + dx::HitObject hitObject; + + // Use wave incoherence to decide how to create the HitObject + if (launchIndex.x % 4 == 1) + { + ray.Origin.x += 2.0f; + hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_CLOSEST_HIT_SHADER, 0xFF, 0, 0, 0, ray, payload); + } + else if (launchIndex.x % 4 == 2) + { + hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES, 0xFF, 0, 0, 0, ray, payload); + } + else if (launchIndex.x % 4 == 3) + { + hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_TRIANGLES, 0xFF, 0, 0, 0, ray, payload); + } + + dx::MaybeReorderThread(hitObject); + + if (hitObject.IsNop()) + payload.visited |= 1U; + if (hitObject.IsMiss()) + payload.visited |= 2U; + + if (hitObject.GetHitKind() == ProceduralHitKind) + payload.visited |= 8U; + else if (hitObject.IsHit()) + payload.visited |= 4U; + + dx::HitObject::Invoke(hitObject, payload); + + // Store the result in the buffer + int id = launchIndex.x + launchIndex.y * launchDim.x; + testBuffer[id] = payload.visited; +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + // UNUSED +} + +// Triangles +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + AcceptHitAndEndSearch(); +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 16U; +} + +// Procedural +[shader("closesthit")] +void chAABB(inout PerRayData payload, in CustomAttrs attrs) +{ + payload.visited |= 32U; +} + +[shader("anyhit")] +void ahAABB(inout PerRayData payload, in CustomAttrs attrs) +{ + // UNUSED +} + +[shader("intersection")] +void intersection() +{ + // Intersection with circle on a plane (base, n, radius) + // hitPos is intersection point with plane (base, n) + float3 base = {0.0f,0.0f,0.5f}; + float3 n = normalize(float3(0.0f,0.5f,0.5f)); + float radius = 1000.f; + // Plane hit + float t = dot(n, base - ObjectRayOrigin()) / dot(n, ObjectRayDirection()); + if (t > RayTCurrent() || t < RayTMin()) { + return; + } + float3 hitPos = ObjectRayOrigin() + t * ObjectRayDirection(); + float3 relHitPos = hitPos - base; + // Circle hit + float hitDist = length(relHitPos); + if (hitDist > radius) + return; + + CustomAttrs attrs; + attrs.dist = hitDist; + ReportHit(t, ProceduralHitKind, attrs); +} + +)"; + + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, false)) + return; + + // Initialize test data. + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*mesh*/, + true /*procedural geometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 6); + VERIFY_ARE_EQUAL(Histo[1], 1024); // nop + VERIFY_ARE_EQUAL(Histo[2], 1022); // miss + VERIFY_ARE_EQUAL(Histo[4], 12); // triangle hit, no ch + VERIFY_ARE_EQUAL(Histo[8], 1008); // procedural hit, no ch + VERIFY_ARE_EQUAL(Histo[20], 11); // triangle hit, 'closesthit' invoked + VERIFY_ARE_EQUAL(Histo[40], 1019); // procedural hit, 'chAABB' invoked } \ No newline at end of file From e1df92bfc1952717fc44388bbbb5c16d527eb1cf Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Tue, 6 May 2025 10:23:30 +0200 Subject: [PATCH 20/31] SERReorderCoherentTest --- .../unittests/HLSLExec/ExecutionTest.cpp | 1 + .../unittests/HLSLExec/ExecutionTest_SER.h | 137 ++++++++++++++++++ 2 files changed, 138 insertions(+) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index ffa696d88c..9be631fd83 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -634,6 +634,7 @@ class ExecutionTest { TEST_METHOD(SERMaybeReorderThreadTest) TEST_METHOD(SERDynamicHitObjectArrayTest); TEST_METHOD(SERWaveIncoherentHitTest); + TEST_METHOD(SERReorderCoherentTest); dxc::DxcDllSupport m_support; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 9e69d7f50b..18a4397e0d 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -2414,4 +2414,141 @@ void intersection() VERIFY_ARE_EQUAL(Histo[8], 1008); // procedural hit, no ch VERIFY_ARE_EQUAL(Histo[20], 11); // triangle hit, 'closesthit' invoked VERIFY_ARE_EQUAL(Histo[40], 1019); // procedural hit, 'chAABB' invoked +} + +TEST_F(ExecutionTest, SERReorderCoherentTest) { + // SER: Test reordercoherent + static const char *ShaderSrc = R"( +struct SceneConstants +{ + float4 eye; + float4 U; + float4 V; + float4 W; + float sceneScale; + uint2 windowSize; + int rayFlags; +}; + +struct[raypayload] PerRayData +{ + uint visited : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller); +}; + +struct Attrs +{ + float2 barycentrics : BARYCENTRICS; +}; + +reordercoherent RWStructuredBuffer testBuffer : register(u0); +RaytracingAccelerationStructure topObject : register(t0); +ConstantBuffer sceneConstants : register(b0); + +RayDesc ComputeRay() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; + RayDesc ray; + ray.Origin = sceneConstants.eye.xyz; + ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); + ray.TMin = 0; + ray.TMax = 1e18; + + return ray; +} + +[shader("raygeneration")] +void raygen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + uint threadId = launchIndex.x + launchIndex.y * launchDim.x; + + RayDesc ray = ComputeRay(); + + PerRayData payload; + payload.visited = 0; + + // Initial test value. + testBuffer[threadId] = threadId; + + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + + // Conditionally update the test value. + if (hitObject.IsHit()) + { + testBuffer[threadId] += 10; // Add 10 to hits + } + else + { + testBuffer[threadId] += 20; // Add 20 to misses + } + + Barrier(UAV_MEMORY, REORDER_SCOPE); + dx::MaybeReorderThread(hitObject); + + // Conditionally update the test value. + if (threadId % 2 == 0) + { + testBuffer[threadId] += 1000; // Add 1000 to even threads + } + else + { + testBuffer[threadId] += 2000; // Add 2000 to odd threads + } + + // Verify test value. + uint expectedValue = (hitObject.IsHit() ? threadId + 10 : threadId + 20); + expectedValue += (threadId % 2 == 0 ? 1000 : 2000); + if (testBuffer[threadId] != expectedValue) + { + // Mark failure in the buffer if the result does not match + testBuffer[threadId] = 0; + } + else + { + testBuffer[threadId] = 1; + } +} + +[shader("miss")] +void miss(inout PerRayData payload) +{ + payload.visited |= 2U; +} + +[shader("anyhit")] +void anyhit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 1U; +} + +[shader("closesthit")] +void closesthit(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 4U; +} + +)"; + + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, false)) + return; + + // Initialize test data. + const int WindowSize = 64; + std::vector TestData(WindowSize * WindowSize, 0); + LPCWSTR Args[] = {L"-HV 2021", L"-Vd"}; + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, + WindowSize, WindowSize, true /*useMesh*/, + false /*useProceduralGeometry*/, 1 /*payloadCount*/, + 2 /*attributeCount*/); + std::map Histo; + for (int Val : TestData) + ++Histo[Val]; + VERIFY_ARE_EQUAL(Histo.size(), 1); + VERIFY_ARE_EQUAL(Histo[1], 4096); } \ No newline at end of file From 8ab7045fc624681c2b015d0fc8f10339f326e982 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Tue, 6 May 2025 19:34:50 -0400 Subject: [PATCH 21/31] [CoopVec] Initial CoopVec ExecutionTest support for Mul[Add] and OuterProductAccumulate (#7424) ExecutionTest::CoopVec_MulAdd: Functional verification for the Mul() and MulAdd() HLSL APIs. The driver matrix conversion API is tested as well. These tests should be considered as work-in-progress as this point. They include coverage primarily for SINT8, FLOAT16, FLOAT_E4M3, and FLOAT_E5M2. The test queries the driver for all supported configurations and runs each one, with a filtering mechanism to limit the set of tests to the minimal feature set. The set of tests can be further filtered by the following TE parameters: CoopVecMatrixInterp: SINT8, FLOAT16, FLOAT_E4M3, ... CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, OUTER_PRODUCT_OPTIMAL CoopVecBiasInterp: SINT32, FLOAT16, FLOAT_E4M3, ... CoopVecInputInterp: SINT8, FLOAT16, FLOAT_E4M3, ... CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, FLOAT32, ... CoopVecOutputType: SINT32, UINT32, FLOAT16, FLOAT32, ... Filter example: $ TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 -p:CoopVecMatrixLayout=MUL_OPTIMAL Precision coverage is minimal at this point, using an all-ones input matrix and test vector with ones in the first two components. This is enough to test basic functionality, but more comprehensive tests are needed. ExecutionTest::CoopVec_OuterProduct: Functional verification for the OuterProductAccumulate() HLSL API. This test queries the driver for all supported configurations and runs each one. No filtering is currently implemented. --- tools/clang/unittests/HLSLExec/CoopVec.h | 359 ++++ tools/clang/unittests/HLSLExec/CoopVecAPI.h | 178 ++ .../unittests/HLSLExec/ExecutionTest.cpp | 1451 ++++++++++++++++- 3 files changed, 1981 insertions(+), 7 deletions(-) create mode 100644 tools/clang/unittests/HLSLExec/CoopVec.h create mode 100644 tools/clang/unittests/HLSLExec/CoopVecAPI.h diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h new file mode 100644 index 0000000000..f166c61f67 --- /dev/null +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -0,0 +1,359 @@ +#pragma once + +#if HAVE_COOPVEC_API + +#include +#include +#include + +#include "dxc/Support/microcom.h" + +#include "CoopVecAPI.h" + +struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { +private: + DXC_MICROCOM_REF_FIELD(RefCount) + dxc::DxcDllSupport &DxcSupport; + +public: + LinAlgHeaderIncludeHandler() = delete; + LinAlgHeaderIncludeHandler(dxc::DxcDllSupport &DxcSupport) + : RefCount(0), DxcSupport(DxcSupport) {} + + DXC_MICROCOM_ADDREF_RELEASE_IMPL(RefCount) + + HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR Filename, + IDxcBlob **IncludeSource) { + if (wcscmp(Filename, L"dx/linalg.h") == 0 || + wcscmp(Filename, L".\\dx\\linalg.h") == 0) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( + L"LinAlgHeader", ParamValue))) { + return E_FAIL; + } + if (ParamValue.IsEmpty()) { + return E_FAIL; + } + LPCWSTR RealHeaderPath = + reinterpret_cast(ParamValue.GetBuffer()); + + CComPtr HeaderUtils; + + IFT(DxcSupport.CreateInstance(CLSID_DxcUtils, &HeaderUtils)); + + IDxcBlobEncoding *HeaderBlob; + IFT(HeaderUtils->LoadFile(RealHeaderPath, nullptr, &HeaderBlob)); + + *IncludeSource = HeaderBlob; + + return S_OK; + } + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID IID, void **Object) override { +// FIXME: This is a workaround for a warning-as-error about unused parameters. +#pragma warning(push) +#pragma warning(disable : 4100) + return DoBasicQueryInterface(this, IID, Object); +#pragma warning(pop) + } +}; + +namespace CoopVecHelpers { +template +static std::vector CreateAllOnesInputMatrix(uint32_t Width, + uint32_t Height) { + std::vector InputMatrix(Width * Height); + for (uint32_t i = 0; i < Width * Height; i++) { + if constexpr (std::is_same_v || + std::is_same_v) { + InputMatrix[i] = 1; + } else if constexpr (std::is_same_v) { + InputMatrix[i] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + InputMatrix[i] = 1.0f; + } else { + WEX::Logging::Log::Error(L"Unsupported input type"); + break; + } + } + + // Convert to uint8_t vector + std::vector Uint8InputMatrix(InputMatrix.size() * sizeof(EltTy)); + std::memcpy(Uint8InputMatrix.data(), InputMatrix.data(), + InputMatrix.size() * sizeof(EltTy)); + return Uint8InputMatrix; +} + +template +static std::vector CreateInputVector(uint32_t NumThreads, + uint32_t EltsPerThread) { + std::vector InputVector(NumThreads * EltsPerThread); + std::fill(InputVector.begin(), InputVector.end(), EltTy(0)); + if (EltsPerThread < 2) { + WEX::Logging::Log::Error(L"EltsPerThread must be at least 2"); + return std::vector(); + } + for (uint32_t TID = 0; TID < NumThreads; TID++) { + if constexpr (std::is_same_v || + std::is_same_v) { + InputVector[TID * EltsPerThread + 0] = 1; + InputVector[TID * EltsPerThread + 1] = 1; + } else if constexpr (std::is_same_v) { + InputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); + InputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + InputVector[TID * EltsPerThread + 0] = 1.0f; + InputVector[TID * EltsPerThread + 1] = 1.0f; + } else { + WEX::Logging::Log::Error(L"Unsupported input type"); + break; + } + } + + // Convert to uint8_t vector + std::vector Uint8InputVector(InputVector.size() * sizeof(EltTy)); + std::memcpy(Uint8InputVector.data(), InputVector.data(), + InputVector.size() * sizeof(EltTy)); + return Uint8InputVector; +} + +template +static std::vector CreateInputBias(uint32_t NumElts) { + std::vector InputBias(NumElts); + if constexpr (std::is_same_v || + std::is_same_v) { + std::fill(InputBias.begin(), InputBias.end(), EltTy(1)); + } else if constexpr (std::is_same_v) { + std::fill(InputBias.begin(), InputBias.end(), + ConvertFloat32ToFloat16(1.0f)); + } else if constexpr (std::is_same_v) { + std::fill(InputBias.begin(), InputBias.end(), 1); + } else { + WEX::Logging::Log::Error(L"Unsupported bias type"); + } + // Convert to uint8_t vector + std::vector Uint8InputBias(InputBias.size() * sizeof(EltTy)); + std::memcpy(Uint8InputBias.data(), InputBias.data(), + InputBias.size() * sizeof(EltTy)); + return Uint8InputBias; +} + +static std::wstring +DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"FLOAT_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"FLOAT_E5M2"; + default: + return L""; + } +} + +static bool IsDataTypeInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; + } + + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return DataTypeToFilterString(DataType) == FilterString; +} + +static std::wstring +MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"OUTER_PRODUCT_OPTIMAL"; + default: + return L""; + } +} + +static bool +IsMatrixLayoutInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; + } + + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return MatrixLayoutToFilterString(MatrixLayout) == FilterString; +} + +static std::wstring MatrixLayoutToHlslLayoutString( + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"MATRIX_LAYOUT_ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"MATRIX_LAYOUT_COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MATRIX_LAYOUT_MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL"; + default: + return L""; + } +} + +// This multiplier is used to compute the row/column stride for a matrix +// given it's element size. +static int +GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return 1; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return 2; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return 4; + default: + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return 1; + } +} + +static int GetNumPackedElementsForInputDataType( + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { + // Int8 packed types are the only ones that have more than 1 element per + // shader variable + switch (InputInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return 4; + default: + return 1; + } +} + +// This type is used in generated HLSL source to represent the vector type +// for the given data type. +static std::wstring +GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"int16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"uint16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"int32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"uint32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"half"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"float"; + default: + WEX::Logging::Log::Error(L"Unsupported input data type"); + return L""; + } +} + +static std::wstring +GetHlslInterpretationForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { + switch (Interpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"DATA_TYPE_SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"DATA_TYPE_UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"DATA_TYPE_SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"DATA_TYPE_UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"DATA_TYPE_SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"DATA_TYPE_UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"DATA_TYPE_SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"DATA_TYPE_UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"DATA_TYPE_FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"DATA_TYPE_FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"DATA_TYPE_FLOAT8_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"DATA_TYPE_FLOAT8_E5M2"; + default: + WEX::Logging::Log::Error(L"Unsupported interpretation"); + return L""; + } +} + +// The returned data type is used for matrix conversion. It is hard-coded +// for the test framework where all integer matrices start as SINT8 and +// all FP matrices start as FLOAT32. +static D3D12_LINEAR_ALGEBRA_DATATYPE +GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { + switch (MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + default: + return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + } +} +}; // namespace CoopVecHelpers + +#endif // HAVE_COOPVEC_API diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPI.h b/tools/clang/unittests/HLSLExec/CoopVecAPI.h new file mode 100644 index 0000000000..16c1105edc --- /dev/null +++ b/tools/clang/unittests/HLSLExec/CoopVecAPI.h @@ -0,0 +1,178 @@ +#pragma once +// clang-format off + +#if !defined(D3D12_PREVIEW_SDK_VERSION) || D3D12_PREVIEW_SDK_VERSION < 717 + +#ifdef __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ +#define HAVE_COOPVEC_API 1 + +// This file contains the definitions of the D3D12 cooperative vector API. +// It is used to test the cooperative vector API on older SDKs. + +constexpr int D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL = 9; +constexpr int D3D12_FEATURE_COOPERATIVE_VECTOR = 11; + +// -------------------------------------------------------------------------------------------------------------------------------- +// Experimental Feature: D3D12CooperativeVectorExperiment +// +// Use with D3D12CooperativeVectorExperiment to enable cooperative vector experimental feature. +// +// Enabling D3D12CooperativeVectorExperiment needs no configuration struct, pass NULL in the pConfigurationStructs array. +// +// -------------------------------------------------------------------------------------------------------------------------------- +static const UUID D3D12CooperativeVectorExperiment = { /* 384748be-cca5-471e-a125-5cc997e04d39 */ + 0x384748be, + 0xcca5, + 0x471e, + {0xa1, 0x25, 0x5c, 0xc9, 0x97, 0xe0, 0x4d, 0x39} +}; + +/* interface __MIDL_itf_d3d12_0000_0082 */ +/* [local] */ + +typedef +enum D3D12_COOPERATIVE_VECTOR_TIER + { + D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED = 0, + D3D12_COOPERATIVE_VECTOR_TIER_1_0 = 0x10, + D3D12_COOPERATIVE_VECTOR_TIER_1_1 = 0x11 + } D3D12_COOPERATIVE_VECTOR_TIER; + +typedef +enum D3D12_LINEAR_ALGEBRA_DATATYPE + { + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16 = 2, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16 = 3, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 = 4, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 = 5, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 = 7, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 = 8, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED = 16, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED = 17, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 = 18, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 = 19, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 = 20, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 = 21 + } D3D12_LINEAR_ALGEBRA_DATATYPE; + +typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL + { + _Out_ D3D12_COOPERATIVE_VECTOR_TIER CooperativeVectorTier; + } D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL; + +typedef struct D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL + { + D3D12_LINEAR_ALGEBRA_DATATYPE InputType; + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE BiasInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE OutputType; + BOOL TransposeSupported; + } D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL; + +typedef struct D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE + { + D3D12_LINEAR_ALGEBRA_DATATYPE InputType; + D3D12_LINEAR_ALGEBRA_DATATYPE AccumulationType; + } D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE; + +typedef struct D3D12_FEATURE_DATA_COOPERATIVE_VECTOR + { + _Inout_ UINT MatrixVectorMulAddPropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL *pMatrixVectorMulAddProperties; + _Inout_ UINT OuterProductAccumulatePropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE *pOuterProductAccumulateProperties; + _Inout_ UINT VectorAccumulatePropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE *pVectorAccumulateProperties; + } D3D12_FEATURE_DATA_COOPERATIVE_VECTOR; + +typedef +enum D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT + { + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR = 0, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR + 1 ) , + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR + 1 ) , + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL + 1 ) + } D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO + { + _Inout_ UINT DestSize; + _In_ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT DestLayout; + _In_ UINT DestStride; + _In_ UINT NumRows; + _In_ UINT NumColumns; + _In_ D3D12_LINEAR_ALGEBRA_DATATYPE DestDataType; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA + { + _Inout_ D3D12_GPU_VIRTUAL_ADDRESS DestVA; + _In_ D3D12_GPU_VIRTUAL_ADDRESS SrcVA; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO + { + _In_ UINT SrcSize; + _In_ D3D12_LINEAR_ALGEBRA_DATATYPE SrcDataType; + _In_ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT SrcLayout; + _In_ UINT SrcStride; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO + { + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO DestInfo; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO SrcInfo; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA DataDesc; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO; + + + +#ifndef __ID3D12DevicePreview_INTERFACE_DEFINED__ +#define __ID3D12DevicePreview_INTERFACE_DEFINED__ + +EXTERN_C const IID IID_ID3D12DevicePreview; + +MIDL_INTERFACE("55ea41d3-6bf5-4332-bbf9-905e6b4e2930") +ID3D12DevicePreview : public IUnknown +{ +public: + virtual void STDMETHODCALLTYPE GetLinearAlgebraMatrixConversionDestinationInfo( + _Inout_ D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO *pDesc) = 0; + +}; + +#endif /* __ID3D12DevicePreview_INTERFACE_DEFINED__ */ + + +#ifndef __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ +#define __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ + +EXTERN_C const IID IID_ID3D12GraphicsCommandList11; + +MIDL_INTERFACE("f0dcfabc-a84a-4fe3-b3b9-eab26b306c38") +ID3D12GraphicsCommandList11 : public ID3D12GraphicsCommandList10 +{ +public: + virtual void STDMETHODCALLTYPE Reserved0() = 0; + virtual void STDMETHODCALLTYPE Reserved1() = 0; + virtual void STDMETHODCALLTYPE Reserved2() = 0; + + virtual void STDMETHODCALLTYPE ConvertLinearAlgebraMatrix( + _In_ const D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO *pDesc, + _In_ UINT DescCount) = 0; + +}; + +#endif /* __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ */ + +#else // __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ +// The used d3d12.h header does not support ID3D12GraphicsCommandList10, +// so we cannot define ID3D12GraphicsCommandList11. +#define HAVE_COOPVEC_API 0 +#endif // __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ + +#else // D3D12_PREVIEW_SDK_VERSION < 717 +// Preview header has CoopVec support +#define HAVE_COOPVEC_API 1 +#endif // D3D12_PREVIEW_SDK_VERSION < 717 diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 1bef0b4f8d..55d569dd8d 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -51,6 +51,7 @@ // https://msdn.microsoft.com/en-us/library/windows/desktop/dn899120(v=vs.85).aspx // https://developer.microsoft.com/en-US/windows/downloads/windows-10-sdk // + #include #include #include @@ -63,6 +64,8 @@ #include #include #include "LongVectors.h" +#include "CoopVecAPI.h" +#include "CoopVec.h" // clang-format on #pragma comment(lib, "d3dcompiler.lib") @@ -617,6 +620,9 @@ class ExecutionTest { TEST_METHOD(LongVector_Clamp_uint64); TEST_METHOD(LongVector_Initialize_uint64); + TEST_METHOD(CoopVec_Mul); + TEST_METHOD(CoopVec_OuterProduct); + dxc::DxcDllSupport m_support; bool m_D3DInitCompleted = false; @@ -752,7 +758,7 @@ class ExecutionTest { #endif } - bool UseDebugIfaces() { return true; } + bool UseDebugIfaces() { return false; } bool SaveImages() { return GetTestParamBool(L"SaveImages"); } @@ -775,6 +781,42 @@ class ExecutionTest { void RunResourceTest(ID3D12Device *pDevice, const char *pShader, const wchar_t *sm, bool isDynamic); + void runCoopVecMulTest(); + void runCoopVecOuterProductTest(); + +#if HAVE_COOPVEC_API + struct CoopVecMulSubtestConfig { + int InputPerThread; + int OutputPerThread; + int NumThreads; + int NumLevels; + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; + bool Bias; + }; + + void + runCoopVecMulTestConfig(ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); + void runCoopVecMulSubtest(ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, + CoopVecMulSubtestConfig &Config); + + struct CoopVecOuterProductSubtestConfig { + int DimM; // Row Count + int DimN; // Column Count + int NumThreads; + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; + }; + + void runCoopVecOuterProductTestConfig( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); + void runCoopVecOuterProductSubtest( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, + CoopVecOuterProductSubtestConfig &Config); +#endif // HAVE_COOPVEC_API + template void WaveIntrinsicsActivePrefixTest(TableParameter *pParameterList, size_t numParameter, bool isPrefix); @@ -834,7 +876,8 @@ class ExecutionTest { void CompileFromText(LPCSTR pText, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, ID3DBlob **ppBlob, - LPCWSTR *pOptions = nullptr, int numOptions = 0) { + LPCWSTR *pOptions = nullptr, int numOptions = 0, + IDxcIncludeHandler *pIncludeHandler = nullptr) { VERIFY_SUCCEEDED(m_support.Initialize()); CComPtr pCompiler; CComPtr pLibrary; @@ -847,7 +890,7 @@ class ExecutionTest { pText, (UINT32)strlen(pText), CP_UTF8, &pTextBlob)); VERIFY_SUCCEEDED(pCompiler->Compile(pTextBlob, L"hlsl.hlsl", pEntryPoint, pTargetProfile, pOptions, numOptions, - nullptr, 0, nullptr, &pResult)); + nullptr, 0, pIncludeHandler, &pResult)); VERIFY_SUCCEEDED(pResult->GetStatus(&resultCode)); if (FAILED(resultCode)) { #ifndef _HLK_CONF @@ -882,7 +925,8 @@ class ExecutionTest { ID3D12RootSignature *pRootSignature, LPCSTR pShader, LPCWSTR pTargetProfile, ID3D12PipelineState **ppComputeState, - LPCWSTR *pOptions = nullptr, int numOptions = 0) { + LPCWSTR *pOptions = nullptr, int numOptions = 0, + IDxcIncludeHandler *pIncludeHandler = nullptr) { CComPtr pComputeShader; // Load and compile shaders. @@ -892,7 +936,7 @@ class ExecutionTest { #endif } else { CompileFromText(pShader, L"main", pTargetProfile, &pComputeShader, - pOptions, numOptions); + pOptions, numOptions, pIncludeHandler); } // Describe and create the compute pipeline state object (PSO). @@ -1729,6 +1773,21 @@ class ExecutionTest { #endif } + bool DoesDeviceSupportCooperativeVector(ID3D12Device *Device) { +#if HAVE_COOPVEC_API + D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL O; + if (FAILED(Device->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL, &O, + sizeof(O)))) + return false; + return O.CooperativeVectorTier != + D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED; +#else + UNREFERENCED_PARAMETER(Device); + return false; +#endif + } + bool IsFallbackPathEnabled() { // Enable fallback paths with: /p:"EnableFallback=1" UINT EnableFallbackValue = 0; @@ -1841,8 +1900,18 @@ class ExecutionTest { if (pD3D12EnableExperimentalFeatures == nullptr) { return HRESULT_FROM_WIN32(GetLastError()); } - return pD3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModelsID, - nullptr, nullptr); + + std::vector Features; + + Features.push_back(D3D12ExperimentalShaderModels); + +#if HAVE_COOPVEC_API + if (GetTestParamBool(L"CooperativeVectorExperimental")) { + Features.push_back(D3D12CooperativeVectorExperiment); + } +#endif + return pD3D12EnableExperimentalFeatures((UINT)Features.size(), + Features.data(), nullptr, nullptr); } static HRESULT EnableExperimentalShaderModels() { @@ -11912,6 +11981,1374 @@ VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, TestConfig.Tolerance)); } +// Runs a set of tests for the Cooperative Vector Mul and MulAdd operations. +// The device will be queried for supported configurations and then each +// supported configuration will be tested against multiple matrix and vector +// sizes. To help reproduce individual test failures, the test will log the +// configuration it is running and the results of each test. The following +// filters can be used to limit test execution to a specific set of +// configurations: +// +// - CoopVecMatrixInterp: SINT8, FLOAT16, FLOAT_E4M3, ... +// - CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, +// OUTER_PRODUCT_OPTIMAL +// - CoopVecBiasInterp: SINT32, FLOAT16, FLOAT_E4M3, ... +// - CoopVecInputInterp: SINT8, FLOAT16, FLOAT_E4M3, ... +// - CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, +// FLOAT32, ... +// - CoopVecOutputType: SINT32, UINT32, FLOAT16, FLOAT32, ... +// +// Filter example: +// TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 +// -p:CoopVecMatrixLayout=MUL_OPTIMAL +// +// The current implementation will always write the final output data as float. +void ExecutionTest::runCoopVecMulTest() { +#if !HAVE_COOPVEC_API + WEX::Logging::Log::Comment( + "Cooperative vector API not supported in build configuration. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; +#else + // Create device and verify coopvec support + CComPtr D3DDevice; + if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { + return; + } + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { + WEX::Logging::Log::Comment( + "Device does not support cooperative vector. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + + // Query coopvec feature data. First call gets the size of the arrays. The + // second call populates the arrays using memory we allocate. + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR DevOptions = {}; + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Allocate memory for the arrays in DevOptions + std::vector MulAddProps( + DevOptions.MatrixVectorMulAddPropCount); + DevOptions.pMatrixVectorMulAddProperties = MulAddProps.data(); + + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Test each supported data type and matrix layout + for (auto MulAddConfig : MulAddProps) { + // Filter on preview test support + bool PreviewConfig = false; + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + PreviewConfig = true; + } + + if (!PreviewConfig) { + continue; + } + + // Apply filters + bool IsInFilter = + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecMatrixInterp", + MulAddConfig.MatrixInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecBiasInterp", + MulAddConfig.BiasInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputInterp", + MulAddConfig.InputInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputType", + MulAddConfig.InputType) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecOutputType", + MulAddConfig.OutputType); + if (!IsInFilter) { + continue; + } + + // Run the test + runCoopVecMulTestConfig(D3DDevice, MulAddConfig); + } +#endif // HAVE_COOPVEC_API +} + +#if HAVE_COOPVEC_API +void ExecutionTest::runCoopVecMulTestConfig( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { + + LogCommentFmt( + L"Running test for MatrixInterpretation: %s, BiasInterpretation: %s, " + L"InputInterpretation: %s, InputType: %s, OutputType: %s", + CoopVecHelpers::DataTypeToFilterString(MulProps.MatrixInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.BiasInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.InputInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.InputType).c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.OutputType).c_str()); + + constexpr CoopVecMulSubtestConfig TestConfigs[] = { + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + }; + + for (auto Config : TestConfigs) { + if ((MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || + Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + continue; + } + + bool IsInFilter = CoopVecHelpers::IsMatrixLayoutInFilter( + L"CoopVecMatrixLayout", Config.MatrixLayout); + if (!IsInFilter) { + continue; + } + + runCoopVecMulSubtest(D3DDevice, MulProps, Config); + } +} + +void ExecutionTest::runCoopVecMulSubtest( + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, + CoopVecMulSubtestConfig &Config) { + + LogCommentFmt( + L"Running test for InputPerThread: %d, OutputPerThread: %d, NumThreads: " + L"%d, NumLevels: %d, Bias: %s, MatrixLayout: %s", + Config.InputPerThread, Config.OutputPerThread, Config.NumThreads, + Config.NumLevels, Config.Bias ? L"true" : L"false", + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + + const int OutputBufferSize = (Config.OutputPerThread * Config.NumThreads * 4); + + // Create root signature with a single root entry for all SRVs and UAVs + CComPtr RootSignature; + { + CD3DX12_DESCRIPTOR_RANGE Ranges[2]; + Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, + 0); // InputVector, InputMatrix, InputBias + Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // OutputBuffer + CreateRootSignatureFromRanges(D3DDevice, &RootSignature, Ranges, 2, nullptr, + 0); + } + + // Create descriptor heap with space for 4 descriptors: 3 SRVs and 1 UAV + CComPtr DescriptorHeap; + { + D3D12_DESCRIPTOR_HEAP_DESC Desc = {}; + Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + Desc.NumDescriptors = 4; + Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED( + D3DDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&DescriptorHeap))); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( + DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + + // Create the compute pipeline state for the CoopVec shader + CComPtr ComputePipelineState; + { + std::string ShaderSource = R"( +#include "dx/linalg.h" + +ByteAddressBuffer InputVector : register(t0); +ByteAddressBuffer InputBias : register(t1); +ByteAddressBuffer InputMatrix : register(t2); +RWByteAddressBuffer OutputBuffer: register(u0); + +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ + using namespace dx::linalg; + + // Ensure 4-byte alignment for vector loads + uint inputOffset = (INPUT_PER_THREAD * threadIdx * (sizeof(INPUT_DATA_TYPE) / INPUT_DIVISOR)); + inputOffset = (inputOffset + 3) & ~3; // Align to 4 bytes + vector input = InputVector.Load >(inputOffset); + + MatrixRef mat = { InputMatrix, 0, STRIDE }; + + vector accum; + + if (USE_BIAS) { + VectorRef biasVec = { InputBias, 0 }; + accum = MulAdd(mat, MakeInterpretedVector(input), biasVec); + } else { + accum = Mul(mat, MakeInterpretedVector(input)); + } + + vector result = (vector)accum; + + // Ensure 4-byte alignment for vector store + uint outputOffset = OUTPUT_PER_THREAD * threadIdx * sizeof(float); + outputOffset = (outputOffset + 3) & ~3; // Align to 4 bytes + OutputBuffer.Store >(outputOffset, result); +} + )"; + + auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + auto CreateDefineFromString = [](const wchar_t *Name, + const std::wstring &Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + int Stride = 0; + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( + MulProps.MatrixInterpretation); + switch (Config.MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.InputPerThread * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.OutputPerThread * StrideMultiplier; + break; + } + + const int InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + MulProps.InputInterpretation); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.InputType); + const std::wstring AccumDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.MatrixInterpretation); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.InputInterpretation); + const std::wstring AccumInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.BiasInterpretation); + + auto InputPerThreadDefine = + CreateDefineFromInt(L"INPUT_PER_THREAD", Config.InputPerThread); + auto OutputPerThreadDefine = + CreateDefineFromInt(L"OUTPUT_PER_THREAD", Config.OutputPerThread); + auto NumThreadsDefine = + CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType); + auto InputDivisorDefine = + CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout); + auto MatrixDataTypeEnumDefine = + CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum); + auto UseBiasDefine = CreateDefineFromInt(L"USE_BIAS", Config.Bias ? 1 : 0); + auto AccumInterpretationEnumDefine = CreateDefineFromString( + L"ACCUM_INTERPRETATION_ENUM", AccumInterpretationEnum); + + LPCWSTR Options[] = { + L"-enable-16bit-types", + InputPerThreadDefine.c_str(), + OutputPerThreadDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + UseBiasDefine.c_str(), + AccumInterpretationEnumDefine.c_str(), + }; + + CComPtr IncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); + + CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", + &ComputePipelineState, Options, _countof(Options), + IncludeHandler); + } + + // Create a command list for the compute shader. + CComPtr CommandList; + CComPtr CommandAllocator; + CComPtr CommandQueue; + FenceObj FO; + CreateCommandQueue(D3DDevice, L"CoopVec Test Command Queue", &CommandQueue, + D3D12_COMMAND_LIST_TYPE_DIRECT); + InitFenceObj(D3DDevice, &FO); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + IID_PPV_ARGS(&CommandList))); + + // Setup input data + auto ExpectedOutputBuffer = + std::make_unique(Config.OutputPerThread * Config.NumThreads); + + // Setup input matrix as all-ones in sint8 format. This will later be + // converted to the appropriate data type by the matrix conversion API. + CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; + std::vector InputMatrix; + if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + Config.InputPerThread, Config.OutputPerThread); + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix + // conversion + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + Config.InputPerThread, Config.OutputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return; + } + + CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), + InputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), + &InputMatrixSRVResource, &InputMatrixSRVUploadResource); + + // Create input vector of an appropriate type. All integer types start as + // SINT8 for now. + CComPtr InputVecSRVResource, InputVecSRVUploadResource; + std::vector InputVector; + + if ((MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && + (MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED)) || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + InputVector = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported input data type"); + return; + } + if (InputVector.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputVector.resize(InputVector.size() + 4 - (InputVector.size() % 4)); + } + CreateTestResources(D3DDevice, CommandList, InputVector.data(), + InputVector.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector.size()), + &InputVecSRVResource, &InputVecSRVUploadResource); + + // This increments baseHandle + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector.size() / sizeof(int32_t)), + InputVecSRVResource); + + // Create input bias + CComPtr InputBiasSRVResource, InputBiasSRVUploadResource; + std::vector InputBias; + + if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + InputBias = + CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { + InputBias = + CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + InputBias = CoopVecHelpers::CreateInputBias( + Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported bias data type"); + return; + } + + if (InputBias.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputBias.resize(InputBias.size() + 4 - (InputBias.size() % 4)); + } + CreateTestResources(D3DDevice, CommandList, InputBias.data(), + InputBias.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputBias.size()), + &InputBiasSRVResource, &InputBiasSRVUploadResource); + + // This increments baseHandle + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputBias.size() / sizeof(int32_t)), + InputBiasSRVResource); + + // Calculate reference output + // FIXME: This does not capture all cases, but is sufficient for the preview + // feature set + if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) { + // The input bias is really an array of int32_t + std::vector InputBiasI32(InputBias.size() / sizeof(int32_t)); + std::memcpy(InputBiasI32.data(), InputBias.data(), InputBias.size()); + + // The input vector is really an array of float if our vector input type is + // FLOAT32 + std::vector InputVectorF32(InputVector.size() / sizeof(int32_t)); + if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + std::memcpy(InputVectorF32.data(), InputVector.data(), + InputVector.size()); + } + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { + int Acc = 0; + + for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { + int InputElem; + if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputElem = (int) + InputVectorF32[ThreadIdx * Config.InputPerThread + InputIdx]; + } else { + InputElem = + InputVector[ThreadIdx * Config.InputPerThread + InputIdx]; + } + int const MatrixElem = + InputMatrix[OutputIdx * Config.InputPerThread + InputIdx]; + Acc += InputElem * MatrixElem; + } + + if (Config.Bias) { + Acc += InputBiasI32[OutputIdx]; + } + + float Result = float(Acc); + ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = + Result; + } + } + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // The input bias/vector is really an array of float16 + std::vector InputVectorFP16( + InputVector.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVectorFP16.data(), InputVector.data(), InputVector.size()); + + std::vector InputBiasFP16( + InputBias.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputBiasFP16.data(), InputBias.data(), InputBias.size()); + + // The CPU reference matrix is float + std::vector InputMatrixFP32(InputMatrix.size() / sizeof(float)); + std::memcpy(InputMatrixFP32.data(), InputMatrix.data(), InputMatrix.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { + float Acc = 0; + + for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { + float const InputElem = ConvertFloat16ToFloat32( + InputVectorFP16[ThreadIdx * Config.InputPerThread + InputIdx]); + float const MatrixElem = + InputMatrixFP32[OutputIdx * Config.InputPerThread + InputIdx]; + Acc += InputElem * MatrixElem; + } + + if (Config.Bias) { + Acc += ConvertFloat16ToFloat32(InputBiasFP16[OutputIdx]); + } + + float Result = Acc; + ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = + Result; + } + } + } + + CComPtr ConvertedMatrixResource; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo.SrcDataType = + CoopVecHelpers::GetMatrixSrcDataType(MulProps.MatrixInterpretation); + ConvertInfo.SrcInfo.SrcLayout = + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + + // Create destination matrix info + ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver + int SrcEltSize = 0; + int DestEltSize = 0; + switch (MulProps.MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + SrcEltSize = 1; + DestEltSize = 1; + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + SrcEltSize = 4; // FP32 + DestEltSize = 2; // FP16 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + ConvertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + ConvertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + } + ConvertInfo.SrcInfo.SrcStride = Config.InputPerThread * SrcEltSize; + ConvertInfo.SrcInfo.SrcSize = + Config.InputPerThread * Config.OutputPerThread * SrcEltSize; + + ConvertInfo.DestInfo.DestLayout = Config.MatrixLayout; + ConvertInfo.DestInfo.DestStride = 0; + ConvertInfo.DestInfo.NumRows = Config.OutputPerThread; + ConvertInfo.DestInfo.NumColumns = Config.InputPerThread; + + if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { + ConvertInfo.DestInfo.DestStride = Config.InputPerThread * DestEltSize; + } else if (Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + ConvertInfo.DestInfo.DestStride = Config.OutputPerThread * DestEltSize; + } + + // Get destination size using preview interface + { + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); + + // Query required destination size + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); + } + + // Create resource to hold matrix copy + CreateTestResources( + D3DDevice, CommandList, nullptr, 0, + CD3DX12_RESOURCE_DESC::Buffer(ConvertInfo.DestInfo.DestSize), + &ConvertedMatrixResource, nullptr); + + // Set up data descriptors + ConvertInfo.DataDesc.DestVA = + ConvertedMatrixResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.SrcVA = InputMatrixSRVResource->GetGPUVirtualAddress(); + + // Get command list interface and perform conversion + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + + // This increments baseHandle + if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { + WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); + return; + } + CreateRawSRV(D3DDevice, BaseHandle, + ConvertInfo.DestInfo.DestSize / sizeof(int32_t), + ConvertedMatrixResource); + } + + CComPtr UavResource; + CComPtr UavUploadResource; + CComPtr UavReadResource; + + // Create buffer for output and fill with 0xFF to make it obvious if it's not + // written in the shader. + std::vector OutputBufferInit(OutputBufferSize); + std::fill(OutputBufferInit.begin(), OutputBufferInit.end(), (uint8_t)0xFF); + + CreateTestUavs(D3DDevice, CommandList, OutputBufferInit.data(), + OutputBufferSize, &UavResource, &UavUploadResource, + &UavReadResource); + CreateRawUAV(D3DDevice, BaseHandle, OutputBufferSize / 4, UavResource); + + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + + SetDescriptorHeap(CommandList, DescriptorHeap); + + CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( + DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(ComputePipelineState); + CommandList->Dispatch(1, 1, 1); + RecordTransitionBarrier(CommandList, UavResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE); + CommandList->CopyResource(UavReadResource, UavResource); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + { + MappedData MappedData(UavReadResource, OutputBufferSize); + + float *ResultBuffer = (float *)MappedData.data(); + bool Equal = true; + for (int i = 0; i < OutputBufferSize / sizeof(float); i++) { + if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + LogErrorFmt(L"Result mismatch at index %d", i); + LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + ResultBuffer[i], i, ExpectedOutputBuffer[i]); + Equal = false; + break; + } + } + VERIFY_IS_TRUE(Equal); + } +} +#endif // HAVE_COOPVEC_API + +TEST_F(ExecutionTest, CoopVec_Mul) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + runCoopVecMulTest(); +} + +void ExecutionTest::runCoopVecOuterProductTest() { +#if !HAVE_COOPVEC_API + WEX::Logging::Log::Comment( + "Cooperative vector API not supported in build configuration. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; +#else + // Create device and verify coopvec support + CComPtr D3DDevice; + if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { + return; + } + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { + WEX::Logging::Log::Comment( + "Device does not support cooperative vector. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + + // Query coopvec feature data. First call gets the size of the arrays. The + // second call populates the arrays using memory we allocate. + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR DevOptions = {}; + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Allocate memory for the arrays in DevOptions + std::vector AccumulateProps( + DevOptions.OuterProductAccumulatePropCount); + DevOptions.pOuterProductAccumulateProperties = AccumulateProps.data(); + + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Test each supported data type and matrix layout + for (auto AccumulateConfig : AccumulateProps) { + // Run the test + runCoopVecOuterProductTestConfig(D3DDevice, AccumulateConfig); + } +#endif // HAVE_COOPVEC_API +} + +#if HAVE_COOPVEC_API +void ExecutionTest::runCoopVecOuterProductTestConfig( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { + LogCommentFmt( + L"Running test for InputType: %s, AccumulationType: %s", + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.InputType).c_str(), + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.AccumulationType) + .c_str()); + + constexpr CoopVecOuterProductSubtestConfig TestConfigs[] = { + {4, 4, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + }; + + for (auto Config : TestConfigs) { + if ((AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || + Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + continue; + } + + runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config); + } +} + +void ExecutionTest::runCoopVecOuterProductSubtest( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, + CoopVecOuterProductSubtestConfig &Config) { + + LogCommentFmt( + L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s", + Config.DimM, Config.DimN, Config.NumThreads, + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + + // Create root signature with a single root entry for all SRVs and UAVs + CComPtr RootSignature; + { + CD3DX12_DESCRIPTOR_RANGE ranges[2]; + ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, + 0); // InputVector1, InputVector2 + ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix + CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr, + 0); + } + + // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV + CComPtr DescriptorHeap; + { + D3D12_DESCRIPTOR_HEAP_DESC Desc = {}; + Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + Desc.NumDescriptors = 3; + Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED( + D3DDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&DescriptorHeap))); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( + DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + + // Create a compute pipeline state object. + CComPtr ComputePipelineState; + { + std::string ShaderSource = R"( +#include "dx/linalg.h" + +ByteAddressBuffer InputVector1 : register(t0); +ByteAddressBuffer InputVector2 : register(t1); +RWByteAddressBuffer AccumMatrix : register(u0); + +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ +#if 1 + using namespace dx::linalg; + + // Ensure 4-byte alignment for vector loads + uint inputOffset1 = (DIM_M * threadIdx * sizeof(INPUT_DATA_TYPE)); + inputOffset1 = (inputOffset1 + 3) & ~3; // Align to 4 bytes + vector input1 = InputVector1.Load >(inputOffset1); + + uint inputOffset2 = (DIM_N * threadIdx * sizeof(INPUT_DATA_TYPE)); + inputOffset2 = (inputOffset2 + 3) & ~3; // Align to 4 bytes + vector input2 = InputVector2.Load >(inputOffset2); + + RWMatrixRef mat = { AccumMatrix, 0, STRIDE }; + + OuterProductAccumulate(input1, input2, mat); +#endif +} + )"; + + auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + auto CreateDefineFromString = [](const wchar_t *Name, + const wchar_t *Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + int Stride = 0; + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( + AccumulateProps.AccumulationType); + switch (Config.MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.DimN * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.DimM * StrideMultiplier; + break; + } + + const int InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + AccumulateProps.InputType); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType); + const std::wstring AccumDataType = + CoopVecHelpers::GetHlslDataTypeForDataType( + AccumulateProps.AccumulationType); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.AccumulationType); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.InputType); + + auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM); + auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN); + auto NumThreadsDefine = + CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str()); + auto InputDivisorDefine = + CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str()); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str()); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str()); + auto MatrixDataTypeEnumDefine = CreateDefineFromString( + L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); + + LPCWSTR Options[] = { + L"-enable-16bit-types", + DimMDefine.c_str(), + DimNDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + }; + + CComPtr IncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); + + CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", + &ComputePipelineState, Options, _countof(Options), + IncludeHandler); + } + + // Create a command list for the compute shader. + CComPtr CommandList; + CComPtr CommandAllocator; + CComPtr CommandQueue; + FenceObj FO; + CreateCommandQueue(D3DDevice, L"CoopVec Test Command Queue", &CommandQueue, + D3D12_COMMAND_LIST_TYPE_DIRECT); + InitFenceObj(D3DDevice, &FO); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + IID_PPV_ARGS(&CommandList))); + + // Setup input matrix as all-ones in sint8/fp32 format. This will later be + // converted to the appropriate data type by the matrix conversion API. + CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; + std::vector InputMatrix; + if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); + } else if (AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix + // conversion + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); + } else { + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return; + } + + CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), + InputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), + &InputMatrixSRVResource, &InputMatrixSRVUploadResource); + + // Create input vectors + CComPtr InputVecSRVResource1, InputVecSRVUploadResource1; + std::vector InputVector1; + CComPtr InputVecSRVResource2, InputVecSRVUploadResource2; + std::vector InputVector2; + + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimM); + InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimN); + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + InputVector1 = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.DimM); + InputVector2 = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.DimN); + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimM); + InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimN); + } else { + WEX::Logging::Log::Error(L"Unsupported input data type"); + return; + } + if (InputVector1.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputVector1.resize(InputVector1.size() + 4 - (InputVector1.size() % 4)); + } + if (InputVector2.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputVector2.resize(InputVector2.size() + 4 - (InputVector2.size() % 4)); + } + CreateTestResources(D3DDevice, CommandList, InputVector1.data(), + InputVector1.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector1.size()), + &InputVecSRVResource1, &InputVecSRVUploadResource1); + CreateTestResources(D3DDevice, CommandList, InputVector2.data(), + InputVector2.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector2.size()), + &InputVecSRVResource2, &InputVecSRVUploadResource2); + + // This increments baseHandle + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector1.size() / sizeof(int32_t)), + InputVecSRVResource1); + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector2.size() / sizeof(int32_t)), + InputVecSRVResource2); + + // Calculate reference output + auto ExpectedOutputBufferI8 = + CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + std::vector ExpectedOutputBuffer(ExpectedOutputBufferI8.size() / + sizeof(float)); + std::memcpy(ExpectedOutputBuffer.data(), ExpectedOutputBufferI8.data(), + ExpectedOutputBufferI8.size()); + + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + std::vector InputVector1FP16( + InputVector1.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVector1FP16.data(), InputVector1.data(), + InputVector1.size()); + + std::vector InputVector2FP16( + InputVector2.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVector2FP16.data(), InputVector2.data(), + InputVector2.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int M = 0; M < Config.DimM; ++M) { + for (int N = 0; N < Config.DimN; ++N) { + float acc = ConvertFloat16ToFloat32(InputVector1FP16[M]) * + ConvertFloat16ToFloat32(InputVector2FP16[N]); + ExpectedOutputBuffer[M * Config.DimN + N] += acc; + } + } + } + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + std::vector InputVector1FP32(InputVector1.size() / sizeof(float)); + std::memcpy(InputVector1FP32.data(), InputVector1.data(), + InputVector1.size()); + + std::vector InputVector2FP32(InputVector2.size() / sizeof(float)); + std::memcpy(InputVector2FP32.data(), InputVector2.data(), + InputVector2.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int M = 0; M < Config.DimM; ++M) { + for (int N = 0; N < Config.DimN; ++N) { + float Acc = InputVector1FP32[ThreadIdx * Config.DimM + M] * + InputVector2FP32[ThreadIdx * Config.DimN + N]; + ExpectedOutputBuffer[M * Config.DimN + N] += Acc; + } + } + } + } + + CComPtr ConvertedMatrixResource, ConvertedMatrixReadResource; + int ConvertedMatrixSize = 0; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO SrcInfo = {}; + SrcInfo.SrcDataType = + CoopVecHelpers::GetMatrixSrcDataType(AccumulateProps.AccumulationType); + SrcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + + // Create destination matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO DestInfo = {}; + DestInfo.DestSize = 0; // Will be populated by driver + int SrcEltSize = 0; + int DestEltSize = 0; + switch (AccumulateProps.AccumulationType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + SrcEltSize = 1; + DestEltSize = 1; + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + SrcEltSize = 4; // FP32 + DestEltSize = 2; // FP16 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + } + SrcInfo.SrcStride = Config.DimM * SrcEltSize; + SrcInfo.SrcSize = Config.DimM * Config.DimN * SrcEltSize; + + DestInfo.DestLayout = Config.MatrixLayout; + DestInfo.DestStride = 0; + DestInfo.NumRows = Config.DimM; + DestInfo.NumColumns = Config.DimN; + + if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { + DestInfo.DestStride = Config.DimM * DestEltSize; + } else if (Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + DestInfo.DestStride = Config.DimM * DestEltSize; + } + + // Create conversion info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo = SrcInfo; + ConvertInfo.DestInfo = DestInfo; + + // Get preview device interface + { + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); + + // Query required destination size + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); + } + + ConvertedMatrixSize = ConvertInfo.DestInfo.DestSize; + + // Hack to prevent read resource from being created with size 0 + std::vector TempData(ConvertInfo.DestInfo.DestSize); + CreateTestUavs(D3DDevice, CommandList, TempData.data(), TempData.size(), + &ConvertedMatrixResource, nullptr, + &ConvertedMatrixReadResource); + + // Set up data descriptors + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA DataDesc = {}; + DataDesc.DestVA = ConvertedMatrixResource->GetGPUVirtualAddress(); + DataDesc.SrcVA = InputMatrixSRVResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc = DataDesc; + + // Get command list interface and perform conversion + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + + // This increments baseHandle + if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { + WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); + return; + } + CreateRawUAV(D3DDevice, BaseHandle, + ConvertInfo.DestInfo.DestSize / sizeof(int32_t), + ConvertedMatrixResource); + } + + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + + SetDescriptorHeap(CommandList, DescriptorHeap); + + CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( + DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(ComputePipelineState); + CommandList->Dispatch(1, 1, 1); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + + // Convert matrix to sint8/fp32 row-major format before reading back to the + // CPU. A new resource is created, along with a readback resource, for the + // matrix copy. + CComPtr MatrixRowMajorResource, MatrixRowMajorReadResource; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo.SrcLayout = Config.MatrixLayout; + ConvertInfo.SrcInfo.SrcSize = ConvertedMatrixSize; + ConvertInfo.SrcInfo.SrcDataType = AccumulateProps.AccumulationType; + ConvertInfo.SrcInfo.SrcStride = 0; // OUTER_PRODUCT_OPTIMAL + + // Create destination matrix info + ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver + ConvertInfo.DestInfo.DestLayout = + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + ConvertInfo.DestInfo.NumRows = Config.DimM; + ConvertInfo.DestInfo.NumColumns = Config.DimN; + + if (AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(float); + } else { + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(int8_t); + } + + // Get destination size using preview interface + { + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); + + // Query required destination size + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); + } + + // Create resource to hold matrix copy and a readback resource for it + // Init vector is a hack to prevent read resource from being created with + // size 0 + // TODO: Fix CreateTestUavs to allow creating readback resource without init + // data + std::vector TempData(ConvertInfo.DestInfo.DestSize); + CreateTestUavs(D3DDevice, CommandList, TempData.data(), TempData.size(), + &MatrixRowMajorResource, nullptr, + &MatrixRowMajorReadResource); + + // Set up data descriptors + ConvertInfo.DataDesc.DestVA = + MatrixRowMajorResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.SrcVA = + ConvertedMatrixResource->GetGPUVirtualAddress(); + + // Get command list interface and perform conversion + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + } + + RecordTransitionBarrier(CommandList, MatrixRowMajorResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE); + CommandList->CopyResource(MatrixRowMajorReadResource, MatrixRowMajorResource); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + { + MappedData MappedData(MatrixRowMajorReadResource, (UINT)InputMatrix.size()); + + float *ResultBuffer = (float *)MappedData.data(); + bool Equal = true; + for (int i = 0; i < (UINT)InputMatrix.size() / sizeof(float); i++) { + if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + LogErrorFmt(L"Result mismatch at index %d", i); + LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + ResultBuffer[i], i, ExpectedOutputBuffer[i]); + Equal = false; + break; + } + } + VERIFY_IS_TRUE(Equal); + } +} +#endif // HAVE_COOPVEC_API + +TEST_F(ExecutionTest, CoopVec_OuterProduct) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + runCoopVecOuterProductTest(); +} + // This test expects a that retrieves a signal value from each of a // few resources that are initialized here. determines if it uses // the 6.6 Dynamic Resources feature. Values are read back from the result UAV From c6fce45d1160bf289244b8523ccc46822a620869 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Thu, 8 May 2025 10:08:37 +0200 Subject: [PATCH 22/31] Fixed SERShaderTableIndexTest (not relying on is/ah any longer) --- .../unittests/HLSLExec/ExecutionTest_SER.h | 72 +++++++++---------- 1 file changed, 32 insertions(+), 40 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 18a4397e0d..130eeea744 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -949,15 +949,23 @@ void raygen() PerRayData payload; payload.visited = 0; - // SER Test - dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES, 0xFF, 0, 1, 0, ray, payload); dx::MaybeReorderThread(hitObject); + // Invoke hit/miss for triangle + dx::HitObject::Invoke( hitObject, payload ); + if (hitObject.IsHit()) { - // Alter the hit object to point to a new shader index to hit chAABB. + // Transform to an 'aabb' hit. hitObject.SetShaderTableIndex( 1 ); - dx::HitObject::Invoke( hitObject, payload ); + } + + // Invoke hit/miss for aabb + dx::HitObject::Invoke( hitObject, payload ); + + if (hitObject.IsHit()) + { // Poison the test data if GetShaderTableIndex does not match SetShaderTableIndex. if (hitObject.GetShaderTableIndex() != 1) payload.visited = 12345; @@ -970,60 +978,44 @@ void raygen() [shader("miss")] void miss(inout PerRayData payload) { - payload.visited |= 1U; + if ((payload.visited & 4U) == 0) + payload.visited |= 4U; // First 'miss' invocation + else + payload.visited |= 8U; // Second 'miss' invocation } // Triangles [shader("anyhit")] void anyhit(inout PerRayData payload, in Attrs attrs) { - payload.visited |= 2U; AcceptHitAndEndSearch(); } +// Triangle closest hit [shader("closesthit")] void closesthit(inout PerRayData payload, in Attrs attrs) { - payload.visited |= 4U; + payload.visited |= 1U; +} + +// AABB closest hit +[shader("closesthit")] +void chAABB(inout PerRayData payload, in Attrs attrs) +{ + payload.visited |= 2U; } // Procedural [shader("intersection")] void intersection() { - // Intersection with circle on a plane (base, n, radius) - // hitPos is intersection point with plane (base, n) - float3 base = {0.0f,0.0f,0.5f}; - float3 n = normalize(float3(0.2f,0.2f,0.5f)); - float radius = 150.f; - // Plane hit - float t = dot(n, base - ObjectRayOrigin()) / dot(n, ObjectRayDirection()); - if (t > RayTCurrent() || t < RayTMin()) { - return; - } - float3 hitPos = ObjectRayOrigin() + t * ObjectRayDirection(); - float3 relHitPos = hitPos - base; - // Circle hit - float hitDist = length(relHitPos); - if (hitDist > radius) - return; - - CustomAttrs attrs; - attrs.dist = hitDist; - ReportHit(t, 1, attrs); + // UNUSED } [shader("anyhit")] void ahAABB(inout PerRayData payload, in CustomAttrs attrs) { - payload.visited |= 8U; - IgnoreHit(); -} - -[shader("closesthit")] -void chAABB(inout PerRayData payload, in Attrs attrs) -{ - payload.visited |= 16U; + // UNUSED } )"; @@ -1044,12 +1036,12 @@ void chAABB(inout PerRayData payload, in Attrs attrs) std::map Histo; for (int Val : TestData) ++Histo[Val]; - VERIFY_ARE_EQUAL(Histo.size(), 3); - VERIFY_ARE_EQUAL(Histo[0], 3696); // Miss (not Invoked) - VERIFY_ARE_EQUAL(Histo[8], 334); // AABB ignored hit -> (Miss not Invoked) + + VERIFY_ARE_EQUAL(Histo.size(), 2); VERIFY_ARE_EQUAL( - Histo[26], - 66); // AABB ignored hit + TriHit -> setSBT(1) -> chAABB invoked + Histo[3], + 66); // 'closesthit' invoked at index 0, then 'chAABB' invoked at index 1 + VERIFY_ARE_EQUAL(Histo[12], 4030); // Miss shader invoked twice } TEST_F(ExecutionTest, SERLoadLocalRootTableConstantTest) { From c3c399f308197d61bf4ee8f31c19bae63d381d5f Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Thu, 8 May 2025 10:15:43 +0200 Subject: [PATCH 23/31] Turn ShaderTable::Init into a ctor with initializer list --- tools/clang/unittests/HLSLExec/DXRUtil.h | 22 +++++++++---------- .../unittests/HLSLExec/ExecutionTest.cpp | 14 ++++++------ 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/DXRUtil.h b/tools/clang/unittests/HLSLExec/DXRUtil.h index 54828f4857..14bbf5bf1b 100644 --- a/tools/clang/unittests/HLSLExec/DXRUtil.h +++ b/tools/clang/unittests/HLSLExec/DXRUtil.h @@ -42,18 +42,16 @@ struct Instance { class ShaderTable { public: - void Init(ID3D12Device *Device, int RaygenCount, int MissCount, - int HitGroupCount, int RayTypeCount, int RootTableDwords) { - RayTypeCount = RayTypeCount; - RaygenCount = RaygenCount; - MissCount = MissCount * RayTypeCount; - HitGroupCount = HitGroupCount * RayTypeCount; - RootTableSizeInBytes = RootTableDwords * 4; - ShaderRecordSizeInBytes = - ROUND_UP(RootTableSizeInBytes + SHADER_ID_SIZE_IN_BYTES, - D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT); - MissStartIdx = RaygenCount; - HitGroupStartIdx = MissStartIdx + MissCount; + ShaderTable(ID3D12Device *Device, int RaygenCount, int MissCount, + int HitGroupCount, int RayTypeCount, int RootTableDwords) + : RayTypeCount(RayTypeCount), RaygenCount(RaygenCount), + MissCount(MissCount * RayTypeCount), + HitGroupCount(HitGroupCount * RayTypeCount), + RootTableSizeInBytes(RootTableDwords * 4), + ShaderRecordSizeInBytes( + ROUND_UP(RootTableSizeInBytes + SHADER_ID_SIZE_IN_BYTES, + D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT)), + MissStartIdx(RaygenCount), HitGroupStartIdx(MissStartIdx + MissCount) { const int TotalSizeInBytes = (RaygenCount + MissCount + HitGroupCount) * ShaderRecordSizeInBytes; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 9be631fd83..02652deef0 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -2499,13 +2499,13 @@ CComPtr ExecutionTest::RunDXRTest( VERIFY_SUCCEEDED(StateObject->QueryInterface(&StateObjectProperties)); // Create SBT - ShaderTable ShaderTable; - ShaderTable.Init(Device, - 1, // raygen count - 1, // miss count - UseMesh && UseProceduralGeometry ? 2 : 1, // hit group count - 1, // ray type count - 4 // dwords per root table + ShaderTable ShaderTable( + Device, + 1, // raygen count + 1, // miss count + UseMesh && UseProceduralGeometry ? 2 : 1, // hit group count + 1, // ray type count + 4 // dwords per root table ); int LocalRootConsts[4] = {12, 34, 56, 78}; From 15140b12156312f5bb35430230a41de2f10147ee Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Thu, 8 May 2025 10:28:53 +0200 Subject: [PATCH 24/31] Fix GetAttributesTest: procedural not contained in AABB / use integer arith for hitKind computation --- .../unittests/HLSLExec/ExecutionTest_SER.h | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 130eeea744..e3f5758d78 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -1509,7 +1509,7 @@ void raygen() } else { - // Use 255 to keep outside the HitKind range [0, 127] we passthru for hits. + // Use 255 to keep outside the HitKind range [0,15] we passthru for hits. testVal = 255; } int id = launchIndex.x + launchIndex.y * launchDim.x; @@ -1537,28 +1537,35 @@ void closesthit(inout PerRayData payload, in CustomAttrs attrs) [shader("intersection")] void intersection() { + // Intersection with circle on a plane (base, n, radius) // hitPos is intersection point with plane (base, n) float3 base = {0.0f,0.0f,0.5f}; float3 n = normalize(float3(0.0f,0.5f,0.5f)); + float radius = 500.f; + // Plane hit float t = dot(n, base - ObjectRayOrigin()) / dot(n, ObjectRayDirection()); - if (t > RayTCurrent() || t < RayTMin()) { + if (t > RayTCurrent() || t < RayTMin()) return; - } float3 hitPos = ObjectRayOrigin() + t * ObjectRayDirection(); float3 relHitPos = hitPos - base; - // Encode some hit information in hitKind - int hitKind = 0; - if (relHitPos.y >= 0.0f) - hitKind = 1; - hitKind *= 2; - if (relHitPos.x >= 0.0f) - hitKind += 1; - hitKind *= 2; - if (relHitPos.z >= 0.0f) - hitKind += 1; + // Circle hit + float hitDist = length(relHitPos); + if (hitDist > radius) + return; CustomAttrs attrs; - attrs.dist = length(relHitPos); + attrs.dist = hitDist; + + // Generate wave-incoherent hitKind + uint2 launchIndex = DispatchRaysIndex().xy; + uint hitKind = 1U; + if (launchIndex.x >= 32) + hitKind |= 2U; + if (launchIndex.y >= 32) + hitKind |= 4U; + if ((launchIndex.x + launchIndex.y) % 2 == 0) + hitKind |= 8U; + ReportHit(t, hitKind, attrs); } @@ -1580,12 +1587,18 @@ void intersection() std::map Histo; for (int Val : TestData) ++Histo[Val]; - VERIFY_ARE_EQUAL(Histo.size(), 5); - VERIFY_ARE_EQUAL(Histo[0], 2009); - VERIFY_ARE_EQUAL(Histo[1], 561); - VERIFY_ARE_EQUAL(Histo[3], 587); - VERIFY_ARE_EQUAL(Histo[4], 454); - VERIFY_ARE_EQUAL(Histo[6], 485); + + VERIFY_ARE_EQUAL(Histo.size(), 10); + VERIFY_ARE_EQUAL(Histo[0], 1587); + VERIFY_ARE_EQUAL(Histo[1], 277); + VERIFY_ARE_EQUAL(Histo[3], 256); + VERIFY_ARE_EQUAL(Histo[5], 167); + VERIFY_ARE_EQUAL(Histo[7], 153); + VERIFY_ARE_EQUAL(Histo[9], 249); + VERIFY_ARE_EQUAL(Histo[11], 260); + VERIFY_ARE_EQUAL(Histo[13], 158); + VERIFY_ARE_EQUAL(Histo[15], 142); + VERIFY_ARE_EQUAL(Histo[255], 847); } TEST_F(ExecutionTest, SERTraceHitMissNopTest) { From 6e42757a4e182e17e5a62f12c5e4577a7c0ef16c Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Thu, 8 May 2025 10:50:36 +0200 Subject: [PATCH 25/31] Fix SERWaveIncoherentHitTest: Use ray_flags to be independent of aabb/tri hit order --- .../unittests/HLSLExec/ExecutionTest_SER.h | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index e3f5758d78..30b6547730 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -2303,17 +2303,20 @@ void raygen() dx::HitObject hitObject; + int cat = (launchIndex.x + launchIndex.y) % 4; + // Use wave incoherence to decide how to create the HitObject - if (launchIndex.x % 4 == 1) + if (cat == 1) { - ray.Origin.x += 2.0f; - hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_CLOSEST_HIT_SHADER, 0xFF, 0, 0, 0, ray, payload); + // Turn this into an expected miss by moving eye behind triangles + ray.Origin.z -= 1000.0f; + hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES, 0xFF, 0, 0, 0, ray, payload); } - else if (launchIndex.x % 4 == 2) + else if (cat == 2) { hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES, 0xFF, 0, 0, 0, ray, payload); } - else if (launchIndex.x % 4 == 3) + else if (cat == 3) { hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_TRIANGLES, 0xFF, 0, 0, 0, ray, payload); } @@ -2376,7 +2379,7 @@ void intersection() // hitPos is intersection point with plane (base, n) float3 base = {0.0f,0.0f,0.5f}; float3 n = normalize(float3(0.0f,0.5f,0.5f)); - float radius = 1000.f; + float radius = 500.f; // Plane hit float t = dot(n, base - ObjectRayOrigin()) / dot(n, ObjectRayDirection()); if (t > RayTCurrent() || t < RayTMin()) { @@ -2412,13 +2415,12 @@ void intersection() std::map Histo; for (int Val : TestData) ++Histo[Val]; - VERIFY_ARE_EQUAL(Histo.size(), 6); - VERIFY_ARE_EQUAL(Histo[1], 1024); // nop - VERIFY_ARE_EQUAL(Histo[2], 1022); // miss - VERIFY_ARE_EQUAL(Histo[4], 12); // triangle hit, no ch - VERIFY_ARE_EQUAL(Histo[8], 1008); // procedural hit, no ch - VERIFY_ARE_EQUAL(Histo[20], 11); // triangle hit, 'closesthit' invoked - VERIFY_ARE_EQUAL(Histo[40], 1019); // procedural hit, 'chAABB' invoked + + VERIFY_ARE_EQUAL(Histo.size(), 4); + VERIFY_ARE_EQUAL(Histo[1], 1024); // nop + VERIFY_ARE_EQUAL(Histo[2], 2243); // miss + VERIFY_ARE_EQUAL(Histo[20], 16); // triangle hit + VERIFY_ARE_EQUAL(Histo[40], 813); // procedural hit } TEST_F(ExecutionTest, SERReorderCoherentTest) { From 0aee7ff984dbf0ad6cddd6084fe25eb047a75571 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Thu, 8 May 2025 10:54:28 +0200 Subject: [PATCH 26/31] nfc: formatting --- tools/clang/unittests/HLSLExec/ExecutionTest_SER.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 30b6547730..8bffa410d5 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -2419,7 +2419,7 @@ void intersection() VERIFY_ARE_EQUAL(Histo.size(), 4); VERIFY_ARE_EQUAL(Histo[1], 1024); // nop VERIFY_ARE_EQUAL(Histo[2], 2243); // miss - VERIFY_ARE_EQUAL(Histo[20], 16); // triangle hit + VERIFY_ARE_EQUAL(Histo[20], 16); // triangle hit VERIFY_ARE_EQUAL(Histo[40], 813); // procedural hit } From 57a36b3c4d2baec0208cc58189cb9e4e1f4041ac Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Wed, 27 Aug 2025 07:54:20 +0200 Subject: [PATCH 27/31] [SER] Execution test update Issue: https://github.com/microsoft/hlsl-specs/issues/613 Changes: - Added permutation testing for getters and intersection attributes. Combinations of: * Producing HitObjects (HitObject::TraceRay, HitObject::FromRayQuery, HitObject::MakeMiss) * Querying property from HitObject (Direct HitObject getter or HitObject::Invoke+classic getter) * HitObject tested in raygen,closesthit or miss (adds recursion) * with and without reorder - Fused all Scalar/VectorMatrixGetter tests into one SERGetterPermutationTest - Added SERAttributesPermutationTest: Testing procedural and triangle intersection attributes - Added SERNOPValuesTest: Explicit test for default return values in NOP HitObjects. - Added SERMultiPayloadTest: Testing multiple live HitObjects with differing payload types, SBT indices and loop control flow. --- .../unittests/HLSLExec/ExecutionTest.cpp | 218 +- .../unittests/HLSLExec/ExecutionTest_SER.h | 1768 +++++++++++------ 2 files changed, 1345 insertions(+), 641 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 0632fd91ca..16ccf9f11c 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -623,9 +623,7 @@ class ExecutionTest { // Shader Execution Reordering tests TEST_METHOD(SERBasicTest); - TEST_METHOD(SERScalarGetterTest); - TEST_METHOD(SERVectorGetterTest); - TEST_METHOD(SERMatrixGetterTest); + TEST_METHOD(SERNOPValuesTest); TEST_METHOD(SERRayQueryTest); TEST_METHOD(SERIntersectionTest); TEST_METHOD(SERGetAttributesTest); @@ -638,6 +636,9 @@ class ExecutionTest { TEST_METHOD(SERDynamicHitObjectArrayTest); TEST_METHOD(SERWaveIncoherentHitTest); TEST_METHOD(SERReorderCoherentTest); + TEST_METHOD(SERGetterPermutationTest); + TEST_METHOD(SERAttributesPermutationTest); + TEST_METHOD(SERMultiPayloadTest); // CoopVec tests TEST_METHOD(CoopVec_Mul); @@ -2128,12 +2129,40 @@ class ExecutionTest { int numOptions); bool CreateDXRDevice(ID3D12Device **ppDevice, D3D_SHADER_MODEL testModel, bool skipUnsupported); + struct DXRRunConfig { + int WindowWidth = 64; + int WindowHeight = 64; + bool UseMesh = true; + bool UseProceduralGeometry = false; + int PayloadCount = 1; + int AttributeCount = 2; + int MaxRecursion = 1; + int NumMissShaders = 1; + int NumHitGroups = 1; + }; + CComPtr + RunDXRTest(ID3D12Device *Device0, LPCSTR ShaderSrc, LPCWSTR TargetProfile, + LPCWSTR *Options, int NumOptions, std::vector &TestData, + const DXRRunConfig &Config); + CComPtr RunDXRTest(ID3D12Device *Device0, LPCSTR ShaderSrc, LPCWSTR TargetProfile, LPCWSTR *Options, int NumOptions, std::vector &TestData, int WindowWidth, int WindowHeight, bool UseMesh, bool UseProceduralGeometry, - int PayloadCount, int AttributeCount); + int PayloadCount, int AttributeCount) { + DXRRunConfig Config = {WindowWidth, + WindowHeight, + UseMesh, + UseProceduralGeometry, + PayloadCount, + AttributeCount, + 1, + 1, + 1}; + return RunDXRTest(Device0, ShaderSrc, TargetProfile, Options, NumOptions, + TestData, Config); + } void SetDescriptorHeap(ID3D12GraphicsCommandList *pCommandList, ID3D12DescriptorHeap *pHeap) { @@ -2313,11 +2342,11 @@ bool ExecutionTest::CreateDXRDevice(ID3D12Device **ppDevice, return false; } -CComPtr ExecutionTest::RunDXRTest( - ID3D12Device *Device0, LPCSTR ShaderSrc, LPCWSTR TargetProfile, - LPCWSTR *Options, int NumOptions, std::vector &TestData, - int WindowWidth, int WindowHeight, bool UseMesh, bool UseProceduralGeometry, - int PayloadCount, int AttributeCount) { +CComPtr +ExecutionTest::RunDXRTest(ID3D12Device *Device0, LPCSTR ShaderSrc, + LPCWSTR TargetProfile, LPCWSTR *Options, + int NumOptions, std::vector &TestData, + const DXRRunConfig &Config) { CComPtr Device; VERIFY_SUCCEEDED(Device0->QueryInterface(IID_PPV_ARGS(&Device))); @@ -2421,7 +2450,7 @@ CComPtr ExecutionTest::RunDXRTest( {0.f, 301.f, 0.f, 0.f}, {0.f, 0., -699.f, 0.f}, 100.f, - {(unsigned int)WindowWidth, (unsigned int)WindowHeight}, + {(unsigned int)Config.WindowWidth, (unsigned int)Config.WindowHeight}, 0x00}; memcpy(SceneConstantBufferWO, &SceneConsts, sizeof(SceneConsts)); @@ -2493,29 +2522,77 @@ CComPtr ExecutionTest::RunDXRTest( CompileFromText(ShaderSrc, L"raygen", TargetProfile, &ShaderLib, Options, NumOptions); + // Construct HitGroups + struct HitGroupDesc { + std::wstring ClosestHit; + std::wstring AnyHit; + std::wstring Intersection; + std::wstring HitGroupName; + const bool IsProcedural() const { return !Intersection.empty();} + }; + std::vector HitGroupDescs; + + const bool PrimaryHitGroupsAreAABB = !Config.UseMesh && Config.UseProceduralGeometry; + const bool EnableSecondaryHitGroups = Config.UseMesh && Config.UseProceduralGeometry; + + // Base hit group + HitGroupDesc PrimaryHitGroup{L"closesthit", L"anyhit", L"", L"HitGroup"}; + if (PrimaryHitGroupsAreAABB) + PrimaryHitGroup.Intersection = L"intersection"; + HitGroupDescs.push_back(PrimaryHitGroup); + + for (int i = 1; i < Config.NumHitGroups; i++) { + std::wstring ClosestHit = L"closesthit" + std::to_wstring(i); + std::wstring AnyHit = L"anyhit" + std::to_wstring(i); + std::wstring Intersection = L""; + if (PrimaryHitGroupsAreAABB) + Intersection = L"intersection" + std::to_wstring(i); + std::wstring HitGroupName = L"HitGroup" + std::to_wstring(i); + HitGroupDescs.push_back( + HitGroupDesc{ClosestHit, AnyHit, Intersection, HitGroupName}); + } + + if (EnableSecondaryHitGroups) { + HitGroupDescs.push_back( + HitGroupDesc{L"chAABB", L"ahAABB", L"intersection", L"HitGroupAABB"}); + for (int i = 1; i < Config.NumHitGroups; i++) { + std::wstring ClosestHit = L"chAABB" + std::to_wstring(i); + std::wstring AnyHit = L"ahAABB" + std::to_wstring(i); + std::wstring Intersection = L"intersection" + std::to_wstring(i); + std::wstring HitGroupName = L"HitGroupAABB" + std::to_wstring(i); + HitGroupDescs.push_back( + HitGroupDesc{ClosestHit, AnyHit, Intersection, HitGroupName}); + } + } + + // Collect required shader names from HitGroups + std::vector ShaderNames; + ShaderNames.push_back(L"raygen"); + ShaderNames.push_back(L"miss"); + for (int i = 1; i < Config.NumMissShaders; i++) + ShaderNames.push_back(L"miss" + std::to_wstring(i)); + for (const HitGroupDesc &HitGroupDesc : HitGroupDescs) { + ShaderNames.push_back(HitGroupDesc.ClosestHit); + ShaderNames.push_back(HitGroupDesc.AnyHit); + if (HitGroupDesc.IsProcedural()) + ShaderNames.push_back(HitGroupDesc.Intersection); + } + // Describe and create the RT pipeline state object (RTPSO). CD3DX12_STATE_OBJECT_DESC StateObjectDesc( D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE); auto Lib = StateObjectDesc.CreateSubobject(); CD3DX12_SHADER_BYTECODE ByteCode(ShaderLib); Lib->SetDXILLibrary(&ByteCode); - Lib->DefineExport(L"raygen"); - Lib->DefineExport(L"closesthit"); - Lib->DefineExport(L"anyhit"); - Lib->DefineExport(L"miss"); - if (UseProceduralGeometry) - Lib->DefineExport(L"intersection"); - if (UseMesh && UseProceduralGeometry) { - Lib->DefineExport(L"ahAABB"); - Lib->DefineExport(L"chAABB"); - } - - const int MaxRecursion = 1; + + for (std::wstring Export : ShaderNames) + Lib->DefineExport(Export.c_str()); + StateObjectDesc.CreateSubobject() - ->Config(PayloadCount * sizeof(float), AttributeCount * sizeof(float)); + ->Config(Config.PayloadCount * sizeof(float), Config.AttributeCount * sizeof(float)); StateObjectDesc .CreateSubobject() - ->Config(MaxRecursion); + ->Config(Config.MaxRecursion); // Set Global Root Signature subobject. auto GlobalRootSigSubObj = @@ -2529,37 +2606,21 @@ CComPtr ExecutionTest::RunDXRTest( auto Exports = StateObjectDesc.CreateSubobject< CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT>(); Exports->SetSubobjectToAssociate(*GlobalRootSigSubObj); - Exports->AddExport(L"raygen"); - Exports->AddExport(L"closesthit"); - Exports->AddExport(L"anyhit"); - Exports->AddExport(L"miss"); - if (UseProceduralGeometry) - Exports->AddExport(L"intersection"); - if (UseMesh && UseProceduralGeometry) { - Exports->AddExport(L"ahAABB"); - Exports->AddExport(L"chAABB"); - } - - auto HitGroup = - StateObjectDesc.CreateSubobject(); - HitGroup->SetClosestHitShaderImport(L"closesthit"); - HitGroup->SetAnyHitShaderImport(L"anyhit"); - if (!UseMesh && UseProceduralGeometry) { - HitGroup->SetIntersectionShaderImport(L"intersection"); - HitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); - } else { - HitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_TRIANGLES); - } - HitGroup->SetHitGroupExport(L"HitGroup"); + for (std::wstring Export : ShaderNames) + Exports->AddExport(Export.c_str()); - if (UseMesh && UseProceduralGeometry) { - auto HitGroupAABB = + for (const HitGroupDesc &HitGroupDesc : HitGroupDescs) { + auto HitGroup = StateObjectDesc.CreateSubobject(); - HitGroupAABB->SetAnyHitShaderImport(L"ahAABB"); - HitGroupAABB->SetClosestHitShaderImport(L"chAABB"); - HitGroupAABB->SetIntersectionShaderImport(L"intersection"); - HitGroupAABB->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); - HitGroupAABB->SetHitGroupExport(L"HitGroupAABB"); + HitGroup->SetClosestHitShaderImport(HitGroupDesc.ClosestHit.c_str()); + HitGroup->SetAnyHitShaderImport(HitGroupDesc.AnyHit.c_str()); + if (HitGroupDesc.IsProcedural()) { + HitGroup->SetIntersectionShaderImport(HitGroupDesc.Intersection.c_str()); + HitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE); + } else { + HitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_TRIANGLES); + } + HitGroup->SetHitGroupExport(HitGroupDesc.HitGroupName.c_str()); } CComPtr StateObject; @@ -2572,32 +2633,46 @@ CComPtr ExecutionTest::RunDXRTest( ShaderTable ShaderTable( Device, 1, // raygen count - 1, // miss count - UseMesh && UseProceduralGeometry ? 2 : 1, // hit group count + Config.NumMissShaders, // miss count + (int) HitGroupDescs.size(), // hit group count 1, // ray type count 4 // dwords per root table ); int LocalRootConsts[4] = {12, 34, 56, 78}; + + // raygen memcpy(ShaderTable.GetRaygenShaderIdPtr(0), StateObjectProperties->GetShaderIdentifier(L"raygen"), SHADER_ID_SIZE_IN_BYTES); memcpy(ShaderTable.GetRaygenRootTablePtr(0), LocalRootConsts, sizeof(LocalRootConsts)); + + // miss shaders memcpy(ShaderTable.GetMissShaderIdPtr(0, 0), StateObjectProperties->GetShaderIdentifier(L"miss"), SHADER_ID_SIZE_IN_BYTES); memcpy(ShaderTable.GetMissRootTablePtr(0, 0), LocalRootConsts, sizeof(LocalRootConsts)); - memcpy(ShaderTable.GetHitGroupShaderIdPtr(0, 0), - StateObjectProperties->GetShaderIdentifier(L"HitGroup"), - SHADER_ID_SIZE_IN_BYTES); - memcpy(ShaderTable.GetHitGroupRootTablePtr(0, 0), LocalRootConsts, - sizeof(LocalRootConsts)); - if (UseMesh && UseProceduralGeometry) - memcpy(ShaderTable.GetHitGroupShaderIdPtr(0, 1), - StateObjectProperties->GetShaderIdentifier(L"HitGroupAABB"), + for (int i = 1; i < Config.NumMissShaders; i++) { + std::wstring MissShaderName = L"miss" + std::to_wstring(i); + memcpy(ShaderTable.GetMissShaderIdPtr(i, 0), + StateObjectProperties->GetShaderIdentifier(MissShaderName.c_str()), SHADER_ID_SIZE_IN_BYTES); + memcpy(ShaderTable.GetMissRootTablePtr(i, 0), LocalRootConsts, + sizeof(LocalRootConsts)); + } + + // hit groups + for (int HitGroupIdx = 0; HitGroupIdx < HitGroupDescs.size(); HitGroupIdx++) { + const HitGroupDesc &HitGroupDesc = HitGroupDescs[HitGroupIdx]; + memcpy( + ShaderTable.GetHitGroupShaderIdPtr(HitGroupIdx, 0), + StateObjectProperties->GetShaderIdentifier(HitGroupDesc.HitGroupName.c_str()), + SHADER_ID_SIZE_IN_BYTES); + memcpy(ShaderTable.GetHitGroupRootTablePtr(HitGroupIdx, 0), LocalRootConsts, + sizeof(LocalRootConsts)); + } // Create a command allocator and list. CComPtr CommandAllocator; @@ -2622,7 +2697,7 @@ CComPtr ExecutionTest::RunDXRTest( CComPtr BLASProceduralGeometryResource; CComPtr ScratchResource; - if (UseMesh) { + if (Config.UseMesh) { CComPtr VertexBuffer; CComPtr VertexBufferUpload; CComPtr IndexBuffer; @@ -2759,7 +2834,7 @@ CComPtr ExecutionTest::RunDXRTest( VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, nullptr)); } - if (UseProceduralGeometry) { + if (Config.UseProceduralGeometry) { // Define procedural geometry AABB for a plane CComPtr AabbBuffer; CComPtr AabbBufferUpload; @@ -2841,7 +2916,7 @@ CComPtr ExecutionTest::RunDXRTest( { D3D12_RAYTRACING_INSTANCE_DESC CPUInstanceDescs[2] = {}; const int MeshIdx = 0; - const int ProcGeoIdx = UseMesh && UseProceduralGeometry ? 1 : 0; + const int ProcGeoIdx = Config.UseMesh && Config.UseProceduralGeometry ? 1 : 0; const int NumInstanceDescs = ProcGeoIdx + 1; for (int i = 0; i < NumInstanceDescs; ++i) { @@ -2849,15 +2924,16 @@ CComPtr ExecutionTest::RunDXRTest( InstanceDesc.Transform[0][0] = InstanceDesc.Transform[1][1] = InstanceDesc.Transform[2][2] = 1; InstanceDesc.InstanceID = i; - InstanceDesc.InstanceContributionToHitGroupIndex = i; + InstanceDesc.InstanceContributionToHitGroupIndex = + i * Config.NumHitGroups; InstanceDesc.InstanceMask = 1; InstanceDesc.Flags = D3D12_RAYTRACING_INSTANCE_FLAG_NONE; } - if (UseMesh) + if (Config.UseMesh) CPUInstanceDescs[MeshIdx].AccelerationStructure = BLASMeshResource->GetGPUVirtualAddress(); - if (UseProceduralGeometry) + if (Config.UseProceduralGeometry) CPUInstanceDescs[ProcGeoIdx].AccelerationStructure = BLASProceduralGeometryResource->GetGPUVirtualAddress(); @@ -2934,8 +3010,8 @@ CComPtr ExecutionTest::RunDXRTest( ShaderTable.GetHitGroupRangeInBytes(); DispatchDesc.HitGroupTable.StrideInBytes = ShaderTable.GetShaderRecordSizeInBytes(); - DispatchDesc.Width = WindowWidth; - DispatchDesc.Height = WindowHeight; + DispatchDesc.Width = Config.WindowWidth; + DispatchDesc.Height = Config.WindowHeight; DispatchDesc.Depth = 1; CommandList->SetPipelineState1(StateObject); CommandList->DispatchRays(&DispatchDesc); diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 8bffa410d5..64b1793b2f 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -12,9 +12,257 @@ #pragma once -TEST_F(ExecutionTest, SERScalarGetterTest) { - // SER: Test basic function of HitObject getters. - static const char *ShaderSrc = R"( +struct SERAccessor { + enum ScalarTypes { + UINT = 0, + FLOAT = 1, + }; + + ScalarTypes ScalarType; + int ValRows; + int ValCols; + + LPCWSTR HitObjectGetter; + LPCWSTR CHGetter; + LPCWSTR MSGetter; + LPCWSTR NOPGetter; + + LPCWSTR getScalarTypeName() const { + switch (ScalarType) { + case UINT: + return L"uint"; + case FLOAT: + return L"float"; + default: + return L"UNKOWN TYPE"; + } + } + + void addCompileArgs(std::vector &OwnedArgs, + std::vector &ArgVec) const { + // Value dimensions + OwnedArgs.emplace_back(L"-DM_ROWS=" + std::to_wstring(ValRows)); + ArgVec.push_back(OwnedArgs.back().c_str()); + OwnedArgs.emplace_back(L"-DM_COLS=" + std::to_wstring(ValCols)); + ArgVec.push_back(OwnedArgs.back().c_str()); + + LPCWSTR ScalarTypeName = getScalarTypeName(); + OwnedArgs.emplace_back(L"-DSCALAR_TYPE=" + std::wstring(ScalarTypeName)); + ArgVec.push_back(OwnedArgs.back().c_str()); + if (ValRows == 0 && ValCols == 0) { + // Scalar + OwnedArgs.emplace_back(L"-DRESULT_TYPE=" + std::wstring(ScalarTypeName)); + ArgVec.push_back(OwnedArgs.back().c_str()); + } else if (ValRows >= 1 && ValCols == 0) { + // Vector + OwnedArgs.emplace_back(L"-DRESULT_TYPE=" + std::wstring(ScalarTypeName) + + std::to_wstring(ValRows)); + ArgVec.push_back(OwnedArgs.back().c_str()); + } else if (ValRows > 0 && ValCols > 0) { + // Matrix + OwnedArgs.emplace_back(L"-DMATRIX_ELEMENT_TYPE=" + + std::wstring(ScalarTypeName)); + ArgVec.push_back(OwnedArgs.back().c_str()); + } + + OwnedArgs.emplace_back(L"-DHITOBJECT_GET_RESULT=" + + std::wstring(HitObjectGetter)); + ArgVec.push_back(OwnedArgs.back().c_str()); + OwnedArgs.emplace_back(L"-DCH_GET_RESULT=" + std::wstring(CHGetter)); + ArgVec.push_back(OwnedArgs.back().c_str()); + OwnedArgs.emplace_back(L"-DMS_GET_RESULT=" + std::wstring(MSGetter)); + ArgVec.push_back(OwnedArgs.back().c_str()); + OwnedArgs.emplace_back(L"-DNOP_GET_RESULT=" + std::wstring(NOPGetter)); + ArgVec.push_back(OwnedArgs.back().c_str()); + } +}; + +struct SERTestConfig { + // Source of the hit object or reference value under test. + enum Method { + TraceRay = + 0, // Source queried in closesthit, miss shaders called by TraceRay + RayQuery = 1, // Source is HitObject::FromRayQuery + HitObject_TraceRay = 2, // Source is HitObject::TraceRay + HitObject_Invoke = 3, // [only used for recursion] + }; + std::wstring getMethodStr(Method src) const { + switch (src) { + case TraceRay: + return L"TraceRay"; + case RayQuery: + return L"RayQuery"; + case HitObject_TraceRay: + return L"HitObject_TraceRay"; + default: + return L"UNKNOWN"; + } + } + + enum ResultFrom { + FromShaders = 0, // Call getters in CH, MS + FromHitObject = 1, // Call getters on HitObject + }; + std::wstring getResultFromStr(ResultFrom resultFrom) const { + switch (resultFrom) { + case FromShaders: + return L"FromShaders"; + case FromHitObject: + return L"FromHitObject"; + default: + return L"UNKNOWN"; + } + } + + // Where the hit object code is located. + enum TestLocation { + RayGen = 0, // In raygen shader + ClosestHit = 1, // In closesthit shader + Miss = 2, // In miss shader + }; + std::wstring getTestLocationStr(TestLocation loc) const { + switch (loc) { + case RayGen: + return L"RayGen"; + case ClosestHit: + return L"ClosestHit"; + case Miss: + return L"Miss"; + default: + return L"UNKNOWN"; + } + } + + bool UseTriangles; + bool UseProceduralGeometry; + + bool ReorderHitObject; + TestLocation TestLoc; + + Method TraceMethod; + ResultFrom ResultSrc; + + Method RecMethod; // only used if TestLoc != RayGen + + // TestLocation TestLocation; + // + const bool hasRecursion() const { return TestLoc != TestLocation::RayGen; } + + void addCompileArgs(std::vector &ArgVec) const { + // How to produce the hit object and get the value from it + switch (TraceMethod) { + case TraceRay: + // Getter called on HitObject produced by HitObject::TraceRay + ArgVec.push_back(L"-DMETHOD_TRACERAY=1"); + break; + case HitObject_TraceRay: + // Getter called on HitObject produced by HitObject::TraceRay + ArgVec.push_back(L"-DMETHOD_HITOBJECT_TRACERAY=1"); + break; + case RayQuery: + // Getter called on HitObject produced by HitObject::FromRayQuery + ArgVec.push_back(L"-DMETHOD_HITOBJECT_FROMRQ=1"); + break; + default: + VERIFY_IS_TRUE(false); + break; + } + + switch (ResultSrc) { + case FromShaders: + ArgVec.push_back(L"-DRESULT_FROM_SHADERS=1"); + break; + case FromHitObject: + ArgVec.push_back(L"-DRESULT_FROM_HITOBJECT=1"); + break; + default: + VERIFY_IS_TRUE(false); + break; + } + + if (ReorderHitObject) + ArgVec.push_back(L"-DREORDER_HITOBJECT=1"); + + switch (TestLoc) { + case TestLocation::RayGen: + ArgVec.push_back(L"-DTESTLOC_RAYGEN=1"); + break; + case TestLocation::ClosestHit: + ArgVec.push_back(L"-DTESTLOC_CLOSESTHIT=1"); + case TestLocation::Miss: + ArgVec.push_back(L"-DTESTLOC_MISS=1"); + break; + default: + VERIFY_IS_TRUE(false); + break; + } + + if (hasRecursion()) { + ArgVec.push_back(L"-DENABLE_RECURSION=1"); + + // Primary shading call to test HitObject in CH/MS + switch (RecMethod) { + case TraceRay: + ArgVec.push_back(L"-DRECMETHOD_TRACERAY=1"); + break; + case HitObject_Invoke: + ArgVec.push_back(L"-DRECMETHOD_HITOBJECT_INVOKE=1"); + break; + default: + VERIFY_IS_TRUE(false); + break; + } + } + } + + std::wstring str() const { + std::wstring txt; + if (UseTriangles) + txt += L"tris;"; + if (UseProceduralGeometry) + txt += L"aabbs;"; + txt += L"trace=" + getMethodStr(TraceMethod) + L";"; + txt += L"result=" + getResultFromStr(ResultSrc) + L";"; + txt += L"loc=" + getTestLocationStr(TestLoc) + L";"; + if (ReorderHitObject) { + txt += L"reorder;"; + } + if (hasRecursion()) { + txt += L"rec;"; + } + return txt; + } +}; + +// clang-format off +static constexpr SERAccessor Accessors[] = { + // Scalar + {SERAccessor::FLOAT, 0, 0, L"GetRayTMin", L"RayTMin", L"RayTMin", L"getFloatZero"}, + {SERAccessor::FLOAT, 0, 0, L"GetRayTCurrent", L"RayTCurrent", L"RayTCurrent", L"getFloatZero"}, + {SERAccessor::UINT, 0, 0, L"GetRayFlags", L"RayFlags", L"RayFlags", L"getIntZero"}, + {SERAccessor::UINT, 0, 0, L"GetHitKind", L"HitKind", L"getIntZero", L"getIntZero"}, + {SERAccessor::UINT, 0, 0, L"GetGeometryIndex", L"GeometryIndex", L"getIntZero", L"getIntZero"}, + {SERAccessor::UINT, 0, 0, L"GetInstanceIndex", L"InstanceIndex", L"getIntZero", L"getIntZero"}, + {SERAccessor::UINT, 0, 0, L"GetInstanceID", L"InstanceID", L"getIntZero", L"getIntZero"}, + {SERAccessor::UINT, 0, 0, L"GetPrimitiveIndex", L"PrimitiveIndex", L"getIntZero", L"getIntZero"}, + {SERAccessor::UINT, 0, 0, L"IsHit", L"getIntOne", L"getIntZero", L"getIntZero"}, + {SERAccessor::UINT, 0, 0, L"IsNop", L"getIntZero", L"getIntZero", L"getIntOne"}, + {SERAccessor::UINT, 0, 0, L"IsMiss", L"getIntZero", L"getIntOne", L"getIntZero"}, + // Vector + {SERAccessor::FLOAT, 3, 0, L"GetWorldRayOrigin", L"WorldRayOrigin", L"WorldRayOrigin", L"getVec3Zero"}, + {SERAccessor::FLOAT, 3, 0, L"GetWorldRayDirection", L"WorldRayDirection", L"WorldRayDirection", L"getVec3Zero"}, + {SERAccessor::FLOAT, 3, 0, L"GetObjectRayOrigin", L"ObjectRayOrigin", L"WorldRayOrigin", L"getVec3Zero"}, + {SERAccessor::FLOAT, 3, 0, L"GetObjectRayDirection", L"ObjectRayDirection", L"WorldRayDirection", L"getVec3Zero"}, + // Matrix + {SERAccessor::FLOAT, 3, 4, L"GetWorldToObject3x4", L"WorldToObject3x4", L"getOneDiagonalMat", L"getOneDiagonalMat"}, + {SERAccessor::FLOAT, 4, 3, L"GetWorldToObject4x3", L"WorldToObject4x3", L"getOneDiagonalMat", L"getOneDiagonalMat"}, + {SERAccessor::FLOAT, 3, 4, L"GetObjectToWorld3x4", L"ObjectToWorld3x4", L"getOneDiagonalMat", L"getOneDiagonalMat"}, + {SERAccessor::FLOAT, 4, 3, L"GetObjectToWorld4x3", L"ObjectToWorld4x3", L"getOneDiagonalMat", L"getOneDiagonalMat"}, +}; +// clang-format on + +static const char *SERPermutationTestShaderSrc = R"( + struct SceneConstants { float4 eye; @@ -26,17 +274,23 @@ struct SceneConstants int rayFlags; }; -struct[raypayload] PerRayData +#ifdef MATRIX_ELEMENT_TYPE +typedef matrix ValueType; +#else +typedef RESULT_TYPE ValueType; +#endif + +struct [raypayload] PerRayData { - VALTYPE value : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit); + int recursionDepth : read(caller,closesthit,miss) : write(caller,closesthit,miss); }; -struct Attrs +struct TriangleAttrs { - float2 barycentrics : BARYCENTRICS; + float2 barycentrics; }; -RWStructuredBuffer testBuffer : register(u0); +RWStructuredBuffer testBuffer : register(u0); RaytracingAccelerationStructure topObject : register(t0); ConstantBuffer sceneConstants : register(b0); @@ -55,264 +309,578 @@ RayDesc ComputeRay() return ray; } -[shader("raygeneration")] -void raygen() +#ifdef MATRIX_ELEMENT_TYPE +typedef matrix MatrixType; + +matrix getOneDiagonalMat() { + matrix mat = 0; + mat[0][0] = 1.f; + mat[1][1] = 1.f; + mat[2][2] = 1.f; + return mat; +} +#endif + +void StoreResult(ValueType result) { + const int numRows = M_ROWS > 0 ? M_ROWS : 1; + const int numCols = M_COLS > 0 ? M_COLS : 1; + const int numResultElements = numRows * numCols; + + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + const int id = numResultElements * (launchIndex.x + launchIndex.y * launchDim.x); + +#ifdef MATRIX_ELEMENT_TYPE +#if M_ROWS == 0 || M_COLS == 0 +#error "Zero-sized matrix dimension" +#endif + + // Matrix + for (int r = 0; r < M_ROWS; r++) { + for (int c = 0; c < M_COLS; c++) { + testBuffer[id + (r * M_COLS + c)] = result[r][c]; + } + } + +#elif M_ROWS +#if M_COLS +#error "Rows specified for vector" +#endif + // Vector + for (int r = 0; r < M_ROWS; r++) { + testBuffer[id + r] = result[r]; + } +#else + testBuffer[id] = result; +#endif +} + +// Procedural geometry for use by RayQuery and intersection shader +static const int ProceduralHitKind = 11; + +struct CustomAttrs { - uint2 launchIndex = DispatchRaysIndex().xy; - uint2 launchDim = DispatchRaysDimensions().xy; - int id = 2 * (launchIndex.x + launchIndex.y * launchDim.x); + float dist; +}; - RayDesc ray = ComputeRay(); +bool evalIntersection(float3 objRayOrigin, float3 objRayDir, float rayTMax, float rayTMin, out CustomAttrs attrs, out float rayT) +{ + rayT = 0; + // Intersection with circle on a plane (base, n, radius) + // hitPos is intersection point with plane (base, n) + float3 base = {0.0f,0.0f,0.5f}; + float3 n = normalize(float3(0.0f,0.5f,0.5f)); + float radius = 500.f; + // Plane hit + float t = dot(n, base - objRayOrigin) / dot(n, objRayDir); + if (t > rayTMax || t < rayTMin) { + return false; + } + float3 hitPos = objRayOrigin + t * objRayDir; + float3 relHitPos = hitPos - base; + // Circle hit + float hitDist = length(relHitPos); + if (hitDist > radius) + return false; + + attrs.dist = hitDist; + rayT = t; + return true; +} + +#if ATTRIBUTES_TEST +void StoreTriangleAttributes(TriangleAttrs attrs) { + float2 resValue = attrs.barycentrics; + StoreResult(resValue); +} + +void StoreProceduralAttributes(CustomAttrs attrs) { + float2 resValue = {attrs.dist, 0}; + StoreResult(resValue); +} +#endif - // Fetch reference value - PerRayData refPayload; - TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, refPayload); - testBuffer[id] = refPayload.value; + +static dx::HitObject hitObjectTraceFromRQ(RayDesc ray) { + RayQuery rayQ; + rayQ.TraceRayInline(topObject, RAY_FLAG_NONE, 0xFF, ray); + + float tHit = 0; + CustomAttrs customAttrs = {0}; + + while (rayQ.Proceed()) { + switch (rayQ.CandidateType()) { - PerRayData serPayload; - dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, serPayload); - dx::MaybeReorderThread(hitObject); - VALTYPE serVal = hitObject.SER_GET_SCALAR(); - testBuffer[id + 1] = serVal; + // Acccept all triangle hits + case CANDIDATE_NON_OPAQUE_TRIANGLE: { + rayQ.CommitNonOpaqueTriangleHit(); + break; + } + + // Use same decision logic as intersection shader + case CANDIDATE_PROCEDURAL_PRIMITIVE: { + if (evalIntersection(rayQ.CandidateObjectRayOrigin(), rayQ.CandidateObjectRayDirection(), rayQ.CommittedRayT(), rayQ.RayTMin(), customAttrs, tHit)) { + rayQ.CommitProceduralPrimitiveHit(tHit); + } + break; + } + + default: + break; + } + } + + switch (rayQ.CommittedStatus()) { + case COMMITTED_NOTHING: + return dx::HitObject::MakeMiss(RAY_FLAG_NONE, 0, ray); + case COMMITTED_TRIANGLE_HIT: { + TriangleAttrs attrs; + attrs.barycentrics = rayQ.CommittedTriangleBarycentrics(); + uint HitKind = rayQ.CommittedTriangleFrontFace() ? HIT_KIND_TRIANGLE_FRONT_FACE : HIT_KIND_TRIANGLE_BACK_FACE; + dx::HitObject hitObject = dx::HitObject::FromRayQuery(rayQ, HitKind, attrs); + hitObject.SetShaderTableIndex(0); + return hitObject; + } + case COMMITTED_PROCEDURAL_PRIMITIVE_HIT: { + dx::HitObject hitObject = dx::HitObject::FromRayQuery(rayQ, ProceduralHitKind, customAttrs); + hitObject.SetShaderTableIndex(0); + return hitObject; + } + default: + return dx::HitObject(); + } +} + +void CallTraceMethod(int recursionDepth) { + const int numRows = M_ROWS > 0 ? M_ROWS : 1; + const int numCols = M_COLS > 0 ? M_COLS : 1; + const int numResultElements = numRows * numCols; + + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + const int id = numResultElements * (launchIndex.x + launchIndex.y * launchDim.x); + + RayDesc ray = ComputeRay(); + + PerRayData payload; +#ifdef ENABLE_RECURSION + payload.recursionDepth = recursionDepth; +#endif + + +#if METHOD_TRACERAY +///// Reference result + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); +#if !RESULT_FROM_SHADERS + #error "TraceRay() implicitly gets results from shaders" +#endif + +///// Produce hit object +#elif METHOD_HITOBJECT_TRACERAY + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, payload); + +#elif METHOD_HITOBJECT_FROMRQ + dx::HitObject hitObject = hitObjectTraceFromRQ(ray); +#endif + +#if REORDER_HITOBJECT + dx::MaybeReorderThread(hitObject); +#endif + +///// Query hit object getter directly +#if RESULT_FROM_HITOBJECT +#if ATTRIBUTES_TEST + // TODO: Update GetAttributes API + if (hitObject.IsMiss()) { + // Test for zero-init of miss + StoreTriangleAttributes(hitObject.GetAttributes()); + } else if (hitObject.GetHitKind() == ProceduralHitKind) { + StoreProceduralAttributes(hitObject.GetAttributes()); + } else { + StoreTriangleAttributes(hitObject.GetAttributes()); + } +#else + StoreResult(hitObject.HITOBJECT_GET_RESULT()); +#endif + +#elif RESULT_FROM_SHADERS +#if !METHOD_TRACERAY + // Already invoked in TraceRay() + dx::HitObject::Invoke(hitObject, payload); +#endif +#endif +} + +[shader("raygeneration")] +void raygen() +{ +#if ENABLE_RECURSION + RayDesc ray = ComputeRay(); + PerRayData recPayload; + recPayload.recursionDepth = 1; +#if RECMETHOD_TRACERAY + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, recPayload); +#elif RECMETHOD_HITOBJECT_INVOKE + dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, recPayload); + dx::HitObject::Invoke(hitObject, recPayload); +#else +#error "Unsupported shading method in recursive tests" +#endif + return; +#endif + +#if TESTLOC_RAYGEN + CallTraceMethod(1); + return; +#if ENABLE_RECURSION +#error "Must disable recursion when testing in raygen" +#endif +#endif } float getFloatZero() { return 0.0f; } int getIntZero() { return 0; } +int getIntOne() { return 1; } [shader("miss")] void miss(inout PerRayData payload) { - payload.value = MISS_GET_SCALAR(); +#if TESTLOC_MISS + if (payload.recursionDepth == 1) + { + CallTraceMethod(payload.recursionDepth + 1); + return; + } +#endif + +#if ATTRIBUTES_TEST + StoreResult(float2(0,0)); +#else + StoreResult(MS_GET_RESULT()); +#endif } +///// Triangle hit group [shader("anyhit")] -void anyhit(inout PerRayData payload, in Attrs attrs) +void anyhit(inout PerRayData payload, in TriangleAttrs attrs) { - // UNUSED + // UNUSED } [shader("closesthit")] -void closesthit(inout PerRayData payload, in Attrs attrs) +void closesthit(inout PerRayData payload, in TriangleAttrs attrs) { - payload.value = HIT_GET_SCALAR(); +#if TESTLOC_CLOSESTHIT + if (payload.recursionDepth == 1) + { + CallTraceMethod(payload.recursionDepth + 1); + return; + } +#endif + +#if ATTRIBUTES_TEST + StoreTriangleAttributes(attrs); +#else + StoreResult(CH_GET_RESULT()); +#endif } -)"; - CComPtr Device; - if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) +///// AABB hit group +[shader("closesthit")] +void chAABB(inout PerRayData payload, in CustomAttrs customAttrs) +{ +#if TESTLOC_CLOSESTHIT + if (payload.recursionDepth == 1) + { + CallTraceMethod(payload.recursionDepth + 1); return; + } +#endif - // Initialize test data. - const int WindowSize = 64; +#if ATTRIBUTES_TEST + StoreProceduralAttributes(customAttrs); +#else + StoreResult(CH_GET_RESULT()); +#endif +} - // RayTMin - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetRayTMin()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=float", - L"-DHIT_GET_SCALAR=RayTMin", - L"-DMISS_GET_SCALAR=RayTMin", - L"-DSER_GET_SCALAR=GetRayTMin"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - float *ResArray = (float *)(TestData.data() + Id); - float RefVal = ResArray[0]; - float SerVal = ResArray[1]; - const bool PassRayTMin = CompareFloatEpsilon(SerVal, RefVal, 0.0008f); - if (!PassRayTMin) { - VERIFY_IS_TRUE(PassRayTMin); - return; - } - } - WEX::Logging::Log::Comment(L"HitObject::GetRayTMin() PASSED"); +[shader("intersection")] +void intersection() +{ + CustomAttrs attrs = {0}; + float rayT; + if (evalIntersection(ObjectRayOrigin(), ObjectRayDirection(), RayTCurrent(), RayTMin(), attrs, rayT)) { + ReportHit(rayT, ProceduralHitKind, attrs); } +} - // RayTCurrent - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetRayTCurrent()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=float", - L"-DHIT_GET_SCALAR=RayTCurrent", - L"-DMISS_GET_SCALAR=RayTCurrent", - L"-DSER_GET_SCALAR=GetRayTCurrent"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - float *ResArray = (float *)(TestData.data() + Id); - float RefVal = ResArray[0]; - float SerVal = ResArray[1]; - const bool PassRayTCurrent = CompareFloatEpsilon(SerVal, RefVal, 0.0008f); - if (!PassRayTCurrent) { - VERIFY_IS_TRUE(PassRayTCurrent); - return; - } +[shader("anyhit")] +void ahAABB(inout PerRayData payload, in CustomAttrs attrs) +{ + // UNUSED +} + +)"; + +template +static void VerifyTestArray(const T* RefData, const T* TestData, int NumElements); + +template<> +void VerifyTestArray(const int* RefData, const int* TestData, int NumElements) { + for (int i = 0; i < NumElements; i++) { + if (RefData[i] != TestData[i]) { + VERIFY_ARE_EQUAL(RefData[i], TestData[i]); } - WEX::Logging::Log::Comment(L"HitObject::GetRayTCurrent() PASSED"); } +} - // RayFlags - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetRayFlags()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=uint", - L"-DHIT_GET_SCALAR=RayFlags", - L"-DMISS_GET_SCALAR=RayFlags", - L"-DSER_GET_SCALAR=GetRayFlags"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - const int RefVal = TestData[Id]; - const int SerVal = TestData[Id + 1]; - if (RefVal != SerVal) { - VERIFY_ARE_EQUAL(RefVal, SerVal); - return; - } +template<> +void VerifyTestArray(const float* RefData, const float* TestData, int NumElements) { + for (int i = 0; i < NumElements; i++) { + const float RefVal = RefData[i]; + const float TestVal = TestData[i]; + if (!CompareFloatEpsilon(TestVal, RefVal, 0.0008f)) { + VERIFY_ARE_EQUAL(TestVal, RefVal); } - WEX::Logging::Log::Comment(L"HitObject::GetRayFlags() PASSED"); } +} - // HitKind - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetHitKind()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=uint", - L"-DHIT_GET_SCALAR=HitKind", - L"-DMISS_GET_SCALAR=getIntZero", - L"-DSER_GET_SCALAR=GetHitKind"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - const int RefVal = TestData[Id]; - const int SerVal = TestData[Id + 1]; - if (RefVal != SerVal) { - VERIFY_ARE_EQUAL(RefVal, SerVal); - return; +TEST_F(ExecutionTest, SERGetterPermutationTest) { + // SER: Test basic function of HitObject getters. + + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) + return; + + SERTestConfig RefConfig = {true, + true, + false, + SERTestConfig::RayGen, + SERTestConfig::TraceRay, + SERTestConfig::FromShaders, + SERTestConfig::TraceRay}; + + std::vector TestConfigs; + for (SERTestConfig::TestLocation TestLoc : + {SERTestConfig::RayGen, SERTestConfig::Miss, + SERTestConfig::ClosestHit}) { + for (bool Reorder : {true, false}) { + // MaybeReorderThreads only supported in RayGens + if (TestLoc != SERTestConfig::RayGen && Reorder) + continue; + + for (SERTestConfig::Method TestMethod : + {SERTestConfig::HitObject_TraceRay, SERTestConfig::RayQuery}) { + for (SERTestConfig::ResultFrom ResultSrc : + {SERTestConfig::FromShaders, SERTestConfig::FromHitObject}) { + SERTestConfig TestConfig = RefConfig; + TestConfig.TestLoc = TestLoc; + TestConfig.TraceMethod = TestMethod; + TestConfig.ReorderHitObject = Reorder; + TestConfig.ResultSrc = ResultSrc; + + if (TestLoc == SERTestConfig::RayGen) { + TestConfigs.push_back(TestConfig); + continue; + } + + // Variations on primary shading call to test HitObject in CH/MS + for (SERTestConfig::Method RecMethod : + {SERTestConfig::TraceRay, SERTestConfig::HitObject_Invoke}) { + TestConfig.RecMethod = RecMethod; + TestConfigs.push_back(TestConfig); + } + } } } - WEX::Logging::Log::Comment(L"HitObject::GetHitKind() PASSED"); } - // GeometryIndex - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetGeometryIndex()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=uint", - L"-DHIT_GET_SCALAR=GeometryIndex", - L"-DMISS_GET_SCALAR=getIntZero", - L"-DSER_GET_SCALAR=GetGeometryIndex"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - const int RefVal = TestData[Id]; - const int SerVal = TestData[Id + 1]; - if (RefVal != SerVal) { - VERIFY_ARE_EQUAL(RefVal, SerVal); - return; + // 64 x 64 test window size + const int WindowSize = 64; + + for (const auto &Accessor : Accessors) { + const int NumResultRows = Accessor.ValRows > 0 ? Accessor.ValRows : 1; + const int NumResultCols = Accessor.ValCols > 0 ? Accessor.ValCols : 1; + const int NumResultElements = NumResultRows * NumResultCols; + const int RefMaxRecursion = RefConfig.hasRecursion() ? 2 : 1; + + // Query reference result + std::vector RefData(WindowSize * WindowSize * NumResultElements); + std::vector RefArgs; + std::vector OwnedRefArgs; + RefArgs.push_back(L"-HV 2021"); + RefArgs.push_back(L"-Vd"); + Accessor.addCompileArgs(OwnedRefArgs, RefArgs); + RefConfig.addCompileArgs(RefArgs); + + const int ExtraRec = 0; + DXRRunConfig RefRunConfig = { + WindowSize, + WindowSize, + RefConfig.UseTriangles, + RefConfig.UseProceduralGeometry, + RefMaxRecursion + ExtraRec, + }; + RunDXRTest(Device, SERPermutationTestShaderSrc, L"lib_6_9", RefArgs.data(), + (int)RefArgs.size(), RefData, RefRunConfig); + + // Test permutations + for (const auto &TestConfig : TestConfigs) { + DXRRunConfig TestRunConfig(RefRunConfig); + TestRunConfig.MaxRecursion = + ExtraRec + (TestConfig.hasRecursion() ? 2 : 1); + + std::wstring TestConfigTxt = L"HitObject::"; + TestConfigTxt += Accessor.HitObjectGetter; + TestConfigTxt += L"() with config " + TestConfig.str(); + + { + std::wstring TestingMsg = L"Testing " + TestConfigTxt; + WEX::Logging::Log::Comment(TestingMsg.c_str()); } - } - WEX::Logging::Log::Comment(L"HitObject::GetGeometryIndex() PASSED"); - } - // InstanceIndex - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetInstanceIndex()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=uint", - L"-DHIT_GET_SCALAR=InstanceIndex", - L"-DMISS_GET_SCALAR=getIntZero", - L"-DSER_GET_SCALAR=GetInstanceIndex"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - const int RefVal = TestData[Id]; - const int SerVal = TestData[Id + 1]; - if (RefVal != SerVal) { - VERIFY_ARE_EQUAL(RefVal, SerVal); - return; + std::vector Args; + std::vector OwnedArgs; + Args.push_back(L"-HV 2021"); + Args.push_back(L"-Vd"); + Accessor.addCompileArgs(OwnedArgs, Args); + TestConfig.addCompileArgs(Args); + + std::vector TestData(WindowSize * WindowSize * NumResultElements, 0); + + RunDXRTest(Device, SERPermutationTestShaderSrc, L"lib_6_9", Args.data(), + (int)Args.size(), TestData, TestRunConfig); + + const int NumArrayElems = WindowSize * WindowSize * NumResultElements; + switch (Accessor.ScalarType) { + case SERAccessor::FLOAT: + VerifyTestArray(reinterpret_cast(RefData.data()), + reinterpret_cast(TestData.data()), + NumArrayElems); + break; + case SERAccessor::UINT: + VerifyTestArray(reinterpret_cast(RefData.data()), + reinterpret_cast(TestData.data()), + NumArrayElems); + break; } } - WEX::Logging::Log::Comment(L"HitObject::GetInstanceIndex() PASSED"); } +} - // InstanceID - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetInstanceID()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=uint", - L"-DHIT_GET_SCALAR=InstanceID", - L"-DMISS_GET_SCALAR=getIntZero", - L"-DSER_GET_SCALAR=GetInstanceID"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - const int RefVal = TestData[Id]; - const int SerVal = TestData[Id + 1]; - if (RefVal != SerVal) { - VERIFY_ARE_EQUAL(RefVal, SerVal); - return; +TEST_F(ExecutionTest, SERAttributesPermutationTest) { + // SER: Test basic function of HitObject getters. + + CComPtr Device; + if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) + return; + + // All test variatinos + SERTestConfig RefConfig = {true, + true, + false, + SERTestConfig::RayGen, + SERTestConfig::TraceRay, + SERTestConfig::FromShaders, + SERTestConfig::TraceRay}; + + std::vector TestConfigs; + for (SERTestConfig::TestLocation TestLoc : + {SERTestConfig::RayGen, SERTestConfig::Miss, + SERTestConfig::ClosestHit}) { + for (bool Reorder : {true, false}) { + // MaybeReorderThreads only supported in RayGens + if (TestLoc != SERTestConfig::RayGen && Reorder) + continue; + + for (SERTestConfig::Method TestMethod : + {SERTestConfig::HitObject_TraceRay, SERTestConfig::RayQuery}) { + for (SERTestConfig::ResultFrom ResultSrc : + {SERTestConfig::FromShaders, SERTestConfig::FromHitObject}) { + SERTestConfig TestConfig = RefConfig; + TestConfig.TestLoc = TestLoc; + TestConfig.TraceMethod = TestMethod; + TestConfig.ReorderHitObject = Reorder; + TestConfig.ResultSrc = ResultSrc; + + if (TestLoc == SERTestConfig::RayGen) { + TestConfigs.push_back(TestConfig); + continue; + } + + // Variations on primary shading call to test HitObject in CH/MS + for (SERTestConfig::Method RecMethod : + {SERTestConfig::TraceRay, SERTestConfig::HitObject_Invoke}) { + TestConfig.RecMethod = RecMethod; + TestConfigs.push_back(TestConfig); + } + } } } - WEX::Logging::Log::Comment(L"HitObject::GetInstanceID() PASSED"); } - // PrimitiveIndex - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetPrimitiveIndex()"); - std::vector TestData(WindowSize * WindowSize * 2, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DVALTYPE=uint", - L"-DHIT_GET_SCALAR=PrimitiveIndex", - L"-DMISS_GET_SCALAR=getIntZero", - L"-DSER_GET_SCALAR=GetPrimitiveIndex"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 1 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2) { - const int RefVal = TestData[Id]; - const int SerVal = TestData[Id + 1]; - if (RefVal != SerVal) { - VERIFY_ARE_EQUAL(RefVal, SerVal); - return; - } + // 64 x 64 test window size + const int WindowSize = 64; + + const int NumResultElements = 2; // Just for Attrs + const int RefMaxRecursion = RefConfig.hasRecursion() ? 2 : 1; + + std::vector BaseArgs; + BaseArgs.push_back(L"-HV 2021"); + BaseArgs.push_back(L"-Vd"); + BaseArgs.push_back(L"-DSCALAR_TYPE=float"); + BaseArgs.push_back(L"-DRESULT_TYPE=float2"); + BaseArgs.push_back(L"-DM_ROWS=2"); + BaseArgs.push_back(L"-DM_COLS=0"); + BaseArgs.push_back(L"-DATTRIBUTES_TEST=1"); + + // Query reference result + std::vector RefData(WindowSize * WindowSize * NumResultElements); + std::vector RefArgs(BaseArgs); + RefConfig.addCompileArgs(RefArgs); + + DXRRunConfig RunConfig = { + WindowSize, + WindowSize, + RefConfig.UseTriangles, + RefConfig.UseProceduralGeometry, + RefMaxRecursion, + }; + RunDXRTest(Device, SERPermutationTestShaderSrc, L"lib_6_9", RefArgs.data(), + (int)RefArgs.size(), RefData, RunConfig); + + // Test permutations + for (const auto &TestConfig : TestConfigs) { + DXRRunConfig TestRunConfig(RunConfig); + TestRunConfig.MaxRecursion = TestConfig.hasRecursion() ? 2 : 1; + + std::wstring TestConfigTxt = + L"HitObject attributes with config " + TestConfig.str(); + + { + std::wstring TestingMsg = L"Testing " + TestConfigTxt; + WEX::Logging::Log::Comment(TestingMsg.c_str()); } - WEX::Logging::Log::Comment(L"HitObject::GetPrimitiveIndex() PASSED"); + + std::vector Args(BaseArgs); + TestConfig.addCompileArgs(Args); + + std::vector TestData(WindowSize * WindowSize * NumResultElements, 0); + + RunDXRTest(Device, SERPermutationTestShaderSrc, L"lib_6_9", Args.data(), + (int)Args.size(), TestData, TestRunConfig); + + const int NumArrayElems = WindowSize * WindowSize * NumResultElements; + VerifyTestArray(reinterpret_cast(RefData.data()), + reinterpret_cast(TestData.data()), + NumArrayElems); } } -TEST_F(ExecutionTest, SERVectorGetterTest) { - // SER: Test basic function of HitObject getters. +TEST_F(ExecutionTest, SERNOPValuesTest) { + // SER: Test NOP HitObject default values static const char *ShaderSrc = R"( + struct SceneConstants { float4 eye; @@ -321,225 +889,138 @@ struct SceneConstants float4 W; float sceneScale; uint2 WindowSize; - int rayFlags; -}; - -struct[raypayload] PerRayData -{ - float3 value : read(caller) : write(miss,closesthit); -}; - -struct Attrs -{ - float2 barycentrics : BARYCENTRICS; + int rayFlags; }; -RWStructuredBuffer testBuffer : register(u0); +RWStructuredBuffer testBuffer : register(u0); RaytracingAccelerationStructure topObject : register(t0); ConstantBuffer sceneConstants : register(b0); -RayDesc ComputeRay() -{ - uint2 launchIndex = DispatchRaysIndex().xy; - uint2 launchDim = DispatchRaysDimensions().xy; +float getFloatZero() { return 0.0f; } +int getIntZero() { return 0; } +int getIntOne() { return 1; } +float3 getVec3Zero() { return (float3)0; } - float2 d = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy) * 2.0f - 1.0f; - RayDesc ray; - ray.Origin = sceneConstants.eye.xyz; - ray.Direction = normalize(d.x*sceneConstants.U.xyz + d.y*sceneConstants.V.xyz + sceneConstants.W.xyz); - ray.TMin = 0; - ray.TMax = 1e18; +struct [raypayload] PerRayData +{ + int unused : read() : write(); +}; - return ray; +#ifdef MATRIX_ELEMENT_TYPE +matrix getOneDiagonalMat() { + matrix mat = 0; + mat[0][0] = 1.f; + mat[1][1] = 1.f; + mat[2][2] = 1.f; + return mat; } +#endif + +#if TEST_ATTRIBUTES +struct CustomAttrs { + uint x; + uint y; + uint z; + uint w; +}; +#endif [shader("raygeneration")] void raygen() { - uint2 launchIndex = DispatchRaysIndex().xy; - uint2 launchDim = DispatchRaysDimensions().xy; - int id = 6 * (launchIndex.x + launchIndex.y * launchDim.x); - - RayDesc ray = ComputeRay(); - - // Fetch reference value - PerRayData refPayload; - TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, refPayload); - testBuffer[id] = refPayload.value.x; - testBuffer[id + 2] = refPayload.value.y; - testBuffer[id + 4] = refPayload.value.z; - - PerRayData serPayload; - dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, serPayload); - dx::MaybeReorderThread(hitObject); - float3 serVal = hitObject.SER_GET_VECTOR(); - testBuffer[id + 1] = serVal.x; - testBuffer[id + 3] = serVal.y; - testBuffer[id + 5] = serVal.z; + dx::HitObject hitObject = dx::HitObject::MakeNop(); +#if TEST_ATTRIBUTES + CustomAttrs attrs = hitObject.GetAttributes(); + testBuffer[0] = attrs.x; + testBuffer[1] = attrs.y; + testBuffer[2] = attrs.z; + testBuffer[3] = attrs.w; +#else + const bool pass = hitObject.HITOBJECT_GET_RESULT() == NOP_GET_RESULT(); + testBuffer[0] = pass ? 1 : 0; + PerRayData pld; + dx::HitObject::Invoke(hitObject, pld); +#endif } -float3 getVecZero() { return 0.0f; } [shader("miss")] void miss(inout PerRayData payload) { - payload.value = MISS_GET_VECTOR(); + testBuffer[1] = 1; } [shader("anyhit")] -void anyhit(inout PerRayData payload, in Attrs attrs) +void anyhit(inout PerRayData payload, in BuiltInTriangleIntersectionAttributes attrs) { - // UNUSED + testBuffer[3] = 1; } [shader("closesthit")] -void closesthit(inout PerRayData payload, in Attrs attrs) +void closesthit(inout PerRayData payload, in BuiltInTriangleIntersectionAttributes attrs) { - payload.value = HIT_GET_VECTOR(); + testBuffer[2] = 1; } -)"; + +)"; CComPtr Device; if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - // Initialize test data. - const int WindowSize = 64; - - // WorldRayOrigin + // Test GetAttributes<> on NOP HitObject { - WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldRayOrigin()"); - std::vector TestData(WindowSize * WindowSize * 6, 0); - LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=WorldRayOrigin", - L"-DMISS_GET_VECTOR=WorldRayOrigin", - L"-DSER_GET_VECTOR=GetWorldRayOrigin"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 3 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 6) { - float *ResArray = (float *)(TestData.data() + Id); - float RefX = ResArray[0]; - float SerX = ResArray[1]; - float RefY = ResArray[2]; - float SerY = ResArray[3]; - float RefZ = ResArray[4]; - float SerZ = ResArray[5]; - const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); - const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); - const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); - if (!PassX || !PassY || !PassZ) { - VERIFY_ARE_EQUAL(SerX, RefX); - VERIFY_ARE_EQUAL(SerY, RefY); - VERIFY_ARE_EQUAL(SerZ, RefZ); - break; - } - } - WEX::Logging::Log::Comment(L"HitObject::GetWorldRayOrigin() PASSED"); - } + WEX::Logging::Log::Comment(L"Testing NOPHitObject::GetAttributes"); - // WorldRayDirection - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldRayDirection()"); - std::vector TestData(WindowSize * WindowSize * 6, 0); - LPCWSTR Args[] = {L"-HV 2021", L"-Vd", - L"-DHIT_GET_VECTOR=WorldRayDirection", - L"-DMISS_GET_VECTOR=WorldRayDirection", - L"-DSER_GET_VECTOR=GetWorldRayDirection"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 3 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 6) { - float *ResArray = (float *)(TestData.data() + Id); - float RefX = ResArray[0]; - float SerX = ResArray[1]; - float RefY = ResArray[2]; - float SerY = ResArray[3]; - float RefZ = ResArray[4]; - float SerZ = ResArray[5]; - const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); - const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); - const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); - if (!PassX || !PassY || !PassZ) { - VERIFY_ARE_EQUAL(SerX, RefX); - VERIFY_ARE_EQUAL(SerY, RefY); - VERIFY_ARE_EQUAL(SerZ, RefZ); - break; - } - } - WEX::Logging::Log::Comment(L"HitObject::GetWorldRayDirection() PASSED"); - } + LPCWSTR Args[] = { + L"-HV 2021", + L"-Vd", + L"-DTEST_ATTRIBUTES=1", + }; - // ObjectRayOrigin - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectRayOrigin()"); - std::vector TestData(WindowSize * WindowSize * 6, 0); - LPCWSTR Args[] = {L"-HV 2021", L"-Vd", L"-DHIT_GET_VECTOR=ObjectRayOrigin", - L"-DMISS_GET_VECTOR=WorldRayOrigin", - L"-DSER_GET_VECTOR=GetObjectRayOrigin"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 3 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 6) { - float *ResArray = (float *)(TestData.data() + Id); - float RefX = ResArray[0]; - float SerX = ResArray[1]; - float RefY = ResArray[2]; - float SerY = ResArray[3]; - float RefZ = ResArray[4]; - float SerZ = ResArray[5]; - const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); - const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); - const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); - if (!PassX || !PassY || !PassZ) { - VERIFY_ARE_EQUAL(SerX, RefX); - VERIFY_ARE_EQUAL(SerY, RefY); - VERIFY_ARE_EQUAL(SerZ, RefZ); - break; - } - } - WEX::Logging::Log::Comment(L"HitObject::GetObjectRayOrigin() PASSED"); + std::vector TestData(4, 0); + DXRRunConfig RunConfig = {1, 1, true, false, 1}; + RunConfig.AttributeCount = 4; + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, (int)std::size(Args), + TestData, RunConfig); + + // Expect zero-init of attribute structure + VERIFY_ARE_EQUAL(TestData[0], 0); + VERIFY_ARE_EQUAL(TestData[1], 0); + VERIFY_ARE_EQUAL(TestData[2], 0); + VERIFY_ARE_EQUAL(TestData[3], 0); } - // ObjectRayDirection - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectRayDirection()"); - std::vector TestData(WindowSize * WindowSize * 6, 0); - LPCWSTR Args[] = {L"-HV 2021", L"-Vd", - L"-DHIT_GET_VECTOR=ObjectRayDirection", - L"-DMISS_GET_VECTOR=WorldRayDirection", - L"-DSER_GET_VECTOR=GetObjectRayDirection"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 3 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 6) { - float *ResArray = (float *)(TestData.data() + Id); - float RefX = ResArray[0]; - float SerX = ResArray[1]; - float RefY = ResArray[2]; - float SerY = ResArray[3]; - float RefZ = ResArray[4]; - float SerZ = ResArray[5]; - const bool PassX = CompareFloatEpsilon(SerX, RefX, 0.0008f); - const bool PassY = CompareFloatEpsilon(SerY, RefY, 0.0008f); - const bool PassZ = CompareFloatEpsilon(SerZ, RefZ, 0.0008f); - if (!PassX || !PassY || !PassZ) { - VERIFY_ARE_EQUAL(SerX, RefX); - VERIFY_ARE_EQUAL(SerY, RefY); - VERIFY_ARE_EQUAL(SerZ, RefZ); - break; - } + for (const auto &Accessor : Accessors) { + std::wstring TestConfigTxt = L"NOPHitObject::"; + TestConfigTxt += Accessor.HitObjectGetter; + + { + std::wstring TestingMsg = L"Testing " + TestConfigTxt; + WEX::Logging::Log::Comment(TestingMsg.c_str()); } - WEX::Logging::Log::Comment(L"HitObject::GetObjectRayDirection() PASSED"); + + std::vector Args; + std::vector OwnedArgs; + Args.push_back(L"-HV 2021"); + Args.push_back(L"-Vd"); + Accessor.addCompileArgs(OwnedArgs, Args); + + std::vector TestData(4, 0); + DXRRunConfig RunConfig = {1, 1, true, false, 1}; + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args.data(), (int)Args.size(), + TestData, RunConfig); + + VERIFY_ARE_EQUAL(TestData[0], 1); // hitObject.GET == expected nop value + VERIFY_ARE_EQUAL(TestData[1], 0); // miss NOT called + VERIFY_ARE_EQUAL(TestData[2], 0); // closesthit NOT called + VERIFY_ARE_EQUAL(TestData[3], 0); // anyhit NOT called } } -TEST_F(ExecutionTest, SERMatrixGetterTest) { - // SER: Test basic function of HitObject getters. +TEST_F(ExecutionTest, SERMultiPayloadTest) { static const char *ShaderSrc = R"( + struct SceneConstants { float4 eye; @@ -551,17 +1032,7 @@ struct SceneConstants int rayFlags; }; -struct[raypayload] PerRayData -{ - float elems[ROWS*COLS] : read(caller) : write(miss,closesthit); -}; - -struct Attrs -{ - float2 barycentrics : BARYCENTRICS; -}; - -RWStructuredBuffer testBuffer : register(u0); +RWStructuredBuffer testBuffer : register(u0); RaytracingAccelerationStructure topObject : register(t0); ConstantBuffer sceneConstants : register(b0); @@ -580,209 +1051,366 @@ RayDesc ComputeRay() return ray; } +// Procedural geometry for use by RayQuery and intersection shader +static const int ProceduralHitKind = 11; + +struct CustomAttrs +{ + float dist; +}; + +bool evalIntersection(float3 objRayOrigin, float3 objRayDir, float rayTMax, float rayTMin, out CustomAttrs attrs, out float rayT) +{ + rayT = 0; + // Intersection with circle on a plane (base, n, radius) + // hitPos is intersection point with plane (base, n) + float3 base = {0.0f,0.0f,0.5f}; + float3 n = normalize(float3(0.0f,0.5f,0.5f)); + float radius = 500.f; + // Plane hit + float t = dot(n, base - objRayOrigin) / dot(n, objRayDir); + if (t > rayTMax || t < rayTMin) { + return false; + } + float3 hitPos = objRayOrigin + t * objRayDir; + float3 relHitPos = hitPos - base; + // Circle hit + float hitDist = length(relHitPos); + if (hitDist > radius) + return false; + + attrs.dist = hitDist; + rayT = t; + return true; +} + +#if ENABLE_PAQS +#define READ_PAQS(X, ...) : read(X, __VA_ARGS__) +#define WRITE_PAQS(X, ...) : write(X, __VA_ARGS__) +#else +#define READ_PAQS(X, ...) +#define WRITE_PAQS(X, ...) +#endif + +struct +#if ENABLE_PAQS +[raypayload] +#endif +PayloadA +{ + float unusedPad READ_PAQS(caller, anyhit, closesthit) WRITE_PAQS(anyhit, closesthit, caller); + uint ahCounter READ_PAQS(caller,anyhit) WRITE_PAQS(anyhit,caller); + float unusedPad2 READ_PAQS(caller) WRITE_PAQS(closesthit,miss); + uint chCounter READ_PAQS(caller,closesthit) WRITE_PAQS(closesthit,caller); +#if ENABLE_RECURSION + int recursionDepth READ_PAQS(caller,miss,closesthit) WRITE_PAQS(caller); +#endif + uint aabbCHCounter READ_PAQS(caller,closesthit) WRITE_PAQS(closesthit,caller); + uint aabbAHCounter READ_PAQS(caller,anyhit) WRITE_PAQS(anyhit,caller); + uint missCounter READ_PAQS(caller,miss) WRITE_PAQS(miss,caller); +}; + +struct +#if ENABLE_PAQS +[raypayload] +#endif +PayloadB +{ + uint ahCounter READ_PAQS(caller,anyhit) WRITE_PAQS(anyhit,caller); + float unusedPad READ_PAQS(caller, anyhit, closesthit) WRITE_PAQS(anyhit, closesthit, caller); + float unusedPad2 READ_PAQS(caller) WRITE_PAQS(closesthit,miss); + uint chCounter READ_PAQS(caller,closesthit) WRITE_PAQS(closesthit,caller); + uint aabbCHCounter READ_PAQS(caller,closesthit) WRITE_PAQS(closesthit,caller); + uint aabbAHCounter READ_PAQS(caller,anyhit) WRITE_PAQS(anyhit,caller); + uint missCounter READ_PAQS(caller,miss) WRITE_PAQS(miss,caller); +#if ENABLE_RECURSION + int recursionDepth READ_PAQS(caller,miss,closesthit) WRITE_PAQS(caller); +#endif +}; + +/// Result tracking +static const uint NumRayResults = 10; +static void storeRayResult(int resIdx, uint value) { + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + int baseIdx = NumRayResults * (launchIndex.x + launchIndex.y * launchDim.x); + testBuffer[baseIdx + resIdx] += value; +} + +void RunTest(int recursionDepth) +{ + RayDesc baseRay = ComputeRay(); + + PayloadA pldA; +#if ENABLE_RECURSION + pldA.recursionDepth = recursionDepth; +#endif + pldA.ahCounter = 0; + pldA.chCounter = 0; + pldA.aabbCHCounter = 0; + pldA.aabbAHCounter = 0; + pldA.missCounter = 0; + + // First HitObject::TraceRay() + dx::HitObject hitA = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, baseRay, pldA); + + // Second HitObject::TraceRay() while other HitObject is live + PayloadB pldB; +#if ENABLE_RECURSION + pldB.recursionDepth = recursionDepth; +#endif + pldB.ahCounter = 0; + pldB.chCounter = 0; + pldB.aabbCHCounter = 0; + pldB.aabbAHCounter = 0; + pldB.missCounter = 0; + RayDesc rayB = baseRay; + rayB.Origin.x += 0.1f; + dx::HitObject hitB = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 1, 1, 1, rayB, pldB); + + // TraceRay() while HitObject is live + TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, baseRay, pldA); + + // Concurrent HitObject with complex control flow + dx::HitObject loopHit; + int dynamicBound = hitA.GetGeometryIndex(); + for (int i = 0; i < dynamicBound + 5; ++i) { + RayDesc loopRay = baseRay; + loopRay.Origin.y += 0.001f * i; + loopHit = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 1, 1, 1, loopRay, pldB); +#if !ENABLE_RECURSION + dx::MaybeReorderThread(loopHit); +#endif + } + + // Invoke all HitObject (repeatedly) + loopHit.SetShaderTableIndex(0); // pldA <- pldB + dx::HitObject::Invoke(loopHit, pldA); + hitA.SetShaderTableIndex(1); // pldB <- pldA + int differentDynamicBound = hitA.GetInstanceIndex(); + for (int i = 0; i < differentDynamicBound + 3; ++i) { + dx::HitObject::Invoke(hitA, pldB); + } + dx::HitObject::Invoke(hitB, pldB); + + // Write individual counters to distinct result slots + // PayloadA + storeRayResult(0, pldA.ahCounter); + storeRayResult(1, pldA.chCounter); + storeRayResult(2, pldA.aabbCHCounter); + storeRayResult(3, pldA.aabbAHCounter); + storeRayResult(4, pldA.missCounter); + // PayloadB + storeRayResult(5, pldB.ahCounter); + storeRayResult(6, pldB.chCounter); + storeRayResult(7, pldB.aabbCHCounter); + storeRayResult(8, pldB.aabbAHCounter); + storeRayResult(9, pldB.missCounter); +} + [shader("raygeneration")] void raygen() { - uint2 launchIndex = DispatchRaysIndex().xy; - uint2 launchDim = DispatchRaysDimensions().xy; - int id = 2 * ROWS * COLS * (launchIndex.x + launchIndex.y * launchDim.x); +#if ENABLE_RECURSION + RayDesc ray = ComputeRay(); + PayloadA recPayload; + recPayload.recursionDepth = 1; + dx::HitObject missObject = dx::HitObject::MakeMiss(RAY_FLAG_NONE, 2, ray); + dx::HitObject::Invoke(missObject, recPayload); - RayDesc ray = ComputeRay(); +#else + RunTest(1); +#endif +} - // Fetch reference value - PerRayData refPayload; - TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, refPayload); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - testBuffer[id + 2 * (r * COLS + c)] = refPayload.elems[r*COLS + c]; - } - } - PerRayData serPayload; - dx::HitObject hitObject = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, ray, serPayload); - dx::MaybeReorderThread(hitObject); - matrix serVal = hitObject.SER_GET_MATRIX(); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - testBuffer[1 + id + 2 * (r * COLS + c)] = serVal[r][c]; - } - } +///// Miss shaders +[shader("miss")] +void miss(inout PayloadA payload) +{ + payload.missCounter++; } -matrix getMatIdentity() { - matrix mat = 0; - mat[0][0] = 1.f; - mat[1][1] = 1.f; - mat[2][2] = 1.f; - return mat; +[shader("miss")] +void miss1(inout PayloadB payload) +{ + payload.missCounter++; } +#if ENABLE_RECURSION [shader("miss")] -void miss(inout PerRayData payload) +void miss2(inout PayloadA payload) { - matrix mat = MISS_GET_MATRIX(); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - payload.elems[r*COLS + c] = mat[r][c]; - } - } + if (payload.recursionDepth == 1) + { + RunTest(payload.recursionDepth + 1); + return; + } } +#endif +///// Triangle HitGroup 0 [shader("anyhit")] -void anyhit(inout PerRayData payload, in Attrs attrs) +void anyhit(inout PayloadA payload, in BuiltInTriangleIntersectionAttributes attrs) { - // UNUSED + payload.ahCounter++; } [shader("closesthit")] -void closesthit(inout PerRayData payload, in Attrs attrs) +void closesthit(inout PayloadA payload, in BuiltInTriangleIntersectionAttributes attrs) { - matrix mat = HIT_GET_MATRIX(); - for (int r = 0; r < ROWS; r++) { - for (int c = 0; c < COLS; c++) { - payload.elems[r*COLS + c] = mat[r][c]; - } - } + payload.chCounter++; } -)"; +///// Triangle HitGroup 1 +[shader("anyhit")] +void anyhit1(inout PayloadB payload, in BuiltInTriangleIntersectionAttributes attrs) +{ + payload.ahCounter++; +} + +[shader("closesthit")] +void closesthit1(inout PayloadB payload, in BuiltInTriangleIntersectionAttributes attrs) +{ + payload.chCounter++; +} + + +///// Procedural HitGroup 0 +[shader("closesthit")] +void chAABB(inout PayloadA payload, in CustomAttrs customAttrs) +{ + payload.aabbCHCounter++; +} + +[shader("anyhit")] +void ahAABB(inout PayloadA payload, in CustomAttrs attrs) +{ + payload.aabbAHCounter++; +} + +[shader("intersection")] +void intersection() +{ + CustomAttrs attrs = {0}; + float rayT; + if (evalIntersection(ObjectRayOrigin(), ObjectRayDirection(), RayTCurrent(), RayTMin(), attrs, rayT)) { + ReportHit(rayT, ProceduralHitKind, attrs); + } +} + + +///// Procedural HitGroup 1 +[shader("closesthit")] +void chAABB1(inout PayloadB payload, in CustomAttrs customAttrs) +{ + payload.aabbCHCounter++; +} + +[shader("anyhit")] +void ahAABB1(inout PayloadB payload, in CustomAttrs attrs) +{ + payload.aabbAHCounter++; +} + +[shader("intersection")] +void intersection1() +{ + CustomAttrs attrs = {0}; + float rayT; + if (evalIntersection(ObjectRayOrigin(), ObjectRayDirection(), RayTCurrent(), RayTMin(), attrs, rayT)) { + ReportHit(rayT, ProceduralHitKind, attrs); + } +} + +)"; CComPtr Device; if (!CreateDXRDevice(&Device, D3D_SHADER_MODEL_6_9, true)) return; - const int WindowSize = 64; - - // WorldToObject3x4 - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldToObject3x4()"); - std::vector TestData(WindowSize * WindowSize * 24, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DHIT_GET_MATRIX=WorldToObject3x4", - L"-DMISS_GET_MATRIX=getMatIdentity", - L"-DSER_GET_MATRIX=GetWorldToObject3x4", - L"-DROWS=3", - L"-DCOLS=4"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 12 /*payloadCount*/, - 2 /*attributeCount*/); - const int ROWS = 3; - const int COLS = 4; - for (int Id = 0; Id < TestData.size(); Id += 24) { - float *ResArray = (float *)(TestData.data() + Id); - for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { - for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { - int RefIdx = 2 * (RowIdx * COLS + ColIdx); - float Ref = ResArray[RefIdx]; - float Ser = ResArray[1 + RefIdx]; - if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { - VERIFY_ARE_EQUAL(Ser, Ref); - } - } + struct PayloadTestConfig { + bool EnablePAQs; + bool EnableRecursion; + + void addCompileArgs(std::vector &OwnedArgs, + std::vector &ArgVec) const { + (void) OwnedArgs; + if (EnablePAQs) { + ArgVec.push_back(L"-DENABLE_PAQS=1"); + } else { + ArgVec.push_back(L"-disable-payload-qualifiers"); } - } - WEX::Logging::Log::Comment(L"HitObject::GetWorldToObject3x4() PASSED"); - } - - // WorldToObject4x3 - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetWorldToObject4x3()"); - const int ROWS = 4; - const int COLS = 3; - std::vector TestData(WindowSize * WindowSize * 2 * ROWS * COLS, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DHIT_GET_MATRIX=WorldToObject4x3", - L"-DMISS_GET_MATRIX=getMatIdentity", - L"-DSER_GET_MATRIX=GetWorldToObject4x3", - L"-DROWS=4", - L"-DCOLS=3"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 12 /*payloadCount*/, - 2 /*attributeCount*/); - for (int Id = 0; Id < TestData.size(); Id += 2 * ROWS * COLS) { - float *ResArray = (float *)(TestData.data() + Id); - for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { - for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { - int RefIdx = 2 * (RowIdx * COLS + ColIdx); - float Ref = ResArray[RefIdx]; - float Ser = ResArray[1 + RefIdx]; - if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { - VERIFY_ARE_EQUAL(Ser, Ref); - } - } + if (EnableRecursion) { + ArgVec.push_back(L"-DENABLE_RECURSION=1"); } } - WEX::Logging::Log::Comment(L"HitObject::GetWorldToObject4x3() PASSED"); - } + }; + + // Expected histogram results for each result key, as {value, count} pairs. + static const std::map ExpectedResults[10] = { + // result key 0 + {{0, 4060}, {2, 36}}, + // result key 1 + {{0, 847}, {1, 3213}, {2, 36}}, + // result key 2 + {{0, 883}, {1, 3213}}, + // result key 3 + {{0, 847}, {2, 3249}}, + // result key 4 + {{0, 3249}, {2, 847}}, + // result key 5 + {{0, 4060}, {6, 36}}, + // result key 6 + {{0, 847}, {4, 3249}}, + // result key 7 + {{0, 883}, {1, 3213}}, + // result key 8 + {{0, 847}, {6, 3249}}, + // result key 9 + {{0, 3249}, {4, 847}}}; - // ObjectToWorld3x4 - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectToWorld3x4()"); - std::vector TestData(WindowSize * WindowSize * 24, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DHIT_GET_MATRIX=ObjectToWorld3x4", - L"-DMISS_GET_MATRIX=getMatIdentity", - L"-DSER_GET_MATRIX=GetObjectToWorld3x4", - L"-DROWS=3", - L"-DCOLS=4"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 12 /*payloadCount*/, - 2 /*attributeCount*/); - const int ROWS = 3; - const int COLS = 4; - for (int Id = 0; Id < TestData.size(); Id += 24) { - float *ResArray = (float *)(TestData.data() + Id); - for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { - for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { - int RefIdx = 2 * (RowIdx * COLS + ColIdx); - float Ref = ResArray[RefIdx]; - float Ser = ResArray[1 + RefIdx]; - if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { - VERIFY_ARE_EQUAL(Ser, Ref); - } - } - } + const int WindowSize = 64; + const int NumRayResults = 10; + + std::vector TestConfigs; + for (bool EnablePAQs : {false, true}) { + for (bool EnableRecursion : {false, true}) { + PayloadTestConfig TestConfig; + TestConfig.EnablePAQs = EnablePAQs; + TestConfig.EnableRecursion = EnableRecursion; + TestConfigs.push_back(TestConfig); } - WEX::Logging::Log::Comment(L"HitObject::GetObjectToWorld3x4() PASSED"); } - // ObjectToWorld4x3 - { - WEX::Logging::Log::Comment(L"Testing HitObject::GetObjectToWorld4x3()"); - std::vector TestData(WindowSize * WindowSize * 24, 0); - LPCWSTR Args[] = {L"-HV 2021", - L"-Vd", - L"-DHIT_GET_MATRIX=ObjectToWorld4x3", - L"-DMISS_GET_MATRIX=getMatIdentity", - L"-DSER_GET_MATRIX=GetObjectToWorld4x3", - L"-DROWS=4", - L"-DCOLS=3"}; - RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args, _countof(Args), TestData, - WindowSize, WindowSize, true /*useMesh*/, - false /*useProceduralGeometry*/, 12 /*payloadCount*/, - 2 /*attributeCount*/); - const int ROWS = 4; - const int COLS = 3; - for (int Id = 0; Id < TestData.size(); Id += 24) { - float *ResArray = (float *)(TestData.data() + Id); - for (int RowIdx = 0; RowIdx < ROWS; RowIdx++) { - for (int ColIdx = 0; ColIdx < COLS; ColIdx++) { - int RefIdx = 2 * (RowIdx * COLS + ColIdx); - float Ref = ResArray[RefIdx]; - float Ser = ResArray[1 + RefIdx]; - if (!CompareFloatEpsilon(Ser, Ref, 0.0008f)) { - VERIFY_ARE_EQUAL(Ser, Ref); - break; - } - } + for (const auto &TestConfig : TestConfigs) { + std::vector TestData(WindowSize * WindowSize * NumRayResults, 0); + DXRRunConfig RunConfig = {WindowSize, WindowSize, true, true, 1}; + RunConfig.PayloadCount = 7 + TestConfig.EnableRecursion; + RunConfig.NumMissShaders = 2 + TestConfig.EnableRecursion; + RunConfig.NumHitGroups = 2; + RunConfig.MaxRecursion = 1 + TestConfig.EnableRecursion; + + std::vector Args; + std::vector OwnedArgs; + Args.push_back(L"-HV 2021"); + Args.push_back(L"-Vd"); + TestConfig.addCompileArgs(OwnedArgs, Args); + + RunDXRTest(Device, ShaderSrc, L"lib_6_9", Args.data(), (int)Args.size(), + TestData, RunConfig); + + for (int ResIdx = 0; ResIdx < NumRayResults; ++ResIdx) { + std::map Histo; + for (int RayIdx = 0; RayIdx < WindowSize * WindowSize; ++RayIdx) { + int Val = TestData[ResIdx + (NumRayResults * RayIdx)]; + ++Histo[Val]; + } + for (auto [Key, Value] : Histo) { + VERIFY_IS_TRUE(ExpectedResults[ResIdx].count(Key)); + const int ExpectedValue = ExpectedResults[ResIdx].at(Key); + VERIFY_ARE_EQUAL(Value, ExpectedValue); } } - WEX::Logging::Log::Comment(L"HitObject::GetObjectToWorld4x3() PASSED"); } } From c9ff450aa71fd7e8160b55f86f353cbddcaaea2b Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Wed, 3 Sep 2025 11:54:54 +0200 Subject: [PATCH 28/31] [SER] HitObject::GetAttributes change (off by default) Issue: https://github.com/microsoft/hlsl-specs/issues/612 Adds code paths with new HitObject::GetAttributes API (guarded by #if NEW_GETATTRIBUTES_API to help transition). --- .../unittests/HLSLExec/ExecutionTest_SER.h | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 64b1793b2f..b0d7877351 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -490,14 +490,31 @@ void CallTraceMethod(int recursionDepth) { ///// Query hit object getter directly #if RESULT_FROM_HITOBJECT #if ATTRIBUTES_TEST - // TODO: Update GetAttributes API if (hitObject.IsMiss()) { // Test for zero-init of miss - StoreTriangleAttributes(hitObject.GetAttributes()); + TriangleAttrs attrs; +#if NEW_GETATTRIBUTES_API + hitObject.GetAttributes(attrs); +#else + attrs = hitObject.GetAttributes(); +#endif + StoreTriangleAttributes(attrs); } else if (hitObject.GetHitKind() == ProceduralHitKind) { - StoreProceduralAttributes(hitObject.GetAttributes()); + CustomAttrs attrs; +#if NEW_GETATTRIBUTES_API + hitObject.GetAttributes(attrs); +#else + attrs = hitObject.GetAttributes(); +#endif + StoreProceduralAttributes(attrs); } else { - StoreTriangleAttributes(hitObject.GetAttributes()); + TriangleAttrs attrs; +#if NEW_GETATTRIBUTES_API + hitObject.GetAttributes(attrs); +#else + attrs = hitObject.GetAttributes(); +#endif + StoreTriangleAttributes(attrs); } #else StoreResult(hitObject.HITOBJECT_GET_RESULT()); @@ -930,7 +947,12 @@ void raygen() { dx::HitObject hitObject = dx::HitObject::MakeNop(); #if TEST_ATTRIBUTES - CustomAttrs attrs = hitObject.GetAttributes(); + CustomAttrs attrs; +#if NEW_GETATTRIBUTES_API + hitObject.GetAttributes(attrs); +#else + attrs = hitObject.GetAttributes(); +#endif testBuffer[0] = attrs.x; testBuffer[1] = attrs.y; testBuffer[2] = attrs.z; @@ -2127,7 +2149,12 @@ void raygen() dx::MaybeReorderThread(hitObject); // Check Attributes for hit detection. - CustomAttrs customAttrs = hitObject.GetAttributes(); + CustomAttrs customAttrs; +#if NEW_GETATTRIBUTES_API + hitObject.GetAttributes(customAttrs); +#else + customAttrs = hitObject.GetAttributes(); +#endif bool isHit = hitObject.IsHit(); int testVal = 0; From 156d9c53eec13e4f68992560b5da48efff1c3552 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Wed, 10 Sep 2025 07:50:49 +0200 Subject: [PATCH 29/31] Set AABB/Tri exclusive ray flags to fix AH sequencing issues in SERMultiPayloadTest --- .../unittests/HLSLExec/ExecutionTest_SER.h | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index b0d7877351..13f0c5fc85 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -1174,7 +1174,7 @@ void RunTest(int recursionDepth) pldA.missCounter = 0; // First HitObject::TraceRay() - dx::HitObject hitA = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, baseRay, pldA); + dx::HitObject hitA = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_TRIANGLES, 0xFF, 0, 1, 0, baseRay, pldA); // Second HitObject::TraceRay() while other HitObject is live PayloadB pldB; @@ -1188,7 +1188,7 @@ void RunTest(int recursionDepth) pldB.missCounter = 0; RayDesc rayB = baseRay; rayB.Origin.x += 0.1f; - dx::HitObject hitB = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 1, 1, 1, rayB, pldB); + dx::HitObject hitB = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES, 0xFF, 1, 1, 1, rayB, pldB); // TraceRay() while HitObject is live TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, baseRay, pldA); @@ -1199,7 +1199,7 @@ void RunTest(int recursionDepth) for (int i = 0; i < dynamicBound + 5; ++i) { RayDesc loopRay = baseRay; loopRay.Origin.y += 0.001f * i; - loopHit = dx::HitObject::TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 1, 1, 1, loopRay, pldB); + loopHit = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_TRIANGLES, 0xFF, 1, 1, 1, loopRay, pldB); #if !ENABLE_RECURSION dx::MaybeReorderThread(loopHit); #endif @@ -1371,7 +1371,7 @@ void intersection1() // Expected histogram results for each result key, as {value, count} pairs. static const std::map ExpectedResults[10] = { // result key 0 - {{0, 4060}, {2, 36}}, + {{0, 4060}, {1, 36}}, // result key 1 {{0, 847}, {1, 3213}, {2, 36}}, // result key 2 @@ -1381,21 +1381,21 @@ void intersection1() // result key 4 {{0, 3249}, {2, 847}}, // result key 5 - {{0, 4060}, {6, 36}}, + {{0, 4030}, {1, 66}}, // result key 6 - {{0, 847}, {4, 3249}}, + {{0, 847}, {4, 3183}, {5, 66}}, // result key 7 - {{0, 883}, {1, 3213}}, + {{0, 4096}}, // result key 8 - {{0, 847}, {6, 3249}}, + {{0, 847}, {5, 3249}}, // result key 9 - {{0, 3249}, {4, 847}}}; + {{0, 66}, {1, 3183}, {4, 847}}}; const int WindowSize = 64; const int NumRayResults = 10; std::vector TestConfigs; - for (bool EnablePAQs : {false, true}) { + for (bool EnablePAQs : {false}) { for (bool EnableRecursion : {false, true}) { PayloadTestConfig TestConfig; TestConfig.EnablePAQs = EnablePAQs; @@ -1428,6 +1428,7 @@ void intersection1() ++Histo[Val]; } for (auto [Key, Value] : Histo) { + LogCommentFmt(L"Result %d.Expected key: %d, value: %d", ResIdx, Key, Value); VERIFY_IS_TRUE(ExpectedResults[ResIdx].count(Key)); const int ExpectedValue = ExpectedResults[ResIdx].at(Key); VERIFY_ARE_EQUAL(Value, ExpectedValue); From 876b2b4158abdb0a2b1e5eaadeea0c669d4e7fef Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Wed, 10 Sep 2025 23:19:14 +0200 Subject: [PATCH 30/31] Remove dbg messages, restore PAQ variations --- tools/clang/unittests/HLSLExec/ExecutionTest_SER.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index 13f0c5fc85..e6bfcc77e5 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -1395,7 +1395,7 @@ void intersection1() const int NumRayResults = 10; std::vector TestConfigs; - for (bool EnablePAQs : {false}) { + for (bool EnablePAQs : {false, true}) { for (bool EnableRecursion : {false, true}) { PayloadTestConfig TestConfig; TestConfig.EnablePAQs = EnablePAQs; @@ -1428,7 +1428,6 @@ void intersection1() ++Histo[Val]; } for (auto [Key, Value] : Histo) { - LogCommentFmt(L"Result %d.Expected key: %d, value: %d", ResIdx, Key, Value); VERIFY_IS_TRUE(ExpectedResults[ResIdx].count(Key)); const int ExpectedValue = ExpectedResults[ResIdx].at(Key); VERIFY_ARE_EQUAL(Value, ExpectedValue); From 4df10b9bac57b5375f01bca0a181e97ba933ffa2 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Wed, 10 Sep 2025 23:23:49 +0200 Subject: [PATCH 31/31] SERMultiPayloadTest: Add SKIP_TRIANGLES ray flag to remaining TraceCall --- tools/clang/unittests/HLSLExec/ExecutionTest_SER.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h index e6bfcc77e5..553a913fa5 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h +++ b/tools/clang/unittests/HLSLExec/ExecutionTest_SER.h @@ -1191,7 +1191,7 @@ void RunTest(int recursionDepth) dx::HitObject hitB = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES, 0xFF, 1, 1, 1, rayB, pldB); // TraceRay() while HitObject is live - TraceRay(topObject, RAY_FLAG_NONE, 0xFF, 0, 1, 0, baseRay, pldA); + TraceRay(topObject, RAY_FLAG_SKIP_TRIANGLES, 0xFF, 0, 1, 0, baseRay, pldA); // Concurrent HitObject with complex control flow dx::HitObject loopHit; @@ -1371,11 +1371,11 @@ void intersection1() // Expected histogram results for each result key, as {value, count} pairs. static const std::map ExpectedResults[10] = { // result key 0 - {{0, 4060}, {1, 36}}, + {{0, 4096}}, // result key 1 - {{0, 847}, {1, 3213}, {2, 36}}, + {{0, 847}, {1, 3249}}, // result key 2 - {{0, 883}, {1, 3213}}, + {{0, 847}, {1, 3249}}, // result key 3 {{0, 847}, {2, 3249}}, // result key 4