@@ -6286,9 +6286,10 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
6286
6286
if (isMxNMatrix(subType)) {
6287
6287
// For matrices, we can only increment/decrement each vector of it.
6288
6288
const auto actOnEachVec = [this, spvOp, one, expr,
6289
- range](uint32_t /*index*/, QualType vecType,
6289
+ range](uint32_t /*index*/, QualType inType,
6290
+ QualType outType,
6290
6291
SpirvInstruction *lhsVec) {
6291
- auto *val = spvBuilder.createBinaryOp(spvOp, vecType , lhsVec, one,
6292
+ auto *val = spvBuilder.createBinaryOp(spvOp, outType , lhsVec, one,
6292
6293
expr->getOperatorLoc(), range);
6293
6294
if (val)
6294
6295
val->setRValue();
@@ -6356,9 +6357,10 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
6356
6357
if (isMxNMatrix(subType)) {
6357
6358
// For matrices, we can only negate each vector of it.
6358
6359
const auto actOnEachVec = [this, spvOp, expr,
6359
- range](uint32_t /*index*/, QualType vecType,
6360
+ range](uint32_t /*index*/, QualType inType,
6361
+ QualType outType,
6360
6362
SpirvInstruction *lhsVec) {
6361
- return spvBuilder.createUnaryOp(spvOp, vecType , lhsVec,
6363
+ return spvBuilder.createUnaryOp(spvOp, outType , lhsVec,
6362
6364
expr->getOperatorLoc(), range);
6363
6365
};
6364
6366
return processEachVectorInMatrix(subExpr, subValue, actOnEachVec,
@@ -7929,27 +7931,39 @@ void SpirvEmitter::assignToMSOutIndices(
7929
7931
7930
7932
SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
7931
7933
const Expr *matrix, SpirvInstruction *matrixVal,
7932
- llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
7934
+ llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
7935
+ SpirvInstruction *)>
7936
+ actOnEachVector,
7937
+ SourceLocation loc, SourceRange range) {
7938
+ return processEachVectorInMatrix(matrix, matrix->getType(), matrixVal,
7939
+ actOnEachVector, loc, range);
7940
+ }
7941
+
7942
+ SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
7943
+ const Expr *matrix, QualType outputType, SpirvInstruction *matrixVal,
7944
+ llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
7933
7945
SpirvInstruction *)>
7934
7946
actOnEachVector,
7935
7947
SourceLocation loc, SourceRange range) {
7936
7948
const auto matType = matrix->getType();
7937
- assert(isMxNMatrix(matType));
7938
- const QualType vecType = getComponentVectorType(astContext, matType);
7949
+ assert(isMxNMatrix(matType) && isMxNMatrix(outputType));
7950
+ const QualType inVecType = getComponentVectorType(astContext, matType);
7951
+ const QualType outVecType = getComponentVectorType(astContext, outputType);
7939
7952
7940
7953
uint32_t rowCount = 0, colCount = 0;
7941
7954
hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
7942
7955
7943
7956
llvm::SmallVector<SpirvInstruction *, 4> vectors;
7944
7957
// Extract each component vector and do operation on it
7945
7958
for (uint32_t i = 0; i < rowCount; ++i) {
7946
- auto *lhsVec = spvBuilder.createCompositeExtract(vecType , matrixVal, {i},
7959
+ auto *lhsVec = spvBuilder.createCompositeExtract(inVecType , matrixVal, {i},
7947
7960
matrix->getLocStart());
7948
- vectors.push_back(actOnEachVector(i, vecType , lhsVec));
7961
+ vectors.push_back(actOnEachVector(i, inVecType, outVecType , lhsVec));
7949
7962
}
7950
7963
7951
7964
// Construct the result matrix
7952
- auto *val = spvBuilder.createCompositeConstruct(matType, vectors, loc, range);
7965
+ auto *val =
7966
+ spvBuilder.createCompositeConstruct(outputType, vectors, loc, range);
7953
7967
if (!val)
7954
7968
return nullptr;
7955
7969
val->setRValue();
@@ -8056,15 +8070,15 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
8056
8070
case BO_MulAssign:
8057
8071
case BO_DivAssign:
8058
8072
case BO_RemAssign: {
8059
- const auto actOnEachVec = [this, spvOp, rhsVal, rhs, loc,
8060
- range]( uint32_t index, QualType vecType ,
8061
- SpirvInstruction *lhsVec) {
8073
+ const auto actOnEachVec = [this, spvOp, rhsVal, rhs, loc, range](
8074
+ uint32_t index, QualType inType ,
8075
+ QualType outType, SpirvInstruction *lhsVec) {
8062
8076
// For each vector of lhs, we need to load the corresponding vector of
8063
8077
// rhs and do the operation on them.
8064
- auto *rhsVec = spvBuilder.createCompositeExtract(vecType , rhsVal, {index},
8078
+ auto *rhsVec = spvBuilder.createCompositeExtract(inType , rhsVal, {index},
8065
8079
rhs->getLocStart());
8066
8080
auto *val =
8067
- spvBuilder.createBinaryOp(spvOp, vecType , lhsVec, rhsVec, loc, range);
8081
+ spvBuilder.createBinaryOp(spvOp, outType , lhsVec, rhsVec, loc, range);
8068
8082
if (val)
8069
8083
val->setRValue();
8070
8084
return val;
@@ -9066,6 +9080,15 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
9066
9080
case hlsl::IntrinsicOp::IOP_firstbitlow: {
9067
9081
retVal = processIntrinsicFirstbit(callExpr, GLSLstd450::GLSLstd450FindILsb);
9068
9082
break;
9083
+ }
9084
+ case hlsl::IntrinsicOp::IOP_isnan: {
9085
+ retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpIsNan,
9086
+ /* doEachVec= */ true);
9087
+ // OpIsNan returns a bool/vec<bool>, so the only valid layout is void. It
9088
+ // will be the responsibility of the store to do an OpSelect and correctly
9089
+ // convert this type to an externally storable type.
9090
+ retVal->setLayoutRule(SpirvLayoutRule::Void);
9091
+ break;
9069
9092
}
9070
9093
INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
9071
9094
INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
@@ -9075,7 +9098,6 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
9075
9098
INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
9076
9099
INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
9077
9100
INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
9078
- INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
9079
9101
INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
9080
9102
INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
9081
9103
INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
@@ -10030,14 +10052,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMad(const CallExpr *callExpr) {
10030
10052
if (isMxNMatrix(arg0->getType())) {
10031
10053
const auto actOnEachVec = [this, loc, arg1Instr, arg2Instr, arg1Loc,
10032
10054
arg2Loc,
10033
- range](uint32_t index, QualType vecType,
10055
+ range](uint32_t index, QualType inType,
10056
+ QualType outType,
10034
10057
SpirvInstruction *arg0Row) {
10035
10058
auto *arg1Row = spvBuilder.createCompositeExtract(
10036
- vecType , arg1Instr, {index}, arg1Loc, range);
10059
+ inType , arg1Instr, {index}, arg1Loc, range);
10037
10060
auto *arg2Row = spvBuilder.createCompositeExtract(
10038
- vecType , arg2Instr, {index}, arg2Loc, range);
10061
+ inType , arg2Instr, {index}, arg2Loc, range);
10039
10062
auto *fma = spvBuilder.createGLSLExtInst(
10040
- vecType , GLSLstd450Fma, {arg0Row, arg1Row, arg2Row}, loc, range);
10063
+ outType , GLSLstd450Fma, {arg0Row, arg1Row, arg2Row}, loc, range);
10041
10064
spvBuilder.decorateNoContraction(fma, loc);
10042
10065
return fma;
10043
10066
};
@@ -10257,13 +10280,14 @@ SpirvEmitter::processIntrinsicLdexp(const CallExpr *callExpr) {
10257
10280
uint32_t rowCount = 0, colCount = 0;
10258
10281
if (isMxNMatrix(paramType, nullptr, &rowCount, &colCount)) {
10259
10282
const auto actOnEachVec = [this, loc, expInstr, arg1Loc,
10260
- range](uint32_t index, QualType vecType,
10283
+ range](uint32_t index, QualType inType,
10284
+ QualType outType,
10261
10285
SpirvInstruction *xRowInstr) {
10262
10286
auto *expRowInstr = spvBuilder.createCompositeExtract(
10263
- vecType , expInstr, {index}, arg1Loc, range);
10287
+ inType , expInstr, {index}, arg1Loc, range);
10264
10288
auto *twoExp = spvBuilder.createGLSLExtInst(
10265
- vecType , GLSLstd450::GLSLstd450Exp2, {expRowInstr}, loc, range);
10266
- return spvBuilder.createBinaryOp(spv::Op::OpFMul, vecType , xRowInstr,
10289
+ outType , GLSLstd450::GLSLstd450Exp2, {expRowInstr}, loc, range);
10290
+ return spvBuilder.createBinaryOp(spv::Op::OpFMul, outType , xRowInstr,
10267
10291
twoExp, loc, range);
10268
10292
};
10269
10293
return processEachVectorInMatrix(x, xInstr, actOnEachVec, loc, range);
@@ -10427,15 +10451,15 @@ SpirvEmitter::processIntrinsicClamp(const CallExpr *callExpr) {
10427
10451
// the operation on each vector of the matrix.
10428
10452
if (isMxNMatrix(argX->getType())) {
10429
10453
const auto actOnEachVec = [this, loc, range, glslOpcode, argMinInstr,
10430
- argMaxInstr, argMinLoc,
10431
- argMaxLoc]( uint32_t index, QualType vecType ,
10432
- SpirvInstruction *curRow) {
10454
+ argMaxInstr, argMinLoc, argMaxLoc](
10455
+ uint32_t index, QualType inType ,
10456
+ QualType outType, SpirvInstruction *curRow) {
10433
10457
auto *minRowInstr = spvBuilder.createCompositeExtract(
10434
- vecType , argMinInstr, {index}, argMinLoc, range);
10458
+ inType , argMinInstr, {index}, argMinLoc, range);
10435
10459
auto *maxRowInstr = spvBuilder.createCompositeExtract(
10436
- vecType , argMaxInstr, {index}, argMaxLoc, range);
10460
+ inType , argMaxInstr, {index}, argMaxLoc, range);
10437
10461
return spvBuilder.createGLSLExtInst(
10438
- vecType , glslOpcode, {curRow, minRowInstr, maxRowInstr}, loc, range);
10462
+ outType , glslOpcode, {curRow, minRowInstr, maxRowInstr}, loc, range);
10439
10463
};
10440
10464
return processEachVectorInMatrix(argX, argXInstr, actOnEachVec, loc, range);
10441
10465
}
@@ -11013,12 +11037,12 @@ SpirvInstruction *SpirvEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
11013
11037
uint32_t numRows = 0, numCols = 0;
11014
11038
if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
11015
11039
auto *vecOne = getVecValueOne(elemType, numCols);
11016
- const auto actOnEachVec = [this, vecOne, loc,
11017
- range](uint32_t /*index*/, QualType vecType ,
11018
- SpirvInstruction *curRow) {
11019
- return spvBuilder.createBinaryOp(spv::Op::OpFDiv, vecType , vecOne, curRow ,
11020
- loc, range);
11021
- };
11040
+ const auto actOnEachVec =
11041
+ [this, vecOne, loc, range](uint32_t /*index*/, QualType inType ,
11042
+ QualType outType, SpirvInstruction *curRow) {
11043
+ return spvBuilder.createBinaryOp(spv::Op::OpFDiv, outType , vecOne,
11044
+ curRow, loc, range);
11045
+ };
11022
11046
return processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
11023
11047
}
11024
11048
@@ -11335,10 +11359,10 @@ SpirvEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
11335
11359
if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
11336
11360
auto *vecZero = getVecValueZero(elemType, numCols);
11337
11361
auto *vecOne = getVecValueOne(elemType, numCols);
11338
- const auto actOnEachVec = [this, loc, vecZero, vecOne,
11339
- range]( uint32_t /*index*/, QualType vecType ,
11340
- SpirvInstruction *curRow) {
11341
- return spvBuilder.createGLSLExtInst(vecType , GLSLstd450::GLSLstd450FClamp,
11362
+ const auto actOnEachVec = [this, loc, vecZero, vecOne, range](
11363
+ uint32_t /*index*/, QualType inType ,
11364
+ QualType outType, SpirvInstruction *curRow) {
11365
+ return spvBuilder.createGLSLExtInst(outType , GLSLstd450::GLSLstd450FClamp,
11342
11366
{curRow, vecZero, vecOne}, loc,
11343
11367
range);
11344
11368
};
@@ -11364,10 +11388,10 @@ SpirvEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
11364
11388
11365
11389
// For matrices, we can perform the instruction on each vector of the matrix.
11366
11390
if (isMxNMatrix(argType)) {
11367
- const auto actOnEachVec = [this, loc, range](uint32_t /*index*/,
11368
- QualType vecType ,
11369
- SpirvInstruction *curRow) {
11370
- return spvBuilder.createGLSLExtInst(vecType , GLSLstd450::GLSLstd450FSign,
11391
+ const auto actOnEachVec = [this, loc, range](
11392
+ uint32_t /*index*/, QualType inType ,
11393
+ QualType outType, SpirvInstruction *curRow) {
11394
+ return spvBuilder.createGLSLExtInst(outType , GLSLstd450::GLSLstd450FSign,
11371
11395
{curRow}, loc, range);
11372
11396
};
11373
11397
floatSign = processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
@@ -11496,12 +11520,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
11496
11520
// If the instruction does not operate on matrices, we can perform the
11497
11521
// instruction on each vector of the matrix.
11498
11522
if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
11523
+ assert(isMxNMatrix(returnType));
11499
11524
const auto actOnEachVec = [this, opcode, loc,
11500
- range](uint32_t /*index*/, QualType vecType,
11525
+ range](uint32_t /*index*/, QualType inType,
11526
+ QualType outType,
11501
11527
SpirvInstruction *curRow) {
11502
- return spvBuilder.createUnaryOp(opcode, vecType , curRow, loc, range);
11528
+ return spvBuilder.createUnaryOp(opcode, outType , curRow, loc, range);
11503
11529
};
11504
- return processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
11530
+ return processEachVectorInMatrix(arg, returnType, argId, actOnEachVec,
11531
+ loc, range);
11505
11532
}
11506
11533
return spvBuilder.createUnaryOp(opcode, returnType, argId, loc, range);
11507
11534
} else if (callExpr->getNumArgs() == 2u) {
@@ -11514,11 +11541,12 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
11514
11541
// instruction on each vector of the matrix.
11515
11542
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
11516
11543
const auto actOnEachVec = [this, opcode, arg1Id, loc, range, arg1Loc,
11517
- arg1Range](uint32_t index, QualType vecType,
11544
+ arg1Range](uint32_t index, QualType inType,
11545
+ QualType outType,
11518
11546
SpirvInstruction *arg0Row) {
11519
11547
auto *arg1Row = spvBuilder.createCompositeExtract(
11520
- vecType , arg1Id, {index}, arg1Loc, arg1Range);
11521
- return spvBuilder.createBinaryOp(opcode, vecType , arg0Row, arg1Row, loc,
11548
+ inType , arg1Id, {index}, arg1Loc, arg1Range);
11549
+ return spvBuilder.createBinaryOp(opcode, outType , arg0Row, arg1Row, loc,
11522
11550
range);
11523
11551
};
11524
11552
return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec, loc, range);
@@ -11546,9 +11574,10 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
11546
11574
// instruction on each vector of the matrix.
11547
11575
if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
11548
11576
const auto actOnEachVec = [this, loc, range,
11549
- opcode](uint32_t /*index*/, QualType vecType,
11577
+ opcode](uint32_t /*index*/, QualType inType,
11578
+ QualType outType,
11550
11579
SpirvInstruction *curRowInstr) {
11551
- return spvBuilder.createGLSLExtInst(vecType , opcode, {curRowInstr}, loc,
11580
+ return spvBuilder.createGLSLExtInst(outType , opcode, {curRowInstr}, loc,
11552
11581
range);
11553
11582
};
11554
11583
return processEachVectorInMatrix(arg, argInstr, actOnEachVec, loc, range);
@@ -11565,12 +11594,13 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
11565
11594
// instruction on each vector of the matrix.
11566
11595
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
11567
11596
const auto actOnEachVec = [this, loc, range, opcode, arg1Instr, arg1Range,
11568
- arg1Loc](uint32_t index, QualType vecType,
11597
+ arg1Loc](uint32_t index, QualType inType,
11598
+ QualType outType,
11569
11599
SpirvInstruction *arg0RowInstr) {
11570
11600
auto *arg1RowInstr = spvBuilder.createCompositeExtract(
11571
- vecType , arg1Instr, {index}, arg1Loc, arg1Range);
11601
+ inType , arg1Instr, {index}, arg1Loc, arg1Range);
11572
11602
return spvBuilder.createGLSLExtInst(
11573
- vecType , opcode, {arg0RowInstr, arg1RowInstr}, loc, range);
11603
+ outType , opcode, {arg0RowInstr, arg1RowInstr}, loc, range);
11574
11604
};
11575
11605
return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc,
11576
11606
range);
@@ -11591,14 +11621,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
11591
11621
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
11592
11622
const auto actOnEachVec = [this, loc, range, opcode, arg1Instr, arg2Instr,
11593
11623
arg1Loc, arg2Loc, arg1Range,
11594
- arg2Range](uint32_t index, QualType vecType,
11624
+ arg2Range](uint32_t index, QualType inType,
11625
+ QualType outType,
11595
11626
SpirvInstruction *arg0RowInstr) {
11596
11627
auto *arg1RowInstr = spvBuilder.createCompositeExtract(
11597
- vecType , arg1Instr, {index}, arg1Loc, arg1Range);
11628
+ inType , arg1Instr, {index}, arg1Loc, arg1Range);
11598
11629
auto *arg2RowInstr = spvBuilder.createCompositeExtract(
11599
- vecType , arg2Instr, {index}, arg2Loc, arg2Range);
11630
+ inType , arg2Instr, {index}, arg2Loc, arg2Range);
11600
11631
return spvBuilder.createGLSLExtInst(
11601
- vecType , opcode, {arg0RowInstr, arg1RowInstr, arg2RowInstr}, loc,
11632
+ outType , opcode, {arg0RowInstr, arg1RowInstr, arg2RowInstr}, loc,
11602
11633
range);
11603
11634
};
11604
11635
return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc,
0 commit comments