Skip to content

Commit 74ba845

Browse files
authored
Splat function argument (microsoft#6747)
When a scalar variable is passed as the argument to an inout vector parameter, then the scalar is suppose to be splatted. After returning from the function, we need to extract the first element from the parameter to store back into the scalar. Fixes microsoft#6568
1 parent 9c154fb commit 74ba845

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,15 @@ SpirvInstruction *SpirvEmitter::castToType(SpirvInstruction *value,
13591359
QualType fromType, QualType toType,
13601360
SourceLocation srcLoc,
13611361
SourceRange range) {
1362+
uint32_t fromSize = 0;
1363+
uint32_t toSize = 0;
1364+
assert(isVectorType(fromType, nullptr, &fromSize) ==
1365+
isVectorType(toType, nullptr, &toSize) &&
1366+
fromSize == toSize);
1367+
// Avoid unused variable warning in release builds
1368+
(void)(fromSize);
1369+
(void)(toSize);
1370+
13621371
if (isFloatOrVecMatOfFloatType(toType))
13631372
return castToFloat(value, fromType, toType, srcLoc, range);
13641373

@@ -2929,8 +2938,8 @@ SpirvInstruction *SpirvEmitter::getBaseOfMemberFunction(
29292938
SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
29302939
const FunctionDecl *callee = getCalleeDefinition(callExpr);
29312940

2932-
// Note that we always want the defintion because Stmts/Exprs in the
2933-
// function body references the parameters in the definition.
2941+
// Note that we always want the definition because Stmts/Exprs in the
2942+
// function body reference the parameters in the definition.
29342943
if (!callee) {
29352944
emitError("found undefined function", callExpr->getExprLoc());
29362945
return nullptr;
@@ -3031,7 +3040,7 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
30313040
const uint32_t argIndex = i + isOperatorOverloading;
30323041

30333042
// We want the argument variable here so that we can write back to it
3034-
// later. We will do the OpLoad of this argument manually. So ingore
3043+
// later. We will do the OpLoad of this argument manually. So ignore
30353044
// the LValueToRValue implicit cast here.
30363045
auto *arg = callExpr->getArg(argIndex)->IgnoreParenLValueCasts();
30373046
const auto *param = callee->getParamDecl(i);
@@ -3112,9 +3121,16 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
31123121
// has returned.
31133122
if (canActAsOutParmVar(param) &&
31143123
!paramTypeMatchesArgType(paramType, arg->getType())) {
3115-
if (const auto *refType = paramType->getAs<ReferenceType>())
3116-
rhsVal = castToType(rhsVal, arg->getType(), refType->getPointeeType(),
3117-
arg->getLocStart(), rhsRange);
3124+
if (const auto *refType = paramType->getAs<ReferenceType>()) {
3125+
QualType toType = refType->getPointeeType();
3126+
if (isScalarType(rhsVal->getAstResultType())) {
3127+
rhsVal =
3128+
splatScalarToGenerate(toType, rhsVal, SpirvLayoutRule::Void);
3129+
} else {
3130+
rhsVal = castToType(rhsVal, rhsVal->getAstResultType(), toType,
3131+
arg->getLocStart(), rhsRange);
3132+
}
3133+
}
31183134
}
31193135

31203136
// Initialize the temporary variables using the contents of the arguments
@@ -3164,9 +3180,18 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
31643180
// mismatch, we need to first cast 'value' to the type of 'arg' because
31653181
// the AST will not include a cast node.
31663182
if (!paramTypeMatchesArgType(paramType, arg->getType())) {
3167-
if (const auto *refType = paramType->getAs<ReferenceType>())
3168-
value = castToType(value, refType->getPointeeType(), arg->getType(),
3169-
arg->getLocStart());
3183+
if (const auto *refType = paramType->getAs<ReferenceType>()) {
3184+
QualType elementType;
3185+
QualType fromType = refType->getPointeeType();
3186+
if (isVectorType(fromType, &elementType) &&
3187+
isScalarType(arg->getType())) {
3188+
value = spvBuilder.createCompositeExtract(
3189+
elementType, value, {0}, value->getSourceLocation());
3190+
fromType = elementType;
3191+
}
3192+
value =
3193+
castToType(value, fromType, arg->getType(), arg->getLocStart());
3194+
}
31703195
}
31713196

31723197
processAssignment(arg, value, false, args[index]);
@@ -14930,7 +14955,7 @@ SpirvEmitter::splatScalarToGenerate(QualType type, SpirvInstruction *scalar,
1493014955
SourceLocation sourceLocation = scalar->getSourceLocation();
1493114956

1493214957
if (isScalarType(type)) {
14933-
// If the type if bool with a non-void layout rule, then it should be
14958+
// If the type is bool with a non-void layout rule, then it should be
1493414959
// treated as a uint.
1493514960
assert(layoutRule == SpirvLayoutRule::Void &&
1493614961
"If the layout type is not void, then we should cast to an int when "

tools/clang/test/CodeGenSPIRV/fn.param.inout.type-mismatch.hlsl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@ void foo(const half3 input, out half3 output) {
33
output = input;
44
}
55

6+
void bar( inout float3 p)
7+
{
8+
p += float3(1,1,1);
9+
}
10+
11+
612
float4 main() : SV_Target0 {
713
float3 output;
814
// CHECK: %param_var_input = OpVariable %_ptr_Function_v3half Function
@@ -17,7 +23,17 @@ float4 main() : SV_Target0 {
1723
// CHECK-NEXT: [[outputFloat3_0:%[0-9]+]] = OpFConvert %v3float [[outputHalf3_0]]
1824
// CHECK-NEXT: OpStore %output [[outputFloat3_0]]
1925

20-
// CHECK-NEXT: [[outputFloat3_1:%[0-9]+]] = OpLoad %v3float %output
26+
// CHECK: [[f:%[0-9]+]] = OpLoad %float %f
27+
// CHECK-NEXT: [[splat:%[0-9]+]] = OpCompositeConstruct %v3float [[f]] [[f]] [[f]]
28+
// CHECK-NEXT: OpStore %param_var_p [[splat]]
29+
// CHECK-NEXT: OpFunctionCall %void %bar %param_var_p
30+
// CHECK-NEXT: [[ret:%[0-9]+]] = OpLoad %v3float %param_var_p
31+
// CHECK-NEXT: [[ext:%[0-9]+]] = OpCompositeExtract %float [[ret]] 0
32+
// CHECK-NEXT: OpStore %f [[ext]]
33+
float f = 0;
34+
bar(f);
35+
36+
// CHECK: [[outputFloat3_1:%[0-9]+]] = OpLoad %v3float %output
2137
// CHECK-NEXT: OpCompositeExtract %float [[outputFloat3_2:%[0-9]+]] 0
2238
// CHECK-NEXT: OpCompositeExtract %float [[outputFloat3_3:%[0-9]+]] 1
2339
// CHECK-NEXT: OpCompositeExtract %float [[outputFloat3_4:%[0-9]+]] 2

0 commit comments

Comments
 (0)