Skip to content

Fully flatten RayDesc for all HL ops in ScalarReplHLSL #7441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/dxc/HLSL/HLOperations.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ const unsigned kHitObjectMakeMissRayDescOpIdx = 4;

// HitObject::TraceRay
const unsigned kHitObjectTraceRay_RayDescOpIdx = 8;
const unsigned kHitObjectTraceRay_NumOp = 10;
const unsigned kHitObjectTraceRay_PayloadOpIdx = 9;
const unsigned kHitObjectTraceRay_NumOp = 13;

// HitObject::FromRayQuery
const unsigned kHitObjectFromRayQuery_WithAttrs_AttributeOpIdx = 4;
Expand Down
117 changes: 34 additions & 83 deletions lib/HLSL/HLOperationLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5720,37 +5720,26 @@ Value *TranslateCallShader(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
return Builder.CreateCall(F, {opArg, ShaderIndex, Parameter});
}

static unsigned LoadRayDescElementsIntoArgs(Value **Args, hlsl::OP *OP,
IRBuilder<> &Builder,
Value *RayDescPtr, unsigned Index) {
static void ExtractRayDescElementsIntoArgs(Value **Args, IRBuilder<> &Builder,
CallInst *CI, unsigned &DestIndex,
unsigned &SrcIndex) {
// struct RayDesc
//{
// float3 Origin;
Value *Origin = CI->getArgOperand(SrcIndex++);
Args[DestIndex++] = Builder.CreateExtractElement(Origin, (uint64_t)0);
Args[DestIndex++] = Builder.CreateExtractElement(Origin, 1);
Args[DestIndex++] = Builder.CreateExtractElement(Origin, 2);
// float TMin;
Args[DestIndex++] = CI->getArgOperand(SrcIndex++);
// float3 Direction;
Value *Direction = CI->getArgOperand(SrcIndex++);
Args[DestIndex++] = Builder.CreateExtractElement(Direction, (uint64_t)0);
Args[DestIndex++] = Builder.CreateExtractElement(Direction, 1);
Args[DestIndex++] = Builder.CreateExtractElement(Direction, 2);
// float TMax;
Args[DestIndex++] = CI->getArgOperand(SrcIndex++);
//};
Value *ZeroIdx = OP->GetU32Const(0);
Value *Origin = Builder.CreateGEP(RayDescPtr, {ZeroIdx, ZeroIdx});
Origin = Builder.CreateLoad(Origin);
Args[Index++] = Builder.CreateExtractElement(Origin, (uint64_t)0);
Args[Index++] = Builder.CreateExtractElement(Origin, 1);
Args[Index++] = Builder.CreateExtractElement(Origin, 2);

Value *TMinPtr = Builder.CreateGEP(RayDescPtr, {ZeroIdx, OP->GetU32Const(1)});
Args[Index++] = Builder.CreateLoad(TMinPtr);

Value *DirectionPtr =
Builder.CreateGEP(RayDescPtr, {ZeroIdx, OP->GetU32Const(2)});
Value *Direction = Builder.CreateLoad(DirectionPtr);

Args[Index++] = Builder.CreateExtractElement(Direction, (uint64_t)0);
Args[Index++] = Builder.CreateExtractElement(Direction, 1);
Args[Index++] = Builder.CreateExtractElement(Direction, 2);

Value *TMaxPtr = Builder.CreateGEP(RayDescPtr, {ZeroIdx, OP->GetU32Const(3)});
Args[Index++] = Builder.CreateLoad(TMaxPtr);
return Index;
}

Value *TranslateTraceRay(CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
Expand All @@ -5759,18 +5748,18 @@ Value *TranslateTraceRay(CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
bool &Translated) {
hlsl::OP *OP = &Helper.hlslOP;

Value *RayDesc = CI->getArgOperand(HLOperandIndex::kTraceRayRayDescOpIdx);
Value *PayLoad = CI->getArgOperand(HLOperandIndex::kTraceRayPayLoadOpIdx);

Value *Args[DXIL::OperandIndex::kTraceRayNumOp];
Args[0] = OP->GetU32Const(static_cast<unsigned>(OpCode));
for (unsigned i = 1; i < HLOperandIndex::kTraceRayRayDescOpIdx; i++)
Args[i] = CI->getArgOperand(i);
unsigned SrcIndex = 1;
unsigned DestIndex = 1;
for (; DestIndex < HLOperandIndex::kTraceRayRayDescOpIdx;
++DestIndex, ++SrcIndex)
Args[DestIndex] = CI->getArgOperand(SrcIndex);

IRBuilder<> Builder(CI);
LoadRayDescElementsIntoArgs(Args, OP, Builder, RayDesc,
DXIL::OperandIndex::kTraceRayRayDescOpIdx);
ExtractRayDescElementsIntoArgs(Args, Builder, CI, DestIndex, SrcIndex);

Value *PayLoad = CI->getArgOperand(SrcIndex);
Args[DXIL::OperandIndex::kTraceRayPayloadOpIdx] = PayLoad;

Type *Ty = PayLoad->getType();
Expand Down Expand Up @@ -5825,25 +5814,7 @@ Value *TranslateTraceRayInline(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
unsigned hlIndex = HLOperandIndex::kTraceRayInlineRayDescOpIdx;
unsigned index = DXIL::OperandIndex::kTraceRayInlineRayDescOpIdx;

// struct RayDesc
//{
// float3 Origin;
Value *origin = CI->getArgOperand(hlIndex++);
Args[index++] = Builder.CreateExtractElement(origin, (uint64_t)0);
Args[index++] = Builder.CreateExtractElement(origin, 1);
Args[index++] = Builder.CreateExtractElement(origin, 2);
// float TMin;
Args[index++] = CI->getArgOperand(hlIndex++);
// float3 Direction;
Value *direction = CI->getArgOperand(hlIndex++);
Args[index++] = Builder.CreateExtractElement(direction, (uint64_t)0);
Args[index++] = Builder.CreateExtractElement(direction, 1);
Args[index++] = Builder.CreateExtractElement(direction, 2);
// float TMax;
Args[index++] = CI->getArgOperand(hlIndex++);
//};

DXASSERT_NOMSG(index == DXIL::OperandIndex::kTraceRayInlineNumOp);
ExtractRayDescElementsIntoArgs(Args, Builder, CI, index, hlIndex);

Function *F = hlslOP->GetOpFunc(opcode, Builder.getVoidTy());

Expand Down Expand Up @@ -6201,13 +6172,13 @@ Value *TranslateHitObjectMake(CallInst *CI, IntrinsicOp IOP, OP::OpCode Opcode,
HLOperationLowerHelper &Helper,
HLObjectOperationLowerHelper *ObjHelper,
bool &Translated) {
hlsl::OP *HlslOP = &Helper.hlslOP;
hlsl::OP *OP = &Helper.hlslOP;
IRBuilder<> Builder(CI);
unsigned SrcIdx = 1;
Value *HitObjectPtr = CI->getArgOperand(SrcIdx++);
if (Opcode == OP::OpCode::HitObject_MakeNop) {
Value *HitObject = TrivialDxilOperation(
Opcode, {nullptr}, Type::getVoidTy(CI->getContext()), CI, HlslOP);
Opcode, {nullptr}, Type::getVoidTy(CI->getContext()), CI, OP);
Builder.CreateStore(HitObject, HitObjectPtr);
DXASSERT(
CI->use_empty(),
Expand All @@ -6217,35 +6188,17 @@ Value *TranslateHitObjectMake(CallInst *CI, IntrinsicOp IOP, OP::OpCode Opcode,

DXASSERT_NOMSG(CI->getNumArgOperands() ==
HLOperandIndex::kHitObjectMakeMiss_NumOp);
Value *RayFlags = CI->getArgOperand(SrcIdx++);
Value *MissShaderIdx = CI->getArgOperand(SrcIdx++);
DXASSERT_NOMSG(SrcIdx == HLOperandIndex::kHitObjectMakeMissRayDescOpIdx);
Value *RayDescOrigin = CI->getArgOperand(SrcIdx++);
Value *RayDescOriginX =
Builder.CreateExtractElement(RayDescOrigin, (uint64_t)0);
Value *RayDescOriginY =
Builder.CreateExtractElement(RayDescOrigin, (uint64_t)1);
Value *RayDescOriginZ =
Builder.CreateExtractElement(RayDescOrigin, (uint64_t)2);

Value *RayDescTMin = CI->getArgOperand(SrcIdx++);
Value *RayDescDirection = CI->getArgOperand(SrcIdx++);
Value *RayDescDirectionX =
Builder.CreateExtractElement(RayDescDirection, (uint64_t)0);
Value *RayDescDirectionY =
Builder.CreateExtractElement(RayDescDirection, (uint64_t)1);
Value *RayDescDirectionZ =
Builder.CreateExtractElement(RayDescDirection, (uint64_t)2);

Value *RayDescTMax = CI->getArgOperand(SrcIdx++);
DXASSERT_NOMSG(SrcIdx == CI->getNumArgOperands());
const unsigned DxilNumArgs = DxilInst_HitObject_MakeMiss::arg_TMax + 1;
unsigned DestIdx = 1;
Value *Args[DxilNumArgs];
Args[0] = nullptr; // OpCode
Args[DestIdx++] = CI->getArgOperand(SrcIdx++); // RayFlags
Args[DestIdx++] = CI->getArgOperand(SrcIdx++); // MissShaderIndex
ExtractRayDescElementsIntoArgs(Args, Builder, CI, DestIdx, SrcIdx);
DXASSERT_NOMSG(DestIdx == DxilNumArgs);

Value *OutHitObject = TrivialDxilOperation(
Opcode,
{nullptr, RayFlags, MissShaderIdx, RayDescOriginX, RayDescOriginY,
RayDescOriginZ, RayDescTMin, RayDescDirectionX, RayDescDirectionY,
RayDescDirectionZ, RayDescTMax},
Helper.voidTy, CI, HlslOP);
Value *OutHitObject =
TrivialDxilOperation(Opcode, Args, Helper.voidTy, CI, OP);
Builder.CreateStore(OutHitObject, HitObjectPtr);
return nullptr;
}
Expand Down Expand Up @@ -6363,12 +6316,10 @@ Value *TranslateHitObjectTraceRay(CallInst *CI, IntrinsicOp IOP,
Args[DestIdx] = CI->getArgOperand(SrcIdx);
}

Value *RayDescPtr = CI->getArgOperand(SrcIdx++);
DestIdx = LoadRayDescElementsIntoArgs(Args, OP, Builder, RayDescPtr, DestIdx);
ExtractRayDescElementsIntoArgs(Args, Builder, CI, DestIdx, SrcIdx);
Value *Payload = CI->getArgOperand(SrcIdx++);
Args[DestIdx++] = Payload;

DXASSERT_NOMSG(SrcIdx == CI->getNumArgOperands());
DXASSERT_NOMSG(DestIdx == DxilNumArgs);

Function *F = OP->GetOpFunc(OpCode, Payload->getType());
Expand Down
53 changes: 28 additions & 25 deletions lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1539,9 +1539,7 @@ void isSafeForScalarRepl(Instruction *I, uint64_t Offset, AllocaInfo &Info) {
// TODO: should we check HL parameter type for UDT overload instead of
// basing on IOP?
IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(CI));
if (IntrinsicOp::IOP_TraceRay == opcode ||
IntrinsicOp::MOP_DxHitObject_TraceRay == opcode ||
IntrinsicOp::MOP_DxHitObject_Invoke == opcode ||
if (IntrinsicOp::MOP_DxHitObject_Invoke == opcode ||
IntrinsicOp::IOP_ReportHit == opcode ||
IntrinsicOp::IOP_CallShader == opcode) {
return MarkUnsafe(Info, User);
Expand Down Expand Up @@ -2756,18 +2754,25 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
/*loadElts*/ false);
DeadInsts.push_back(CI);
} break;
case IntrinsicOp::IOP_TraceRay: {
if (OldVal ==
CI->getArgOperand(HLOperandIndex::kTraceRayRayDescOpIdx)) {
RewriteCallArg(CI, HLOperandIndex::kTraceRayRayDescOpIdx,
/*bIn*/ true, /*bOut*/ false);
} else {
DXASSERT(OldVal ==
CI->getArgOperand(HLOperandIndex::kTraceRayPayLoadOpIdx),
"else invalid TraceRay");
RewriteCallArg(CI, HLOperandIndex::kTraceRayPayLoadOpIdx,
/*bIn*/ true, /*bOut*/ true);
case IntrinsicOp::IOP_TraceRay:
case IntrinsicOp::MOP_DxHitObject_TraceRay: {
const int RayDescIdx =
IOP == IntrinsicOp::IOP_TraceRay
? HLOperandIndex::kTraceRayRayDescOpIdx
: HLOperandIndex::kHitObjectTraceRay_RayDescOpIdx;
if (OldVal == CI->getArgOperand(RayDescIdx)) {
RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts,
/*loadElts*/ true);
DeadInsts.push_back(CI);
break;
}
// Payload will always be the last argument. Actual param index depends
// on whether RayDesc was flattened before.
const int PayloadIdx = CI->getNumArgOperands() - 1;
DXASSERT(OldVal == CI->getArgOperand(PayloadIdx),
"else invalid HitObject::TraceRay");
RewriteCallArg(CI, PayloadIdx,
/*bIn*/ true, /*bOut*/ true);
} break;
case IntrinsicOp::IOP_ReportHit: {
RewriteCallArg(CI, HLOperandIndex::kReportIntersectionAttributeOpIdx,
Expand All @@ -2785,16 +2790,6 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
DeadInsts.push_back(CI);
}
} break;
case IntrinsicOp::MOP_TraceRayInline: {
if (OldVal ==
CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) {
RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts,
/*loadElts*/ true);
DeadInsts.push_back(CI);
break;
}
}
LLVM_FALLTHROUGH;
case IntrinsicOp::MOP_DxHitObject_FromRayQuery: {
const bool IsWithAttrs =
CI->getNumArgOperands() ==
Expand All @@ -2817,7 +2812,15 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts,
/*loadElts*/ true);
DeadInsts.push_back(CI);
break;
} break;
case IntrinsicOp::MOP_TraceRayInline: {
if (OldVal ==
CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) {
RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts,
/*loadElts*/ true);
DeadInsts.push_back(CI);
break;
}
}
default:
// RayQuery this pointer replacement.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: %dxc %s -T lib_6_9 -DHIT1=1 -DHIT2=1 -DHIT3=1 -DHIT4=1
// RUN: %dxc %s -T lib_6_9 -DHIT1=0 -DHIT2=1 -DHIT3=1 -DHIT4=0
// RUN: %dxc %s -T lib_6_9 -DHIT1=1 -DHIT2=0 -DHIT3=0 -DHIT4=1
// RUN: %dxc %s -T lib_6_9 -DHIT1=1 -DHIT2=0 -DHIT3=1 -DHIT4=0
// RUN: %dxc %s -T lib_6_9 -DHIT1=0 -DHIT2=1 -DHIT3=0 -DHIT4=1

struct[raypayload] PerRayData
{
uint dummy : read(anyhit,closesthit,miss,caller) : write(anyhit,miss,closesthit,caller);
};

struct Attrs
{
float2 barycentrics : BARYCENTRICS;
};

RaytracingAccelerationStructure topObject : register(t0);

[shader("raygeneration")]
void raygen()
{
RayDesc ray = {{0, 1, 2}, 3, {4, 5, 6}, 7};

PerRayData payload;
#if HIT1
dx::HitObject hit1 = dx::HitObject::MakeMiss(RAY_FLAG_NONE, 0, ray);
dx::MaybeReorderThread(hit1);
#endif
#if HIT2
dx::HitObject hit2 = dx::HitObject::TraceRay(topObject, RAY_FLAG_SKIP_TRIANGLES, 0xFF, 0, 0, 0, ray, payload);
dx::MaybeReorderThread(hit2);
#endif
#if HIT3
RayQuery<RAY_FLAG_NONE> rayQuery;
rayQuery.TraceRayInline(topObject, RAY_FLAG_NONE, 0xFF, ray);
dx::HitObject hit3 = dx::HitObject::FromRayQuery(rayQuery);
dx::MaybeReorderThread(hit3);
#endif
#if HIT4
TraceRay(topObject, RAY_FLAG_SKIP_TRIANGLES, 0xFF, 0, 0, 0, ray, payload);
#endif
}
Loading
Loading