Skip to content

Commit 9b04d69

Browse files
authored
[SPIRV] Cast derivative opts to 32-bits. (#7445)
The SPIR-V operations require 32-bit floats. Smaller float type can be cast to 32-bits to perform the operation. The FE already emits a warning for 64-bits. Fixes #7431
1 parent 8b406b5 commit 9b04d69

File tree

4 files changed

+137
-6
lines changed

4 files changed

+137
-6
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9484,12 +9484,17 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
94849484
retVal = processIntrinsicPointerCast(callExpr, true);
94859485
break;
94869486
}
9487-
INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
9488-
INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
9489-
INTRINSIC_SPIRV_OP_CASE(ddx_fine, DPdxFine, false);
9490-
INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
9491-
INTRINSIC_SPIRV_OP_CASE(ddy_coarse, DPdyCoarse, false);
9492-
INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
9487+
case hlsl::IntrinsicOp::IOP_ddx:
9488+
case hlsl::IntrinsicOp::IOP_ddx_coarse:
9489+
case hlsl::IntrinsicOp::IOP_ddx_fine:
9490+
case hlsl::IntrinsicOp::IOP_ddy:
9491+
case hlsl::IntrinsicOp::IOP_ddy_coarse:
9492+
case hlsl::IntrinsicOp::IOP_ddy_fine: {
9493+
retVal = processDerivativeIntrinsic(hlslOpcode, callExpr->getArg(0),
9494+
callExpr->getExprLoc(),
9495+
callExpr->getSourceRange());
9496+
break;
9497+
}
94939498
INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
94949499
INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
94959500
INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
@@ -9572,6 +9577,77 @@ SpirvEmitter::processIntrinsicFirstbit(const CallExpr *callExpr,
95729577
srcRange);
95739578
}
95749579

