Skip to content

Commit 1eea545

Browse files
authored
Fixed a crash with cbuffer lowering with new poison value mechanism (microsoft#4684)
There was a crash when CBuffer vector subscript uses a non-literal index but is replaced with Constant in the optimizer before CBuffer load is lowerd. The CBuffer lowering code expects the constant to be within bounds, since it's translating the GEP to extractvalue. The out of bounds subscript is not correct and should cause an error, but ONLY if it actually exists in the final DXIL. This change adds a way to emit error that only get emitted if the associated code is not removed by the end of compilation. Instead of emitting an error right away, emit a poison value with line number and error message instead and have it used by the problematic code. If the problematic code is not removed by the end of compilation, then this poison value would also be there, and it's safe to emit a real error based on it.
1 parent 69e6a84 commit 1eea545

File tree

10 files changed

+283
-14
lines changed

10 files changed

+283
-14
lines changed

include/dxc/HLSL/DxilPoisonValues.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
///////////////////////////////////////////////////////////////////////////////
2+
// //
3+
// DxilPoisonValues.h //
4+
// Copyright (C) Microsoft Corporation. All rights reserved. //
5+
// This file is distributed under the University of Illinois Open Source //
6+
// License. See LICENSE.TXT for details. //
7+
// //
8+
// Allows insertion of poisoned values with error messages that get //
9+
// cleaned up late in the compiler. //
10+
// //
11+
///////////////////////////////////////////////////////////////////////////////
12+
#pragma once
13+
14+
#include "llvm/IR/DebugLoc.h"
15+
16+
namespace llvm {
17+
class Instruction;
18+
class Type;
19+
class Value;
20+
class Module;
21+
}
22+
23+
namespace hlsl {
24+
// Create a special dx.poison.* instruction with debug location and an error message.
25+
// The reason for this is certain invalid code is allowed in the code as long as it is
26+
// removed by optimization at the end of compilation. We only want to emit the error
27+
// for real if we are sure the code with the problem is in the final DXIL.
28+
//
29+
// This "emits" an error message with the specified type. If by the end the compilation
30+
// it is still used, then FinalizePoisonValues will emit the correct error for real.
31+
llvm::Value *CreatePoisonValue(llvm::Type *ty, const llvm::Twine &errMsg, llvm::DebugLoc DL, llvm::Instruction *InsertPt);
32+
33+
// If there's any dx.poison.* values present in the module, emit them. Returns true
34+
// M was modified in any way.
35+
bool FinalizePoisonValues(llvm::Module &M);
36+
}
37+

lib/HLSL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ add_llvm_library(LLVMHLSL
5858
PauseResumePasses.cpp
5959
WaveSensitivityAnalysis.cpp
6060
DxilNoOptLegalize.cpp
61+
DxilPoisonValues.cpp
6162
DxilDeleteRedundantDebugValues.cpp
6263

6364
ADDITIONAL_HEADER_DIRS

lib/HLSL/DxilPoisonValues.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
///////////////////////////////////////////////////////////////////////////////
2+
// //
3+
// DxilPoisonValues.cpp //
4+
// Copyright (C) Microsoft Corporation. All rights reserved. //
5+
// This file is distributed under the University of Illinois Open Source //
6+
// License. See LICENSE.TXT for details. //
7+
// //
8+
// Allows insertion of poisoned values with error messages that get //
9+
// cleaned up late in the compiler. //
10+
// //
11+
///////////////////////////////////////////////////////////////////////////////
12+
13+
#include "dxc/HLSL/DxilPoisonValues.h"
14+
15+
#include "llvm/Support/raw_ostream.h"
16+
#include "llvm/IR/DiagnosticInfo.h"
17+
#include "llvm/IR/Instructions.h"
18+
#include "llvm/IR/Constants.h"
19+
#include "llvm/IR/LLVMContext.h"
20+
21+
using namespace llvm;
22+
23+
constexpr const char kPoisonPrefix[] = "dx.poison.";
24+
25+
namespace hlsl {
26+
Value *CreatePoisonValue(Type *ty, const Twine &errMsg, DebugLoc DL, Instruction *InsertPt) {
27+
std::string functionName;
28+
{
29+
llvm::raw_string_ostream os(functionName);
30+
os << kPoisonPrefix;
31+
os << *ty;
32+
os.flush();
33+
}
34+
35+
Module &M = *InsertPt->getModule();
36+
37+
LLVMContext &C = M.getContext();
38+
Type *argTypes[] = { Type::getMetadataTy(C) };
39+
FunctionType *ft = FunctionType::get(ty, argTypes, false);
40+
Constant *f = M.getOrInsertFunction(functionName, ft);
41+
42+
std::string errMsgStr = errMsg.str();
43+
Value *args[] = {
44+
MetadataAsValue::get(C, MDString::get(C, errMsgStr))
45+
};
46+
CallInst *ret = CallInst::Create(f, ArrayRef<Value *>(args), "err", InsertPt);
47+
ret->setDebugLoc(DL);
48+
return ret;
49+
}
50+
51+
bool FinalizePoisonValues(Module &M) {
52+
bool changed = false;
53+
LLVMContext &Ctx = M.getContext();
54+
for (auto it = M.begin(); it != M.end();) {
55+
Function *F = &*(it++);
56+
if (F->getName().startswith(kPoisonPrefix)) {
57+
for (auto it = F->user_begin(); it != F->user_end();) {
58+
User *U = *(it++);
59+
CallInst *call = cast<CallInst>(U);
60+
MDString *errMsgMD = cast<MDString>(cast<MetadataAsValue>(call->getArgOperand(0))->getMetadata());
61+
StringRef errMsg = errMsgMD->getString();
62+
63+
Ctx.diagnose(DiagnosticInfoDxil(F, call->getDebugLoc(), errMsg, DS_Error));
64+
if (!call->getType()->isVoidTy())
65+
call->replaceAllUsesWith(UndefValue::get(call->getType()));
66+
call->eraseFromParent();
67+
}
68+
F->eraseFromParent();
69+
changed = true;
70+
}
71+
}
72+
return changed;
73+
}
74+
}
75+

lib/HLSL/DxilPreparePasses.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "dxc/DXIL/DxilInstructions.h"
2222
#include "dxc/DXIL/DxilConstants.h"
2323
#include "dxc/HlslIntrinsicOp.h"
24+
#include "dxc/HLSL/DxilPoisonValues.h"
2425
#include "llvm/IR/GetElementPtrTypeIterator.h"
2526
#include "llvm/IR/IRBuilder.h"
2627
#include "llvm/IR/Instructions.h"
@@ -757,6 +758,10 @@ class DxilFinalizeModule : public ModulePass {
757758
}
758759

759760
bool runOnModule(Module &M) override {
761+
762+
// Remove all the poisoned values and emit errors if necessary.
763+
(void)hlsl::FinalizePoisonValues(M);
764+
760765
if (M.HasDxilModule()) {
761766
DxilModule &DM = M.GetDxilModule();
762767
unsigned ValMajor = 0;

lib/HLSL/HLOperationLower.cpp

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "dxc/HLSL/HLOperations.h"
2727
#include "dxc/HlslIntrinsicOp.h"
2828
#include "dxc/DXIL/DxilResourceProperties.h"
29+
#include "dxc/HLSL/DxilPoisonValues.h"
2930

3031
#include "llvm/IR/GetElementPtrTypeIterator.h"
3132
#include "llvm/IR/IRBuilder.h"
@@ -135,6 +136,34 @@ struct HLObjectOperationLowerHelper {
135136
MarkHasCounterOnCreateHandle(counterHandle, resSet);
136137
}
137138

139+
DxilResourceBase *FindCBufferResourceFromHandle(Value *handle) {
140+
if (CallInst *CI = dyn_cast<CallInst>(handle)) {
141+
hlsl::HLOpcodeGroup group =
142+
hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
143+
if (group == HLOpcodeGroup::HLAnnotateHandle) {
144+
handle = CI->getArgOperand(HLOperandIndex::kAnnotateHandleHandleOpIdx);
145+
}
146+
}
147+
148+
Constant *symbol = nullptr;
149+
if (CallInst *CI = dyn_cast<CallInst>(handle)) {
150+
hlsl::HLOpcodeGroup group =
151+
hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
152+
if (group == HLOpcodeGroup::HLCreateHandle) {
153+
symbol = dyn_cast<Constant>(CI->getArgOperand(HLOperandIndex::kCreateHandleResourceOpIdx));
154+
}
155+
}
156+
157+
if (!symbol)
158+
return nullptr;
159+
160+
for (const std::unique_ptr<DxilCBuffer> &res : HLM.GetCBuffers()) {
161+
if (res->GetGlobalSymbol() == symbol)
162+
return res.get();
163+
}
164+
return nullptr;
165+
}
166+
138167
Value *GetOrCreateResourceForCbPtr(GetElementPtrInst *CbPtr,
139168
GlobalVariable *CbGV,
140169
DxilResourceProperties &RP) {
@@ -6816,31 +6845,33 @@ void TranslateCBGepLegacy(GetElementPtrInst *GEP, Value *handle,
68166845
// Array always start from x channel.
68176846
channel = 0;
68186847
} else if (GEPIt->isVectorTy()) {
6819-
unsigned size = DL.getTypeAllocSize(GEPIt->getVectorElementType());
68206848
// Indexing on vector.
68216849
if (bImmIdx) {
6822-
unsigned tempOffset = size * immIdx;
6823-
if (size == 2) { // 16-bit types
6824-
unsigned channelInc = tempOffset >> 1;
6825-
DXASSERT((channel + channelInc) <= 8, "vector should not cross cb register (8x16bit)");
6850+
if (immIdx < GEPIt->getVectorNumElements()) {
6851+
const unsigned vectorElmSize = DL.getTypeAllocSize(GEPIt->getVectorElementType());
6852+
const bool bIs16bitType = vectorElmSize == 2;
6853+
const unsigned tempOffset = vectorElmSize * immIdx;
6854+
const unsigned numChannelsPerRow = bIs16bitType ? 8 : 4;
6855+
const unsigned channelInc = bIs16bitType ? tempOffset >> 1 : tempOffset >> 2;
6856+
6857+
DXASSERT((channel + channelInc) < numChannelsPerRow, "vector should not cross cb register");
68266858
channel += channelInc;
6827-
if (channel == 8) {
6859+
if (channel == numChannelsPerRow) {
68286860
// Get to another row.
68296861
// Update index and channel.
68306862
channel = 0;
68316863
legacyIndex = Builder.CreateAdd(legacyIndex, Builder.getInt32(1));
68326864
}
68336865
}
68346866
else {
6835-
unsigned channelInc = tempOffset >> 2;
6836-
DXASSERT((channel + channelInc) <= 4, "vector should not cross cb register (8x32bit)");
6837-
channel += channelInc;
6838-
if (channel == 4) {
6839-
// Get to another row.
6840-
// Update index and channel.
6841-
channel = 0;
6842-
legacyIndex = Builder.CreateAdd(legacyIndex, Builder.getInt32(1));
6867+
StringRef resName = "(unknown)";
6868+
if (DxilResourceBase *Res = pObjHelper->FindCBufferResourceFromHandle(handle)) {
6869+
resName = Res->GetGlobalName();
68436870
}
6871+
legacyIndex = hlsl::CreatePoisonValue(legacyIndex->getType(),
6872+
Twine("Out of bounds index (") + Twine(immIdx) + Twine(") in CBuffer '") + Twine(resName) + ("'"),
6873+
GEP->getDebugLoc(), GEP);
6874+
channel = 0;
68446875
}
68456876
} else {
68466877
Type *EltTy = GEPIt->getVectorElementType();
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %dxc -E main -T ps_6_0 -Od %s | FileCheck %s
2+
3+
cbuffer constants : register(b0)
4+
{
5+
float4 foo;
6+
}
7+
8+
float main() : SV_TARGET
9+
{
10+
float ret = 0;
11+
uint index = 5;
12+
if (index < 4) {
13+
ret += foo[index];
14+
}
15+
return ret;
16+
}
17+
18+
// Regression test for Od when an out of bound access happens in a dead block.
19+
// We want to make sure it doesn't crash, compiles successfully, and the
20+
// cbuffer load doesn't actually get generated.
21+
22+
// CHECK: @main() {
23+
// CHECK-NOT: @dx.op.cbufferLoadLegacy.f32(i32 59
24+
// CHECK: call void @dx.op.storeOutput.f32(
25+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %dxc -E main -T ps_6_2 -enable-16bit-types -Od %s | FileCheck %s
2+
3+
cbuffer constants : register(b0)
4+
{
5+
half4 foo;
6+
}
7+
8+
float main() : SV_TARGET
9+
{
10+
float ret = 0;
11+
uint index = 9;
12+
if (index < 8) {
13+
ret += foo[index];
14+
}
15+
return ret;
16+
}
17+
18+
// Regression test for Od when an out of bound access happens in a dead block.
19+
// We want to make sure it doesn't crash, compiles successfully, and the
20+
// cbuffer load doesn't actually get generated.
21+
22+
// CHECK: @main() {
23+
// CHECK-NOT: @dx.op.cbufferLoadLegacy.f32(i32 59
24+
// CHECK: call void @dx.op.storeOutput.f32(
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %dxc -E main -T ps_6_2 -enable-16bit-types -Od %s | FileCheck %s
2+
3+
cbuffer constants : register(b0)
4+
{
5+
half2 foo;
6+
}
7+
8+
float main() : SV_TARGET
9+
{
10+
float ret = 0;
11+
uint index = 2;
12+
if (index < 10) {
13+
ret += foo[index];
14+
}
15+
return ret;
16+
}
17+
18+
// Regression test for Od when an out of bound access happens but:
19+
// 1) the front-end can't immediate figure it out,
20+
// 2) when lowering the cbuffer load, the index is resolved to be a constant
21+
// 3) lowering code crashed because it assumed OOB access isn't possible when the index is constant.
22+
23+
// CHECK: 13:{{[0-9]+}}: error: Out of bounds index (2) in CBuffer 'constants'
24+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %dxc -E main -T ps_6_0 -Od %s | FileCheck %s
2+
3+
cbuffer constants : register(b0)
4+
{
5+
float2 foo;
6+
}
7+
8+
float main() : SV_TARGET
9+
{
10+
float ret = 0;
11+
uint index = 2;
12+
if (index < 10) {
13+
ret += foo[index];
14+
}
15+
return ret;
16+
}
17+
18+
// Regression test for Od when an out of bound access happens but:
19+
// 1) the front-end can't immediate figure it out,
20+
// 2) when lowering the cbuffer load, the index is resolved to be a constant
21+
// 3) lowering code crashed because it assumed OOB access isn't possible when the index is constant.
22+
23+
// CHECK: 13:{{[0-9]+}}: error: Out of bounds index (2) in CBuffer 'constants'
24+
25+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %dxc -E main -T ps_6_0 -Od %s | FileCheck %s
2+
3+
const float4 foo;
4+
5+
float main() : SV_TARGET
6+
{
7+
float ret = 0;
8+
uint index = 5;
9+
if (index < 10) {
10+
ret += foo[index];
11+
}
12+
return ret;
13+
}
14+
15+
// Regression test for Od when an out of bound access happens but:
16+
// 1) the front-end can't immediate figure it out,
17+
// 2) when lowering the cbuffer load, the index is resolved to be a constant
18+
// 3) lowering code crashed because it assumed OOB access isn't possible when the index is constant.
19+
20+
// CHECK: 10:{{[0-9]+}}: error: Out of bounds index (5) in CBuffer '$Globals'
21+
22+

0 commit comments

Comments
 (0)