Skip to content

Commit 1028410

Browse files
authored
[SPIR-V] Fix invalid isnan codegen (microsoft#6754)
IsNan returns a boolean, even is the input-type is a float. This was working in most cases except: - if the layout was not Void - if the input type was not a matrix The first bug is because a bool memory layout/representation is not specified, and shall never be exposed to externaly-accessible memory. Hence, if we saw a layout rule != Void, we converted it to a UINT. When calling isnan, the layout rule should not be propagated as we loose any layout info. The second is because our codegen assumed matrix operations returned a matrix with the same type as the input parameters. In the case of isnan, this was just wrong. Fixes microsoft#6712 Signed-off-by: Nathan Gauër <[email protected]>
1 parent 6f1c8e2 commit 1028410

File tree

3 files changed

+148
-62
lines changed

3 files changed

+148
-62
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 90 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6286,9 +6286,10 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
62866286
if (isMxNMatrix(subType)) {
62876287
// For matrices, we can only increment/decrement each vector of it.
62886288
const auto actOnEachVec = [this, spvOp, one, expr,
6289-
range](uint32_t /*index*/, QualType vecType,
6289+
range](uint32_t /*index*/, QualType inType,
6290+
QualType outType,
62906291
SpirvInstruction *lhsVec) {
6291-
auto *val = spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, one,
6292+
auto *val = spvBuilder.createBinaryOp(spvOp, outType, lhsVec, one,
62926293
expr->getOperatorLoc(), range);
62936294
if (val)
62946295
val->setRValue();
@@ -6356,9 +6357,10 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
63566357
if (isMxNMatrix(subType)) {
63576358
// For matrices, we can only negate each vector of it.
63586359
const auto actOnEachVec = [this, spvOp, expr,
6359-
range](uint32_t /*index*/, QualType vecType,
6360+
range](uint32_t /*index*/, QualType inType,
6361+
QualType outType,
63606362
SpirvInstruction *lhsVec) {
6361-
return spvBuilder.createUnaryOp(spvOp, vecType, lhsVec,
6363+
return spvBuilder.createUnaryOp(spvOp, outType, lhsVec,
63626364
expr->getOperatorLoc(), range);
63636365
};
63646366
return processEachVectorInMatrix(subExpr, subValue, actOnEachVec,
@@ -7929,27 +7931,39 @@ void SpirvEmitter::assignToMSOutIndices(
79297931

79307932
SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
79317933
const Expr *matrix, SpirvInstruction *matrixVal,
7932-
llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
7934+
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
7935+
SpirvInstruction *)>
7936+
actOnEachVector,
7937+
SourceLocation loc, SourceRange range) {
7938+
return processEachVectorInMatrix(matrix, matrix->getType(), matrixVal,
7939+
actOnEachVector, loc, range);
7940+
}
7941+
7942+
SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
7943+
const Expr *matrix, QualType outputType, SpirvInstruction *matrixVal,
7944+
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
79337945
SpirvInstruction *)>
79347946
actOnEachVector,
79357947
SourceLocation loc, SourceRange range) {
79367948
const auto matType = matrix->getType();
7937-
assert(isMxNMatrix(matType));
7938-
const QualType vecType = getComponentVectorType(astContext, matType);
7949+
assert(isMxNMatrix(matType) && isMxNMatrix(outputType));
7950+
const QualType inVecType = getComponentVectorType(astContext, matType);
7951+
const QualType outVecType = getComponentVectorType(astContext, outputType);
79397952

79407953
uint32_t rowCount = 0, colCount = 0;
79417954
hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
79427955

79437956
llvm::SmallVector<SpirvInstruction *, 4> vectors;
79447957
// Extract each component vector and do operation on it
79457958
for (uint32_t i = 0; i < rowCount; ++i) {
7946-
auto *lhsVec = spvBuilder.createCompositeExtract(vecType, matrixVal, {i},
7959+
auto *lhsVec = spvBuilder.createCompositeExtract(inVecType, matrixVal, {i},
79477960
matrix->getLocStart());
7948-
vectors.push_back(actOnEachVector(i, vecType, lhsVec));
7961+
vectors.push_back(actOnEachVector(i, inVecType, outVecType, lhsVec));
79497962
}
79507963

79517964
// Construct the result matrix
7952-
auto *val = spvBuilder.createCompositeConstruct(matType, vectors, loc, range);
7965+
auto *val =
7966+
spvBuilder.createCompositeConstruct(outputType, vectors, loc, range);
79537967
if (!val)
79547968
return nullptr;
79557969
val->setRValue();
@@ -8056,15 +8070,15 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
80568070
case BO_MulAssign:
80578071
case BO_DivAssign:
80588072
case BO_RemAssign: {
8059-
const auto actOnEachVec = [this, spvOp, rhsVal, rhs, loc,
8060-
range](uint32_t index, QualType vecType,
8061-
SpirvInstruction *lhsVec) {
8073+
const auto actOnEachVec = [this, spvOp, rhsVal, rhs, loc, range](
8074+
uint32_t index, QualType inType,
8075+
QualType outType, SpirvInstruction *lhsVec) {
80628076
// For each vector of lhs, we need to load the corresponding vector of
80638077
// rhs and do the operation on them.
8064-
auto *rhsVec = spvBuilder.createCompositeExtract(vecType, rhsVal, {index},
8078+
auto *rhsVec = spvBuilder.createCompositeExtract(inType, rhsVal, {index},
80658079
rhs->getLocStart());
80668080
auto *val =
8067-
spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec, loc, range);
8081+
spvBuilder.createBinaryOp(spvOp, outType, lhsVec, rhsVec, loc, range);
80688082
if (val)
80698083
val->setRValue();
80708084
return val;
@@ -9066,6 +9080,15 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
90669080
case hlsl::IntrinsicOp::IOP_firstbitlow: {
90679081
retVal = processIntrinsicFirstbit(callExpr, GLSLstd450::GLSLstd450FindILsb);
90689082
break;
9083+
}
9084+
case hlsl::IntrinsicOp::IOP_isnan: {
9085+
retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpIsNan,
9086+
/* doEachVec= */ true);
9087+
// OpIsNan returns a bool/vec<bool>, so the only valid layout is void. It
9088+
// will be the responsibility of the store to do an OpSelect and correctly
9089+
// convert this type to an externally storable type.
9090+
retVal->setLayoutRule(SpirvLayoutRule::Void);
9091+
break;
90699092
}
90709093
INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
90719094
INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
@@ -9075,7 +9098,6 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
90759098
INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
90769099
INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
90779100
INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
9078-
INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
90799101
INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
90809102
INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
90819103
INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
@@ -10030,14 +10052,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMad(const CallExpr *callExpr) {
1003010052
if (isMxNMatrix(arg0->getType())) {
1003110053
const auto actOnEachVec = [this, loc, arg1Instr, arg2Instr, arg1Loc,
1003210054
arg2Loc,
10033-
range](uint32_t index, QualType vecType,
10055+
range](uint32_t index, QualType inType,
10056+
QualType outType,
1003410057
SpirvInstruction *arg0Row) {
1003510058
auto *arg1Row = spvBuilder.createCompositeExtract(
10036-
vecType, arg1Instr, {index}, arg1Loc, range);
10059+
inType, arg1Instr, {index}, arg1Loc, range);
1003710060
auto *arg2Row = spvBuilder.createCompositeExtract(
10038-
vecType, arg2Instr, {index}, arg2Loc, range);
10061+
inType, arg2Instr, {index}, arg2Loc, range);
1003910062
auto *fma = spvBuilder.createGLSLExtInst(
10040-
vecType, GLSLstd450Fma, {arg0Row, arg1Row, arg2Row}, loc, range);
10063+
outType, GLSLstd450Fma, {arg0Row, arg1Row, arg2Row}, loc, range);
1004110064
spvBuilder.decorateNoContraction(fma, loc);
1004210065
return fma;
1004310066
};
@@ -10257,13 +10280,14 @@ SpirvEmitter::processIntrinsicLdexp(const CallExpr *callExpr) {
1025710280
uint32_t rowCount = 0, colCount = 0;
1025810281
if (isMxNMatrix(paramType, nullptr, &rowCount, &colCount)) {
1025910282
const auto actOnEachVec = [this, loc, expInstr, arg1Loc,
10260-
range](uint32_t index, QualType vecType,
10283+
range](uint32_t index, QualType inType,
10284+
QualType outType,
1026110285
SpirvInstruction *xRowInstr) {
1026210286
auto *expRowInstr = spvBuilder.createCompositeExtract(
10263-
vecType, expInstr, {index}, arg1Loc, range);
10287+
inType, expInstr, {index}, arg1Loc, range);
1026410288
auto *twoExp = spvBuilder.createGLSLExtInst(
10265-
vecType, GLSLstd450::GLSLstd450Exp2, {expRowInstr}, loc, range);
10266-
return spvBuilder.createBinaryOp(spv::Op::OpFMul, vecType, xRowInstr,
10289+
outType, GLSLstd450::GLSLstd450Exp2, {expRowInstr}, loc, range);
10290+
return spvBuilder.createBinaryOp(spv::Op::OpFMul, outType, xRowInstr,
1026710291
twoExp, loc, range);
1026810292
};
1026910293
return processEachVectorInMatrix(x, xInstr, actOnEachVec, loc, range);
@@ -10427,15 +10451,15 @@ SpirvEmitter::processIntrinsicClamp(const CallExpr *callExpr) {
1042710451
// the operation on each vector of the matrix.
1042810452
if (isMxNMatrix(argX->getType())) {
1042910453
const auto actOnEachVec = [this, loc, range, glslOpcode, argMinInstr,
10430-
argMaxInstr, argMinLoc,
10431-
argMaxLoc](uint32_t index, QualType vecType,
10432-
SpirvInstruction *curRow) {
10454+
argMaxInstr, argMinLoc, argMaxLoc](
10455+
uint32_t index, QualType inType,
10456+
QualType outType, SpirvInstruction *curRow) {
1043310457
auto *minRowInstr = spvBuilder.createCompositeExtract(
10434-
vecType, argMinInstr, {index}, argMinLoc, range);
10458+
inType, argMinInstr, {index}, argMinLoc, range);
1043510459
auto *maxRowInstr = spvBuilder.createCompositeExtract(
10436-
vecType, argMaxInstr, {index}, argMaxLoc, range);
10460+
inType, argMaxInstr, {index}, argMaxLoc, range);
1043710461
return spvBuilder.createGLSLExtInst(
10438-
vecType, glslOpcode, {curRow, minRowInstr, maxRowInstr}, loc, range);
10462+
outType, glslOpcode, {curRow, minRowInstr, maxRowInstr}, loc, range);
1043910463
};
1044010464
return processEachVectorInMatrix(argX, argXInstr, actOnEachVec, loc, range);
1044110465
}
@@ -11013,12 +11037,12 @@ SpirvInstruction *SpirvEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
1101311037
uint32_t numRows = 0, numCols = 0;
1101411038
if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
1101511039
auto *vecOne = getVecValueOne(elemType, numCols);
11016-
const auto actOnEachVec = [this, vecOne, loc,
11017-
range](uint32_t /*index*/, QualType vecType,
11018-
SpirvInstruction *curRow) {
11019-
return spvBuilder.createBinaryOp(spv::Op::OpFDiv, vecType, vecOne, curRow,
11020-
loc, range);
11021-
};
11040+
const auto actOnEachVec =
11041+
[this, vecOne, loc, range](uint32_t /*index*/, QualType inType,
11042+
QualType outType, SpirvInstruction *curRow) {
11043+
return spvBuilder.createBinaryOp(spv::Op::OpFDiv, outType, vecOne,
11044+
curRow, loc, range);
11045+
};
1102211046
return processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
1102311047
}
1102411048

@@ -11335,10 +11359,10 @@ SpirvEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
1133511359
if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
1133611360
auto *vecZero = getVecValueZero(elemType, numCols);
1133711361
auto *vecOne = getVecValueOne(elemType, numCols);
11338-
const auto actOnEachVec = [this, loc, vecZero, vecOne,
11339-
range](uint32_t /*index*/, QualType vecType,
11340-
SpirvInstruction *curRow) {
11341-
return spvBuilder.createGLSLExtInst(vecType, GLSLstd450::GLSLstd450FClamp,
11362+
const auto actOnEachVec = [this, loc, vecZero, vecOne, range](
11363+
uint32_t /*index*/, QualType inType,
11364+
QualType outType, SpirvInstruction *curRow) {
11365+
return spvBuilder.createGLSLExtInst(outType, GLSLstd450::GLSLstd450FClamp,
1134211366
{curRow, vecZero, vecOne}, loc,
1134311367
range);
1134411368
};
@@ -11364,10 +11388,10 @@ SpirvEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
1136411388

1136511389
// For matrices, we can perform the instruction on each vector of the matrix.
1136611390
if (isMxNMatrix(argType)) {
11367-
const auto actOnEachVec = [this, loc, range](uint32_t /*index*/,
11368-
QualType vecType,
11369-
SpirvInstruction *curRow) {
11370-
return spvBuilder.createGLSLExtInst(vecType, GLSLstd450::GLSLstd450FSign,
11391+
const auto actOnEachVec = [this, loc, range](
11392+
uint32_t /*index*/, QualType inType,
11393+
QualType outType, SpirvInstruction *curRow) {
11394+
return spvBuilder.createGLSLExtInst(outType, GLSLstd450::GLSLstd450FSign,
1137111395
{curRow}, loc, range);
1137211396
};
1137311397
floatSign = processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
@@ -11496,12 +11520,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
1149611520
// If the instruction does not operate on matrices, we can perform the
1149711521
// instruction on each vector of the matrix.
1149811522
if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
11523+
assert(isMxNMatrix(returnType));
1149911524
const auto actOnEachVec = [this, opcode, loc,
11500-
range](uint32_t /*index*/, QualType vecType,
11525+
range](uint32_t /*index*/, QualType inType,
11526+
QualType outType,
1150111527
SpirvInstruction *curRow) {
11502-
return spvBuilder.createUnaryOp(opcode, vecType, curRow, loc, range);
11528+
return spvBuilder.createUnaryOp(opcode, outType, curRow, loc, range);
1150311529
};
11504-
return processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
11530+
return processEachVectorInMatrix(arg, returnType, argId, actOnEachVec,
11531+
loc, range);
1150511532
}
1150611533
return spvBuilder.createUnaryOp(opcode, returnType, argId, loc, range);
1150711534
} else if (callExpr->getNumArgs() == 2u) {
@@ -11514,11 +11541,12 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
1151411541
// instruction on each vector of the matrix.
1151511542
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
1151611543
const auto actOnEachVec = [this, opcode, arg1Id, loc, range, arg1Loc,
11517-
arg1Range](uint32_t index, QualType vecType,
11544+
arg1Range](uint32_t index, QualType inType,
11545+
QualType outType,
1151811546
SpirvInstruction *arg0Row) {
1151911547
auto *arg1Row = spvBuilder.createCompositeExtract(
11520-
vecType, arg1Id, {index}, arg1Loc, arg1Range);
11521-
return spvBuilder.createBinaryOp(opcode, vecType, arg0Row, arg1Row, loc,
11548+
inType, arg1Id, {index}, arg1Loc, arg1Range);
11549+
return spvBuilder.createBinaryOp(opcode, outType, arg0Row, arg1Row, loc,
1152211550
range);
1152311551
};
1152411552
return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec, loc, range);
@@ -11546,9 +11574,10 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
1154611574
// instruction on each vector of the matrix.
1154711575
if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
1154811576
const auto actOnEachVec = [this, loc, range,
11549-
opcode](uint32_t /*index*/, QualType vecType,
11577+
opcode](uint32_t /*index*/, QualType inType,
11578+
QualType outType,
1155011579
SpirvInstruction *curRowInstr) {
11551-
return spvBuilder.createGLSLExtInst(vecType, opcode, {curRowInstr}, loc,
11580+
return spvBuilder.createGLSLExtInst(outType, opcode, {curRowInstr}, loc,
1155211581
range);
1155311582
};
1155411583
return processEachVectorInMatrix(arg, argInstr, actOnEachVec, loc, range);
@@ -11565,12 +11594,13 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
1156511594
// instruction on each vector of the matrix.
1156611595
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
1156711596
const auto actOnEachVec = [this, loc, range, opcode, arg1Instr, arg1Range,
11568-
arg1Loc](uint32_t index, QualType vecType,
11597+
arg1Loc](uint32_t index, QualType inType,
11598+
QualType outType,
1156911599
SpirvInstruction *arg0RowInstr) {
1157011600
auto *arg1RowInstr = spvBuilder.createCompositeExtract(
11571-
vecType, arg1Instr, {index}, arg1Loc, arg1Range);
11601+
inType, arg1Instr, {index}, arg1Loc, arg1Range);
1157211602
return spvBuilder.createGLSLExtInst(
11573-
vecType, opcode, {arg0RowInstr, arg1RowInstr}, loc, range);
11603+
outType, opcode, {arg0RowInstr, arg1RowInstr}, loc, range);
1157411604
};
1157511605
return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc,
1157611606
range);
@@ -11591,14 +11621,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
1159111621
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
1159211622
const auto actOnEachVec = [this, loc, range, opcode, arg1Instr, arg2Instr,
1159311623
arg1Loc, arg2Loc, arg1Range,
11594-
arg2Range](uint32_t index, QualType vecType,
11624+
arg2Range](uint32_t index, QualType inType,
11625+
QualType outType,
1159511626
SpirvInstruction *arg0RowInstr) {
1159611627
auto *arg1RowInstr = spvBuilder.createCompositeExtract(
11597-
vecType, arg1Instr, {index}, arg1Loc, arg1Range);
11628+
inType, arg1Instr, {index}, arg1Loc, arg1Range);
1159811629
auto *arg2RowInstr = spvBuilder.createCompositeExtract(
11599-
vecType, arg2Instr, {index}, arg2Loc, arg2Range);
11630+
inType, arg2Instr, {index}, arg2Loc, arg2Range);
1160011631
return spvBuilder.createGLSLExtInst(
11601-
vecType, opcode, {arg0RowInstr, arg1RowInstr, arg2RowInstr}, loc,
11632+
outType, opcode, {arg0RowInstr, arg1RowInstr, arg2RowInstr}, loc,
1160211633
range);
1160311634
};
1160411635
return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc,

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,14 @@ class SpirvEmitter : public ASTConsumer {
361361
/// the value. It returns the <result-id> of the processed vector.
362362
SpirvInstruction *processEachVectorInMatrix(
363363
const Expr *matrix, SpirvInstruction *matrixVal,
364-
llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
364+
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
365+
SpirvInstruction *)>
366+
actOnEachVector,
367+
SourceLocation loc = {}, SourceRange range = {});
368+
369+
SpirvInstruction *processEachVectorInMatrix(
370+
const Expr *matrix, QualType outputType, SpirvInstruction *matrixVal,
371+
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
365372
SpirvInstruction *)>
366373
actOnEachVector,
367374
SourceLocation loc = {}, SourceRange range = {});

0 commit comments

Comments
 (0)