@@ -1359,6 +1359,15 @@ SpirvInstruction *SpirvEmitter::castToType(SpirvInstruction *value,
1359
1359
QualType fromType, QualType toType,
1360
1360
SourceLocation srcLoc,
1361
1361
SourceRange range) {
1362
+ uint32_t fromSize = 0;
1363
+ uint32_t toSize = 0;
1364
+ assert(isVectorType(fromType, nullptr, &fromSize) ==
1365
+ isVectorType(toType, nullptr, &toSize) &&
1366
+ fromSize == toSize);
1367
+ // Avoid unused variable warning in release builds
1368
+ (void)(fromSize);
1369
+ (void)(toSize);
1370
+
1362
1371
if (isFloatOrVecMatOfFloatType(toType))
1363
1372
return castToFloat(value, fromType, toType, srcLoc, range);
1364
1373
@@ -2929,8 +2938,8 @@ SpirvInstruction *SpirvEmitter::getBaseOfMemberFunction(
2929
2938
SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
2930
2939
const FunctionDecl *callee = getCalleeDefinition(callExpr);
2931
2940
2932
- // Note that we always want the defintion because Stmts/Exprs in the
2933
- // function body references the parameters in the definition.
2941
+ // Note that we always want the definition because Stmts/Exprs in the
2942
+ // function body reference the parameters in the definition.
2934
2943
if (!callee) {
2935
2944
emitError("found undefined function", callExpr->getExprLoc());
2936
2945
return nullptr;
@@ -3031,7 +3040,7 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
3031
3040
const uint32_t argIndex = i + isOperatorOverloading;
3032
3041
3033
3042
// We want the argument variable here so that we can write back to it
3034
- // later. We will do the OpLoad of this argument manually. So ingore
3043
+ // later. We will do the OpLoad of this argument manually. So ignore
3035
3044
// the LValueToRValue implicit cast here.
3036
3045
auto *arg = callExpr->getArg(argIndex)->IgnoreParenLValueCasts();
3037
3046
const auto *param = callee->getParamDecl(i);
@@ -3112,9 +3121,16 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
3112
3121
// has returned.
3113
3122
if (canActAsOutParmVar(param) &&
3114
3123
!paramTypeMatchesArgType(paramType, arg->getType())) {
3115
- if (const auto *refType = paramType->getAs<ReferenceType>())
3116
- rhsVal = castToType(rhsVal, arg->getType(), refType->getPointeeType(),
3117
- arg->getLocStart(), rhsRange);
3124
+ if (const auto *refType = paramType->getAs<ReferenceType>()) {
3125
+ QualType toType = refType->getPointeeType();
3126
+ if (isScalarType(rhsVal->getAstResultType())) {
3127
+ rhsVal =
3128
+ splatScalarToGenerate(toType, rhsVal, SpirvLayoutRule::Void);
3129
+ } else {
3130
+ rhsVal = castToType(rhsVal, rhsVal->getAstResultType(), toType,
3131
+ arg->getLocStart(), rhsRange);
3132
+ }
3133
+ }
3118
3134
}
3119
3135
3120
3136
// Initialize the temporary variables using the contents of the arguments
@@ -3164,9 +3180,18 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
3164
3180
// mismatch, we need to first cast 'value' to the type of 'arg' because
3165
3181
// the AST will not include a cast node.
3166
3182
if (!paramTypeMatchesArgType(paramType, arg->getType())) {
3167
- if (const auto *refType = paramType->getAs<ReferenceType>())
3168
- value = castToType(value, refType->getPointeeType(), arg->getType(),
3169
- arg->getLocStart());
3183
+ if (const auto *refType = paramType->getAs<ReferenceType>()) {
3184
+ QualType elementType;
3185
+ QualType fromType = refType->getPointeeType();
3186
+ if (isVectorType(fromType, &elementType) &&
3187
+ isScalarType(arg->getType())) {
3188
+ value = spvBuilder.createCompositeExtract(
3189
+ elementType, value, {0}, value->getSourceLocation());
3190
+ fromType = elementType;
3191
+ }
3192
+ value =
3193
+ castToType(value, fromType, arg->getType(), arg->getLocStart());
3194
+ }
3170
3195
}
3171
3196
3172
3197
processAssignment(arg, value, false, args[index]);
@@ -14930,7 +14955,7 @@ SpirvEmitter::splatScalarToGenerate(QualType type, SpirvInstruction *scalar,
14930
14955
SourceLocation sourceLocation = scalar->getSourceLocation();
14931
14956
14932
14957
if (isScalarType(type)) {
14933
- // If the type if bool with a non-void layout rule, then it should be
14958
+ // If the type is bool with a non-void layout rule, then it should be
14934
14959
// treated as a uint.
14935
14960
assert(layoutRule == SpirvLayoutRule::Void &&
14936
14961
"If the layout type is not void, then we should cast to an int when "
0 commit comments