Skip to content

Commit 30bfd82

Browse files
tex3dGreg Roth
andauthored
NFC: Infrastructure changes for DXIL op vector and multi-dim overloads (microsoft#7259)
This change adds vector and multi-dimensional overload support for DXIL operations. Multi-dimensional (or "extended") overloads are added, where two or more types in a DXIL Op function signature may vary independently, such as both the return type and a parameter type. Until now, only one overload dimension has been necessary. For single-dim overloads, any number of parameters in a DXIL op may refer to this single overload type. For multi-dim overloads, each type that can vary must have a unique overload dimension, even when two or more types must be the same. This follows a pattern from llvm intrinsics. If two or more of the types need to be the same, this constraint must be handled manually, outside the automatic overload constraints defined by the DXIL op definitions. Vector overloads are also added, requiring an additional set of scalar overload types to define the allowed vector element types, on top of the original set describing the allowed scalar overloads for an operation, since both scalar and vector overloads may be allowed on the same operation. There are several components involved in handling DXIL operation overloads, with some changes: - DXIL Op definitions in `hctdb.py` use a string of characters to define the allowed overloads, and special type names used in parameter definitions that refer to the overload type. - Overload string syntax updated and more heavily validated. - `','` may separate dimensions for multi-dim overloads - `'<'` indicates that a vector overload is allowed, in which case, scalar components on the left indicate normal scalar overloads allowed, and scalar components on the right indicate the allowed vector element overloads. - If scalar overloads are present to the left, and omitted to the right, the scalar components are replicated to the right automatically. For instance: `"hf<"` is equivalent to `"hf<hf"`. - `dxil_max_overload_dims = 2` is introduced to define the maximum number of overload dimensions currently supported. - This is used to generate the `DXIL::kDxilMaxOloadDims` definition in `DxilConstants.h`. - `"$x0"` and `"$x1"` are used to reference each overloaded dxil type in parameter definitions when more than one overload dimension is defined for a DXIL op. Other special overload types are not allowed for multi-dim overloads, which means you cannot (currently) describe a multi-dim overload where a returned overload type is wrapped in a resource return struct along with residency status. This could be changed in the future if necessary. - Enforced rules for multi-dim overloads keep them compatible with the llvm intrinsic overloading scheme. - `hctdb_instrhelp.py` translates overload and param type info from DXIL operation definitions into code inserted into `DxilOperations.cpp`. - `DxilOperations.h|cpp` encodes allowed overloads inside `OpCodeProperty` state for each operation in the `m_OpCodeProps` table. It uses this information, along with generated code, to enforce overload rules on DXIL ops. - The allowed overload definition in `OpCodeProperty` has been rewritten to use a more compact `OverloadMask` type, support multi-dim overloads, and add a second layer `AllowedVectorElements` for vector overloads for each dimension. - There are assumptions that one `llvm::Type*` describes the overload type, such as with: `GetOpFunc`, `GetOpFuncList`, `GetOverloadType`, `IsOverloadLegal`, and `m_OpCodeClassCache`. The scheme used for multi-dim overloads is to encode each of the overload types in a single unnamed `StructType`, like `type {i32, <4 x float>}`. This makes it compatible with all these existing mechanisms without requiring an API overhaul impacting the broader code base. `GetExtendedOverloadType` is used to construct this type from multiple types. While updating `DxilOperations.h|cpp`, I noticed and removed some unused methods: `IsDxilOpTypeName`, `IsDxilOpType`, `IsDupDxilOpType`, `GetOriginalDxilOpType`. --------- Co-authored-by: Greg Roth <[email protected]>
1 parent 3035d31 commit 30bfd82

File tree

6 files changed

+3162
-3159
lines changed

6 files changed

+3162
-3159
lines changed

include/dxc/DXIL/DxilConstants.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ const float kMinMipLodBias = -16.0f;
155155

156156
const unsigned kResRetStatusIndex = 4;
157157

158+
/* <py::lines('OLOAD_DIMS-TEXT')>hctdb_instrhelp.get_max_oload_dims()</py>*/
159+
// OLOAD_DIMS-TEXT:BEGIN
160+
const unsigned kDxilMaxOloadDims = 2;
161+
// OLOAD_DIMS-TEXT:END
162+
158163
enum class ComponentType : uint32_t {
159164
Invalid = 0,
160165
I1,

include/dxc/DXIL/DxilOperations.h

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,31 @@ class OP {
5757
// caches.
5858
void RefreshCache();
5959

60+
// The single llvm::Type * "OverloadType" has one of these forms:
61+
// No overloads (NumOverloadDims == 0):
62+
// - TS_Void: VoidTy
63+
// For single overload dimension (NumOverloadDims == 1):
64+
// - TS_F*, TS_I*: a scalar numeric type (half, float, i1, i64, etc.),
65+
// - TS_UDT: a pointer to a StructType representing a User Defined Type,
66+
// - TS_Object: a named StructType representing a built-in object, or
67+
// - TS_Vector: a vector type (<4 x float>, <16 x i16>, etc.)
68+
// For multiple overload dimensions (TS_Extended, NumOverloadDims > 1):
69+
// - an unnamed StructType containing each type for the corresponding
70+
// dimension, such as: type { i32, <2 x float> }
71+
// - contained type options are the same as for single dimension.
72+
6073
llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
74+
75+
// N-dimension convenience version of GetOpFunc:
76+
llvm::Function *GetOpFunc(OpCode OpCode,
77+
llvm::ArrayRef<llvm::Type *> OverloadTypes);
78+
6179
const llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> &
6280
GetOpFuncList(OpCode OpCode) const;
6381
bool IsDxilOpUsed(OpCode opcode) const;
6482
void RemoveFunction(llvm::Function *F);
6583
llvm::LLVMContext &GetCtx() { return m_Ctx; }
84+
llvm::Module *GetModule() { return m_pModule; }
6685
llvm::Type *GetHandleType() const;
6786
llvm::Type *GetHitObjectType() const;
6887
llvm::Type *GetNodeHandleType() const;
@@ -81,9 +100,14 @@ class OP {
81100

82101
llvm::Type *GetResRetType(llvm::Type *pOverloadType);
83102
llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
84-
llvm::Type *GetVectorType(unsigned numElements, llvm::Type *pOverloadType);
103+
llvm::Type *GetStructVectorType(unsigned numElements,
104+
llvm::Type *pOverloadType);
85105
bool IsResRetType(llvm::Type *Ty);
86106

107+
// Construct an unnamed struct type containing the set of member types.
108+
llvm::StructType *
109+
GetExtendedOverloadType(llvm::ArrayRef<llvm::Type *> OverloadTypes);
110+
87111
// Try to get the opcode class for a function.
88112
// Return true and set `opClass` if the given function is a dxil function.
89113
// Return false if the given function is not a dxil function.
@@ -128,11 +152,6 @@ class OP {
128152
static bool BarrierRequiresGroup(const llvm::CallInst *CI);
129153
static bool BarrierRequiresNode(const llvm::CallInst *CI);
130154
static DXIL::BarrierMode TranslateToBarrierMode(const llvm::CallInst *CI);
131-
static bool IsDxilOpTypeName(llvm::StringRef name);
132-
static bool IsDxilOpType(llvm::StructType *ST);
133-
static bool IsDupDxilOpType(llvm::StructType *ST);
134-
static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
135-
llvm::Module &M);
136155
static void GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
137156
unsigned &major, unsigned &minor,
138157
unsigned &mask);
@@ -141,6 +160,13 @@ class OP {
141160
unsigned valMinor, unsigned &major,
142161
unsigned &minor, unsigned &mask);
143162

163+
static bool IsDxilOpExtendedOverload(OpCode C);
164+
165+
// Return true if the overload name suffix for this operation may be
166+
// constructed based on a user-defined or user-influenced type name
167+
// that may not represent the same type in different linked modules.
168+
static bool MayHaveNonCanonicalOverload(OpCode OC);
169+
144170
private:
145171
// Per-module properties.
146172
llvm::LLVMContext &m_Ctx;
@@ -164,13 +190,33 @@ class OP {
164190

165191
DXIL::LowPrecisionMode m_LowPrecisionMode;
166192

167-
static const unsigned kUserDefineTypeSlot = 9;
168-
static const unsigned kObjectTypeSlot = 10;
169-
static const unsigned kNumTypeOverloads =
170-
11; // void, h,f,d, i1, i8,i16,i32,i64, udt, obj
193+
// Overload types are split into "basic" overload types and special types
194+
// Basic: void, half, float, double, i1, i8, i16, i32, i64
195+
// - These have one canonical overload per TypeSlot
196+
// Special: udt, obj, vec, extended
197+
// - These may have many overloads per type slot
198+
enum TypeSlot : unsigned {
199+
TS_F16 = 0,
200+
TS_F32 = 1,
201+
TS_F64 = 2,
202+
TS_I1 = 3,
203+
TS_I8 = 4,
204+
TS_I16 = 5,
205+
TS_I32 = 6,
206+
TS_I64 = 7,
207+
TS_BasicCount,
208+
TS_UDT = 8, // Ex: %"struct.MyStruct" *
209+
TS_Object = 9, // Ex: %"class.StructuredBuffer<Foo>"
210+
TS_Vector = 10, // Ex: <8 x i16>
211+
TS_MaskBitCount, // Types used in Mask end here
212+
// TS_Extended is only used to identify the unnamed struct type used to wrap
213+
// multiple overloads when using GetTypeSlot.
214+
TS_Extended, // Ex: type { float, <16 x i32> }
215+
TS_Invalid = UINT_MAX,
216+
};
171217

172-
llvm::Type *m_pResRetType[kNumTypeOverloads];
173-
llvm::Type *m_pCBufferRetType[kNumTypeOverloads];
218+
llvm::Type *m_pResRetType[TS_BasicCount];
219+
llvm::Type *m_pCBufferRetType[TS_BasicCount];
174220

175221
struct OpCodeCacheItem {
176222
llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> pOverloads;
@@ -181,27 +227,46 @@ class OP {
181227

182228
private:
183229
// Static properties.
230+
struct OverloadMask {
231+
// mask of type slot bits as (1 << TypeSlot)
232+
uint16_t SlotMask;
233+
static_assert(TS_MaskBitCount <= (sizeof(SlotMask) * 8));
234+
bool operator[](unsigned TypeSlot) const {
235+
return (TypeSlot < TS_MaskBitCount) ? (bool)(SlotMask & (1 << TypeSlot))
236+
: 0;
237+
}
238+
operator bool() const { return SlotMask != 0; }
239+
};
184240
struct OpCodeProperty {
185241
OpCode opCode;
186242
const char *pOpCodeName;
187243
OpCodeClass opCodeClass;
188244
const char *pOpCodeClassName;
189-
bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64,
190-
// udt
191245
llvm::Attribute::AttrKind FuncAttr;
246+
247+
// Number of overload dimensions used by the operation.
248+
unsigned int NumOverloadDims;
249+
250+
// Mask of supported overload types for each overload dimension.
251+
OverloadMask AllowedOverloads[DXIL::kDxilMaxOloadDims];
252+
253+
// Mask of scalar components allowed for each demension where
254+
// AllowedOverloads[n][TS_Vector] is true.
255+
OverloadMask AllowedVectorElements[DXIL::kDxilMaxOloadDims];
192256
};
193257
static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes];
194258

195-
static const char *m_OverloadTypeName[kNumTypeOverloads];
259+
static const char *m_OverloadTypeName[TS_BasicCount];
196260
static const char *m_NamePrefix;
197261
static const char *m_TypePrefix;
198262
static const char *m_MatrixTypePrefix;
199263
static unsigned GetTypeSlot(llvm::Type *pType);
200264
static const char *GetOverloadTypeName(unsigned TypeSlot);
201-
static llvm::StringRef GetTypeName(llvm::Type *Ty, std::string &str);
202-
static llvm::StringRef ConstructOverloadName(llvm::Type *Ty,
203-
DXIL::OpCode opCode,
204-
std::string &funcNameStorage);
265+
static llvm::StringRef GetTypeName(llvm::Type *Ty,
266+
llvm::SmallVectorImpl<char> &Storage);
267+
static llvm::StringRef
268+
ConstructOverloadName(llvm::Type *Ty, DXIL::OpCode opCode,
269+
llvm::SmallVectorImpl<char> &Storage);
205270
};
206271

207272
} // namespace hlsl

0 commit comments

Comments
 (0)