Skip to content

[DXIL] Define and generate DXILAttribute and DXILProperty #117072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,29 @@ def miss : DXILShaderStage;
def all_stages : DXILShaderStage;
// Denote support for DXIL Op to have been removed
def removed : DXILShaderStage;

// DXIL Op attributes

// A function attribute denotes that there is a corresponding LLVM function
// attribute that will be set when building the DXIL op. The mapping for
// non-trivial cases is defined by setDXILAttribute in DXILOpBuilder.cpp
class DXILAttribute;

def ReadOnly : DXILAttribute;
def ReadNone : DXILAttribute;
def IsDerivative : DXILAttribute;
def IsGradient : DXILAttribute;
def IsFeedback : DXILAttribute;
def IsWave : DXILAttribute;
def NeedsUniformInputs : DXILAttribute;
def IsBarrier : DXILAttribute;
def ReadOnly : DXILAttribute;
def NoDuplicate : DXILAttribute;
def NoReturn : DXILAttribute;

// A property is simply used to mark a DXIL op belongs to a sub-group of
// DXIL ops, and it is used to query if a particular holds this property.
// This is used for static analysis of DXIL ops.
class DXILProperty;

def IsBarrier : DXILProperty;
def IsGradient : DXILProperty;
def IsFeedback : DXILProperty;
def IsWave : DXILProperty;
def RequiresUniformInputs : DXILProperty;