9580+
SpirvInstruction *SpirvEmitter::processMatrixDerivativeIntrinsic(
9581+
hlsl::IntrinsicOp hlslOpcode, const Expr *arg, SourceLocation loc,
9582+
SourceRange range) {
9583+
const auto actOnEachVec = [this, hlslOpcode, loc, range](
9584+
uint32_t /*index*/, QualType inType,
9585+
QualType outType, SpirvInstruction *curRow) {
9586+
return processDerivativeIntrinsic(hlslOpcode, curRow, loc, range);
9587+
};
9588+
9589+
return processEachVectorInMatrix(arg, arg->getType(), doExpr(arg),
9590+
actOnEachVec, loc, range);
9591+
}
9592+
9593+
SpirvInstruction *
9594+
SpirvEmitter::processDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
9595+
const Expr *arg, SourceLocation loc,
9596+
SourceRange range) {
9597+
if (isMxNMatrix(arg->getType())) {
9598+
return processMatrixDerivativeIntrinsic(hlslOpcode, arg, loc, range);
9599+
}
9600+
return processDerivativeIntrinsic(hlslOpcode, doExpr(arg), loc, range);
9601+
}
9602+
9603+
SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic(
9604+
hlsl::IntrinsicOp hlslOpcode, SpirvInstruction *arg, SourceLocation loc,
9605+
SourceRange range) {
9606+
QualType returnType = arg->getAstResultType();
9607+
assert(isFloatOrVecOfFloatType(returnType));
9608+
9609+
if (!spvContext.isPS())
9610+
addDerivativeGroupExecutionMode();
9611+
needsLegalization = true;
9612+
9613+
QualType B32Type = astContext.FloatTy;
9614+
uint32_t vectorSize = 0;
9615+
QualType elementType = returnType;
9616+
if (isVectorType(returnType, &elementType, &vectorSize)) {
9617+
B32Type = astContext.getExtVectorType(B32Type, vectorSize);
9618+
}
9619+
9620+
// Derivative operations work on 32-bit floats only. Cast to 32-bit if needed.
9621+
SpirvInstruction *operand = castToType(arg, returnType, B32Type, loc, range);
9622+
9623+
spv::Op opcode = spv::Op::OpNop;
9624+
switch (hlslOpcode) {
9625+
case hlsl::IntrinsicOp::IOP_ddx:
9626+
opcode = spv::Op::OpDPdx;
9627+
break;
9628+
case hlsl::IntrinsicOp::IOP_ddx_coarse:
9629+
opcode = spv::Op::OpDPdxCoarse;
9630+
break;
9631+
case hlsl::IntrinsicOp::IOP_ddx_fine:
9632+
opcode = spv::Op::OpDPdxFine;
9633+
break;
9634+
case hlsl::IntrinsicOp::IOP_ddy:
9635+
opcode = spv::Op::OpDPdy;
9636+
break;
9637+
case hlsl::IntrinsicOp::IOP_ddy_coarse:
9638+
opcode = spv::Op::OpDPdyCoarse;
9639+
break;
9640+
case hlsl::IntrinsicOp::IOP_ddy_fine:
9641+
opcode = spv::Op::OpDPdyFine;
9642+
break;
9643+
};
9644+
9645+
SpirvInstruction *result =
9646+
spvBuilder.createUnaryOp(opcode, B32Type, operand, loc, range);
9647+
result = castToType(result, B32Type, returnType, loc, range);
9648+
return result;
9649+
}
9650+
95759651
// Returns true is the given expression can be used as an output parameter.
95769652
//
95779653
// Warning: this function could return false negatives.

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,21 @@ class SpirvEmitter : public ASTConsumer {
789789
SpirvInstruction *processIntrinsicFirstbit(const CallExpr *,
790790
GLSLstd450 glslOpcode);
791791

792+
SpirvInstruction *
793+
processMatrixDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
794+
const Expr *arg, SourceLocation loc,
795+
SourceRange range);
796+
797+
SpirvInstruction *processDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
798+
const Expr *arg,
799+
SourceLocation loc,
800+
SourceRange range);
801+
802+
SpirvInstruction *processDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
803+
SpirvInstruction *arg,
804+
SourceLocation loc,
805+
SourceRange range);
806+
792807
private:
793808
/// Returns the <result-id> for constant value 0 of the given type.
794809
SpirvConstant *getValueZero(QualType type);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %dxc -T ps_6_2 -E main -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// CHECK: :14:22: warning: conversion from larger type 'double' to smaller type 'float', possible loss of data [-Wconversion]
4+
// CHECK: :20:22: warning: conversion from larger type 'double2' to smaller type 'vector<float, 2>', possible loss of data [-Wconversion]
5+
6+
void main() {
7+
double a;
8+
double2 b;
9+
10+
// CHECK: [[a:%[0-9]+]] = OpLoad %double %a
11+
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %float [[a]]
12+
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %float [[c]]
13+
// CHECK-NEXT: OpFConvert %double [[r]]
14+
double da = ddx(a);
15+
16+
// CHECK: [[b:%[0-9]+]] = OpLoad %v2double %b
17+
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %v2float [[b]]
18+
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %v2float [[c]]
19+
// CHECK-NEXT: OpFConvert %v2double [[r]]
20+
double2 db = ddx(b);
21+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %dxc -T ps_6_2 -E main -enable-16bit-types -fcgl %s -spirv | FileCheck %s
2+
3+
void main() {
4+
5+
half a;
6+
half2 b;
7+
8+
// CHECK: [[a:%[0-9]+]] = OpLoad %half %a
9+
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %float [[a]]
10+
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %float [[c]]
11+
// CHECK-NEXT: OpFConvert %half [[r]]
12+
half da = ddx(a);
13+
14+
// CHECK: [[b:%[0-9]+]] = OpLoad %v2half %b
15+
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %v2float [[b]]
16+
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %v2float [[c]]
17+
// CHECK-NEXT: OpFConvert %v2half [[r]]
18+
half2 db = ddx(b);
19+
}

0 commit comments

Comments
 (0)