Skip to content

Commit 81ef064

Browse files
authored
Fix broken ternary on matrix return type (microsoft#4434) (microsoft#4460)
HLSL ternary operators that result in vector or matrix types need special handling even if the condition is not a vector. This change allows vector and matrix result types for ternary operators even if matrix and vector conditions are not allowed. The change works by generating an alloca before ternary blocks, and terminating ternary blocks by writing the resulting matrices to the alloca. The result of the ternary is then a load of the alloca'd matrix. Fixes microsoft#4434
1 parent f0b5fd2 commit 81ef064

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

tools/clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
#include "TargetInfo.h"
2121
#include "clang/AST/ASTContext.h"
2222
#include "clang/AST/DeclObjC.h"
23+
#include "clang/AST/HlslTypes.h"
2324
#include "clang/AST/RecordLayout.h"
2425
#include "clang/AST/StmtVisitor.h"
2526
#include "clang/Basic/TargetInfo.h"
2627
#include "clang/Frontend/CodeGenOptions.h"
28+
#include "dxc/DXIL/DxilUtil.h" // HLSL Change
2729
#include "llvm/IR/CFG.h"
2830
#include "llvm/IR/Constants.h"
2931
#include "llvm/IR/DataLayout.h"
@@ -3701,7 +3703,7 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
37013703
}
37023704
// HLSL Change Starts
37033705
if (CGF.getLangOpts().HLSL && !CGF.getLangOpts().EnableShortCircuit) {
3704-
// HLSL does not short circuit by default.
3706+
// HLSL does not short circuit by default before HLSL 2021
37053707
if (hlsl::IsHLSLVecType(E->getType()) || E->getType()->isArithmeticType()) {
37063708
llvm::Value *CondV = CGF.EmitScalarExpr(condExpr);
37073709
llvm::Value *LHS = Visit(lhsExpr);
@@ -3729,6 +3731,7 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
37293731
CGF, E, LHS->getType(), {Cond, LHS, RHS});
37303732
}
37313733
}
3734+
37323735
// HLSL Change Ends
37333736

37343737
// If this is a really simple expression (like x ? 4 : 5), emit this as a
@@ -3749,6 +3752,17 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
37493752
return Builder.CreateSelect(CondV, LHS, RHS, "cond");
37503753
}
37513754

3755+
// HLSL Change Begins
3756+
llvm::Instruction *ResultAlloca = nullptr;
3757+
if (CGF.getLangOpts().HLSL && CGF.getLangOpts().EnableShortCircuit &&
3758+
hlsl::IsHLSLMatType(E->getType())) {
3759+
llvm::Type *MatTy = CGF.ConvertTypeForMem(E->getType());
3760+
ResultAlloca = CGF.CreateTempAlloca(MatTy);
3761+
ResultAlloca->moveBefore(hlsl::dxilutil::FindAllocaInsertionPt(
3762+
Builder.GetInsertBlock()->getParent()));
3763+
}
3764+
// HLSL Change Ends
3765+
37523766
llvm::BasicBlock *LHSBlock = CGF.createBasicBlock("cond.true");
37533767
llvm::BasicBlock *RHSBlock = CGF.createBasicBlock("cond.false");
37543768
llvm::BasicBlock *ContBlock = CGF.createBasicBlock("cond.end");
@@ -3761,6 +3775,11 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
37613775
CGF.incrementProfileCounter(E);
37623776
eval.begin(CGF);
37633777
Value *LHS = Visit(lhsExpr);
3778+
// HLSL Change Begin - Handle matrix ternary
3779+
if (ResultAlloca)
3780+
CGF.CGM.getHLSLRuntime().EmitHLSLMatrixStore(CGF, LHS, ResultAlloca,
3781+
E->getType());
3782+
// HLSL Change End
37643783
eval.end(CGF);
37653784

37663785
LHSBlock = Builder.GetInsertBlock();
@@ -3769,11 +3788,22 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
37693788
CGF.EmitBlock(RHSBlock);
37703789
eval.begin(CGF);
37713790
Value *RHS = Visit(rhsExpr);
3791+
// HLSL Change Begin - Handle matrix ternary
3792+
if (ResultAlloca)
3793+
CGF.CGM.getHLSLRuntime().EmitHLSLMatrixStore(CGF, RHS, ResultAlloca,
3794+
E->getType());
3795+
// HLSL Change End
37723796
eval.end(CGF);
37733797

37743798
RHSBlock = Builder.GetInsertBlock();
37753799
CGF.EmitBlock(ContBlock);
37763800

3801+
// HLSL Change Begin - Handle matrix ternary
3802+
if (ResultAlloca)
3803+
return CGF.CGM.getHLSLRuntime().EmitHLSLMatrixLoad(CGF, ResultAlloca,
3804+
E->getType());
3805+
// HLSL Change End
3806+
37773807
// If the LHS or RHS is a throw expression, it will be legitimately null.
37783808
if (!LHS)
37793809
return RHS;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %dxc -T cs_6_0 -E CSMain -HV 2021 %s -fcgl | FileCheck %s
2+
3+
float2x2 crashingFunction(bool b) {
4+
float2x2 x = {0.0, 0.0, 0.0, 0.0};
5+
float2x2 y = {0.0, 0.0, 0.0, 0.0};
6+
return b ? x : y; // <-- this is the issue
7+
}
8+
9+
[numthreads(1, 1, 1)] void CSMain() {
10+
if (crashingFunction(true)[0][0] > 0)
11+
return;
12+
}
13+
14+
// CHECK: define internal %class.matrix.float.2.2 @"\01?crashingFunction@@YA?AV?$matrix@M$01$01@@_N@Z"
15+
// CHECK: [[ALLOCA:%[0-9a-z]+]] = alloca %class.matrix.float.2.2
16+
// CHECK: preds = {{%[0-9a-z]+}}
17+
// CHECK: call %class.matrix.float.2.2 @"dx.hl.matldst.colStore.%class.matrix.float.2.2 (i32, %class.matrix.float.2.2*, %class.matrix.float.2.2)"(i32 1, %class.matrix.float.2.2* [[ALLOCA]], %class.matrix.float.2.2 %{{[0-9]+}})
18+
// CHECK: preds = {{%[0-9a-z]+}}
19+
// CHECK: call %class.matrix.float.2.2 @"dx.hl.matldst.colStore.%class.matrix.float.2.2 (i32, %class.matrix.float.2.2*, %class.matrix.float.2.2)"(i32 1, %class.matrix.float.2.2* [[ALLOCA]], %class.matrix.float.2.2 %{{[0-9]+}})
20+
// CHECK: preds = {{%[0-9a-z.]+}}, {{%[0-9a-z.]+}}
21+
// CHECK: call %class.matrix.float.2.2 @"dx.hl.matldst.colLoad.%class.matrix.float.2.2 (i32, %class.matrix.float.2.2*)"(i32 0, %class.matrix.float.2.2* [[ALLOCA]])
22+

0 commit comments

Comments
 (0)