class Overloads<Version ver, list<DXILOpParamType> ols> {
Version dxil_version = ver;
Expand All @@ -291,7 +302,7 @@ class Stages<Version ver, list<DXILShaderStage> st> {

class Attributes<Version ver = DXIL1_0, list<DXILAttribute> attrs> {
Version dxil_version = ver;
list<DXILAttribute> op_attrs = attrs;
list<DXILAttribute> fn_attrs = attrs;
}

defvar BarrierMode_DeviceMemoryBarrier = 2;
Expand Down Expand Up @@ -376,6 +387,9 @@ class DXILOp<int opcode, DXILOpClass opclass> {

// Versioned attributes of operation
list<Attributes> attributes = [];

// List of properties. Default to no properties.
list<DXILProperty> properties = [];
}

// Concrete definitions of DXIL Operations
Expand Down Expand Up @@ -783,6 +797,7 @@ def CreateHandle : DXILOp<57, createHandle> {
let arguments = [Int8Ty, Int32Ty, Int32Ty, Int1Ty];
let result = HandleTy;
let stages = [Stages<DXIL1_0, [all_stages]>, Stages<DXIL1_6, [removed]>];
let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is consistent with hctdb.py, but I'm not certain it's right. I can't really rationalize why this is RO while both CreateHandleFromBinding and even CreateHandleFromHeap are RN. I'm not suggesting changing it. At worst, if you're unsure, a comment here might give future generations a chance to revisit this when we might be better suited to answer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to note this and will mention it in the new issue.

}

def BufferLoad : DXILOp<68, bufferLoad> {
Expand All @@ -794,6 +809,7 @@ def BufferLoad : DXILOp<68, bufferLoad> {
[Overloads<DXIL1_0,
[ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
}

def BufferStore : DXILOp<69, bufferStore> {
Expand Down Expand Up @@ -822,6 +838,7 @@ def CheckAccessFullyMapped : DXILOp<71, checkAccessFullyMapped> {
let result = Int1Ty;
let overloads = [Overloads<DXIL1_0, [Int32Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
}

def Discard : DXILOp<82, discard> {
Expand Down Expand Up @@ -896,8 +913,8 @@ def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
let intrinsics = [ IntrinSelect<int_dx_dot4add_i8packed> ];
let arguments = [Int32Ty, Int32Ty, Int32Ty];
let result = Int32Ty;
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def Dot4AddU8Packed : DXILOp<164, dot4AddPacked> {
Expand All @@ -906,22 +923,24 @@ def Dot4AddU8Packed : DXILOp<164, dot4AddPacked> {
let intrinsics = [ IntrinSelect<int_dx_dot4add_u8packed> ];
let arguments = [Int32Ty, Int32Ty, Int32Ty];
let result = Int32Ty;
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def AnnotateHandle : DXILOp<216, annotateHandle> {
let Doc = "annotate handle with resource properties";
let arguments = [HandleTy, ResPropsTy];
let result = HandleTy;
let stages = [Stages<DXIL1_6, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
let Doc = "create resource handle from binding";
let arguments = [ResBindTy, Int32Ty, Int1Ty];
let result = HandleTy;
let stages = [Stages<DXIL1_6, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
Expand All @@ -930,6 +949,7 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
let arguments = [Int1Ty];
let result = Int1Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let properties = [IsWave];
}

def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
Expand All @@ -938,6 +958,7 @@ def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
let arguments = [Int1Ty];
let result = Int1Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let properties = [IsWave];
}

def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
Expand All @@ -946,7 +967,7 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
let arguments = [];
let result = Int1Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let properties = [IsWave];
}

def WaveReadLaneAt: DXILOp<117, waveReadLaneAt> {
Expand All @@ -956,7 +977,7 @@ def WaveReadLaneAt: DXILOp<117, waveReadLaneAt> {
let result = OverloadTy;
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let properties = [IsWave];
}

def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
Expand All @@ -965,7 +986,8 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
let arguments = [];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
let properties = [IsWave];
}

def WaveAllBitCount : DXILOp<135, waveAllOp> {
Expand All @@ -974,7 +996,7 @@ def WaveAllBitCount : DXILOp<135, waveAllOp> {
let arguments = [Int1Ty];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let properties = [IsWave];
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all these wave ops could be readonly at the least, maybe we want to investigate that later though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree that we could further investigate the correctness of the attributes. The goal of this PR was to get consistent behaviour with DXC, so I will file an issue to look into that.

}

def Barrier : DXILOp<80, barrier> {
Expand All @@ -989,4 +1011,5 @@ def Barrier : DXILOp<80, barrier> {
let result = VoidTy;
let stages = [Stages<DXIL1_0, [compute, library]>];
let attributes = [Attributes<DXIL1_0, []>];
let properties = [IsBarrier];
}
22 changes: 22 additions & 0 deletions llvm/lib/Target/DirectX/DXILConstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ enum class OpParamType : unsigned {
#include "DXILOperation.inc"
};

struct Attributes {
#define DXIL_ATTRIBUTE(Name) bool Name = false;
#include "DXILOperation.inc"
};

inline Attributes operator|(Attributes a, Attributes b) {
Attributes c;
#define DXIL_ATTRIBUTE(Name) c.Name = a.Name | b.Name;
#include "DXILOperation.inc"
return c;
}

inline Attributes &operator|=(Attributes &a, Attributes &b) {
a = a | b;
return a;
}

struct Properties {
#define DXIL_PROPERTY(Name) bool Name = false;
#include "DXILOperation.inc"
};

} // namespace dxil
} // namespace llvm

Expand Down
59 changes: 52 additions & 7 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ struct OpStage {
uint32_t ValidStages;
};

struct OpAttribute {
Version DXILVersion;
uint32_t ValidAttrs;
};

static const char *getOverloadTypeName(OverloadKind Kind) {
switch (Kind) {
case OverloadKind::HALF:
Expand Down Expand Up @@ -158,7 +153,6 @@ struct OpCodeProperty {
unsigned OpCodeClassNameOffset;
llvm::SmallVector<OpOverload> Overloads;
llvm::SmallVector<OpStage> Stages;
llvm::SmallVector<OpAttribute> Attributes;
int OverloadParamIndex; // parameter index which control the overload.
// When < 0, should be only 1 overload type.
};
Expand Down Expand Up @@ -367,6 +361,51 @@ static std::optional<size_t> getPropIndex(ArrayRef<T> PropList,
return std::nullopt;
}

constexpr static uint64_t computeSwitchEnum(dxil::OpCode OpCode,
uint16_t VersionMajor,
uint16_t VersionMinor) {
uint64_t OpCodePack = (uint64_t)OpCode;
return (OpCodePack << 32) | (VersionMajor << 16) | VersionMinor;
}

static dxil::Attributes getDXILAttributes(dxil::OpCode OpCode,
VersionTuple DXILVersion) {
SmallVector<Version> Versions = {
#define DXIL_VERSION(MAJOR, MINOR) {MAJOR, MINOR},
#include "DXILOperation.inc"
};

dxil::Attributes Attributes;
for (auto Version : Versions) {
if (DXILVersion < VersionTuple(Version.Major, Version.Minor))
continue;
switch (computeSwitchEnum(OpCode, Version.Major, Version.Minor)) {
#define DXIL_OP_ATTRIBUTES(OpCode, VersionMajor, VersionMinor, ...) \
case computeSwitchEnum(OpCode, VersionMajor, VersionMinor): { \
auto Other = dxil::Attributes{__VA_ARGS__}; \
Attributes |= Other; \
break; \
};
#include "DXILOperation.inc"
}
}
return Attributes;
}

static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode,
VersionTuple DXILVersion) {
dxil::Attributes Attributes = getDXILAttributes(OpCode, DXILVersion);
if (Attributes.ReadNone)
CI->setDoesNotAccessMemory();
if (Attributes.ReadOnly)
CI->setOnlyReadsMemory();
if (Attributes.NoReturn)
CI->setDoesNotReturn();
if (Attributes.NoDuplicate)
CI->setCannotDuplicate();
return;
}

namespace llvm {
namespace dxil {

Expand Down Expand Up @@ -461,7 +500,13 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
OpArgs.append(Args.begin(), Args.end());

return IRB.CreateCall(DXILFn, OpArgs, Name);
// Create the function call instruction
CallInst *CI = IRB.CreateCall(DXILFn, OpArgs, Name);

// We then need to attach available function attributes
setDXILAttributes(CI, OpCode, DXILVersion);

return CI;
}

CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
Expand Down
24 changes: 13 additions & 11 deletions llvm/test/CodeGen/DirectX/BufferLoad.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ define void @loadv4f32() {
; The temporary casts should all have been cleaned up
; CHECK-NOT: %dx.resource.casthandle

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR:]]
%load0 = call {<4 x float>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
%data0 = extractvalue {<4 x float>, i1} %load0, 0
Expand All @@ -34,7 +34,7 @@ define void @loadv4f32() {
call void @scalar_user(float %data0_0)
call void @scalar_user(float %data0_2)

; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef) #[[#ATTR]]
%load4 = call {<4 x float>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
%data4 = extractvalue {<4 x float>, i1} %load4, 0
Expand All @@ -49,7 +49,7 @@ define void @loadv4f32() {
; CHECK: insertelement <4 x float>
call void @vector_user(<4 x float> %data4)

; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef) #[[#ATTR]]
%load12 = call {<4 x float>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
%data12 = extractvalue {<4 x float>, i1} %load12, 0
Expand All @@ -72,7 +72,7 @@ define void @index_dynamic(i32 %bufindex, i32 %elemindex) {
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[LOAD:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 %bufindex, i32 undef)
; CHECK: [[LOAD:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 %bufindex, i32 undef) #[[#ATTR]]
%load = call {<4 x float>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %bufindex)
%data = extractvalue {<4 x float>, i1} %load, 0
Expand Down Expand Up @@ -108,7 +108,7 @@ define void @loadf32() {
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_f32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
%load0 = call {float, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", float, 0, 0, 0) %buffer, i32 0)
%data0 = extractvalue {float, i1} %load0, 0
Expand All @@ -127,7 +127,7 @@ define void @loadv2f32() {
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2f32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
%data0 = call {<2 x float>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <2 x float>, 0, 0, 0) %buffer, i32 0)

Expand All @@ -141,12 +141,12 @@ define void @loadv4f32_checkbit() {
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
%data0 = call {<4 x float>, i1} @llvm.dx.resource.load.typedbuffer.f32(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)

; CHECK: [[STATUS:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 4
; CHECK: [[MAPPED:%.*]] = call i1 @dx.op.checkAccessFullyMapped.i32(i32 71, i32 [[STATUS]]
; CHECK: [[MAPPED:%.*]] = call i1 @dx.op.checkAccessFullyMapped.i32(i32 71, i32 [[STATUS]]) #[[#ATTR]]
%check = extractvalue {<4 x float>, i1} %data0, 1

; CHECK: call void @check_user(i1 [[MAPPED]])
Expand All @@ -162,7 +162,7 @@ define void @loadv4i32() {
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4i32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
%data0 = call {<4 x i32>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)

Expand All @@ -176,7 +176,7 @@ define void @loadv4f16() {
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f16_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
%data0 = call {<4 x half>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)

Expand All @@ -190,9 +190,11 @@ define void @loadv4i16() {
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4i16_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
%data0 = call {<4 x i16>, i1} @llvm.dx.resource.load.typedbuffer(
target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)

ret void
}

; CHECK: attributes #[[#ATTR]] = {{{.*}} memory(read) {{.*}}}
Loading
Loading