diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 136ea47451fed..99c62808c323d 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -368,20 +368,12 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, "Scalar dot product is only supported on ints and floats."); } // For vectors, validate types and emit the appropriate intrinsic - - // A VectorSplat should have happened - assert(T0->isVectorTy() && T1->isVectorTy() && - "Dot product of vector and scalar is not supported."); + assert(CGM.getContext().hasSameUnqualifiedType(E->getArg(0)->getType(), + E->getArg(1)->getType()) && + "Dot product operands must have the same type."); auto *VecTy0 = E->getArg(0)->getType()->castAs(); - [[maybe_unused]] auto *VecTy1 = - E->getArg(1)->getType()->castAs(); - - assert(VecTy0->getElementType() == VecTy1->getElementType() && - "Dot product of vectors need the same element types."); - - assert(VecTy0->getNumElements() == VecTy1->getNumElements() && - "Dot product requires vectors to be of the same size."); + assert(VecTy0 && "Dot product argument must be a vector."); return Builder.CreateIntrinsic( /*ReturnType=*/T0->getScalarType(), diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 07d03e2c58b9a..fe600386e6fa9 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2015,7 +2015,8 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { } if (VecTyA && VecTyB) { bool retValue = false; - if (VecTyA->getElementType() != VecTyB->getElementType()) { + if (!S->Context.hasSameUnqualifiedType(VecTyA->getElementType(), + VecTyB->getElementType())) { // Note: type promotion is intended to be handeled via the intrinsics // and not the builtin itself. S->Diag(TheCall->getBeginLoc(), diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl index 7064066780da6..c1fdb0740adc3 100644 --- a/clang/test/CodeGenHLSL/builtins/dot.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -158,3 +158,11 @@ float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); } // CHECK: %hlsl.dot = fmul reassoc nnan ninf nsz arcp afn double // CHECK: ret double %hlsl.dot double test_dot_double(double p0, double p1) { return dot(p0, p1); } + +// CHECK-LABEL: test_dot_literal +// CHECK: [[X:%.*]] = shufflevector <1 x i32> {{.*}}, <1 x i32> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: %hlsl.dot = call i32 @llvm.[[ICF]].udot.v4i32(<4 x i32> {{.*}}, <4 x i32> [[X]]) +// CHECK-NEXT: [[S1:%.*]] = insertelement <4 x i32> poison, i32 %hlsl.dot, i64 0 +// CHECK-NEXT: [[S2:%.*]] = shufflevector <4 x i32> [[S1]], <4 x i32> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: ret <4 x i32> [[S2]] +uint4 test_dot_literal( uint4 p0) { return dot(p0, 1u.xxxx); }