Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/include/clang/AST/OperationKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ CAST_OPERATION(IntToOCLSampler)
// Truncate a vector type by dropping elements from the end (HLSL only).
CAST_OPERATION(HLSLVectorTruncation)

// Truncate a matrix type by dropping elements from the end (HLSL only).
CAST_OPERATION(HLSLMatrixTruncation)

// Non-decaying array RValue cast (HLSL only).
CAST_OPERATION(HLSLArrayRValue)

Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticGroups.td
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,7 @@ def SuperSubClassMismatch : DiagGroup<"super-class-method-mismatch">;
def OverridingMethodMismatch : DiagGroup<"overriding-method-mismatch">;
def VariadicMacros : DiagGroup<"variadic-macros">;
def VectorConversion : DiagGroup<"vector-conversion">; // clang specific
def MatrixConversion : DiagGroup<"matrix-conversion">; // clang specific
def VexingParse : DiagGroup<"vexing-parse">;
def VLAUseStaticAssert : DiagGroup<"vla-extension-static-assert">;
def VLACxxExtension : DiagGroup<"vla-cxx-extension", [VLAUseStaticAssert]>;
Expand Down Expand Up @@ -1335,6 +1336,8 @@ def : DiagGroup<"int-conversions",
[IntConversion]>; // -Wint-conversions = -Wint-conversion
def : DiagGroup<"vector-conversions",
[VectorConversion]>; // -Wvector-conversions = -Wvector-conversion
def : DiagGroup<"matrix-conversions",
[MatrixConversion]>; // -Wmatrix-conversions = -Wmatrix-conversion
def : DiagGroup<"unused-local-typedefs", [UnusedLocalTypedef]>;
// -Wunused-local-typedefs = -Wunused-local-typedef

Expand Down
8 changes: 7 additions & 1 deletion clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -4356,6 +4356,9 @@ def warn_param_typestate_mismatch : Warning<
def warn_unknown_sanitizer_ignored : Warning<
"unknown sanitizer '%0' ignored">, InGroup<UnknownSanitizers>;

def warn_impcast_matrix_scalar : Warning<
"implicit conversion turns matrix to scalar: %0 to %1">,
InGroup<MatrixConversion>;
def warn_impcast_vector_scalar : Warning<
"implicit conversion turns vector to scalar: %0 to %1">,
InGroup<Conversion>, DefaultIgnore;
Expand Down Expand Up @@ -13272,7 +13275,10 @@ def err_hlsl_builtin_scalar_vector_mismatch
"vector type with matching scalar element type%diff{: $ vs $|}2,3">;

def warn_hlsl_impcast_vector_truncation : Warning<
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
"implicit conversion truncates vector: %0 to %1">, InGroup<VectorConversion>;

def warn_hlsl_impcast_matrix_truncation : Warning<
"implicit conversion truncates matrix: %0 to %1">, InGroup<MatrixConversion>;

def warn_hlsl_availability : Warning<
"%0 is only available %select{|in %4 environment }3on %1 %2 or newer">,
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/Overload.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ class Sema;
/// HLSL vector truncation.
ICK_HLSL_Vector_Truncation,

/// HLSL Matrix truncation.
ICK_HLSL_Matrix_Truncation,

/// HLSL non-decaying array rvalue cast.
ICK_HLSL_Array_RValue,

Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,7 @@ bool CastExpr::CastConsistency() const {
case CK_FixedPointToBoolean:
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
CheckNoBasePath:
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11773,6 +11773,10 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
Elements.push_back(Val.getVectorElt(I));
return Success(Elements, E);
}
case CK_HLSLMatrixTruncation: {
// TODO: See #168935. Add matrix truncation support to expr constant.
return Error(E);
}
case CK_HLSLAggregateSplatCast: {
APValue Val;
QualType ValTy;
Expand Down Expand Up @@ -18163,6 +18167,10 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
case CK_HLSLMatrixTruncation: {
// TODO: See #168935. Add matrix truncation support to expr constant.
return Error(E);
}
case CK_HLSLElementwiseCast: {
SmallVector<APValue> SrcVals;
SmallVector<QualType> SrcTypes;
Expand Down Expand Up @@ -18756,6 +18764,10 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
case CK_HLSLMatrixTruncation: {
// TODO: See #168935. Add matrix truncation support to expr constant.
return Error(E);
}
case CK_HLSLElementwiseCast: {
SmallVector<APValue> SrcVals;
SmallVector<QualType> SrcTypes;
Expand Down Expand Up @@ -18913,6 +18925,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
llvm_unreachable("invalid cast kind for complex value");
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ Address CIRGenFunction::emitPointerWithAlignment(const Expr *expr,
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_IntToOCLSampler:
case CK_IntegralCast:
case CK_IntegralComplexCast:
Expand Down Expand Up @@ -1290,6 +1291,7 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ mlir::Value ComplexExprEmitter::emitCast(CastKind ck, Expr *op,
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ class ConstExprEmitter
case CK_MatrixCast:
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
return {};
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5734,6 +5734,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_ZeroToOCLOpaqueType:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:

case CK_HLSLMatrixTruncation:
case CK_IntToOCLSampler:
case CK_FloatingToFixedPoint:
case CK_FixedPointToFloating:
Expand Down Expand Up @@ -1550,6 +1550,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_NonAtomicToAtomic:
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
return true;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,7 @@ class ConstExprEmitter
case CK_ZeroToOCLOpaqueType:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
Expand Down
62 changes: 58 additions & 4 deletions clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2422,9 +2422,31 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
}
return V;
}
if (auto *MatTy = DestTy->getAs<ConstantMatrixType>()) {
assert(LoadList.size() >= MatTy->getNumElementsFlattened() &&
"Flattened type on RHS must have the same number or more elements "
"than vector on LHS.");

llvm::Value *V =
CGF.Builder.CreateLoad(CGF.CreateIRTemp(DestTy, "flatcast.tmp"));
// write to V.
for (unsigned I = 0, E = MatTy->getNumElementsFlattened(); I < E; I++) {
unsigned ColMajorIndex =
(I % MatTy->getNumRows()) * MatTy->getNumColumns() +
(I / MatTy->getNumRows());
RValue RVal = CGF.EmitLoadOfLValue(LoadList[ColMajorIndex], Loc);
assert(RVal.isScalar() &&
"All flattened source values should be scalars.");
llvm::Value *Cast = CGF.EmitScalarConversion(
RVal.getScalarVal(), LoadList[ColMajorIndex].getType(),
MatTy->getElementType(), Loc);
V = CGF.Builder.CreateInsertElement(V, Cast, I);
}
return V;
}
// if its a builtin just do an extract element or load.
assert(DestTy->isBuiltinType() &&
"Destination type must be a vector or builtin type.");
"Destination type must be a vector, matrix, or builtin type.");
RValue RVal = CGF.EmitLoadOfLValue(LoadList[0], Loc);
assert(RVal.isScalar() && "All flattened source values should be scalars.");
return CGF.EmitScalarConversion(RVal.getScalarVal(), LoadList[0].getType(),
Expand Down Expand Up @@ -2954,15 +2976,47 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
case CK_HLSLMatrixTruncation: {
assert((DestTy->isMatrixType() || DestTy->isBuiltinType()) &&
"Destination type must be a matrix or builtin type.");
Value *Mat = Visit(E);
if (auto *MatTy = DestTy->getAs<ConstantMatrixType>()) {
SmallVector<int> Mask;
unsigned NumCols = MatTy->getNumColumns();
unsigned NumRows = MatTy->getNumRows();
unsigned ColOffset = NumCols;
if (auto *SrcMatTy = E->getType()->getAs<ConstantMatrixType>())
ColOffset = SrcMatTy->getNumColumns();
for (unsigned R = 0; R < NumRows; R++) {
for (unsigned C = 0; C < NumCols; C++) {
unsigned I = R * ColOffset + C;
Mask.push_back(I);
}
}

return Builder.CreateShuffleVector(Mat, Mask, "trunc");
}
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Mat, Zero, "cast.mtrunc");
}
case CK_HLSLElementwiseCast: {
RValue RV = CGF.EmitAnyExpr(E);
SourceLocation Loc = CE->getExprLoc();

assert(RV.isAggregate() && "Not a valid HLSL Elementwise Cast.");
// RHS is an aggregate
LValue SrcVal = CGF.MakeAddrLValue(RV.getAggregateAddress(), E->getType());
Address SrcAddr = Address::invalid();

if (RV.isAggregate()) {
SrcAddr = RV.getAggregateAddress();
} else {
SrcAddr = CGF.CreateMemTemp(E->getType(), "hlsl.ewcast.src");
LValue TmpLV = CGF.MakeAddrLValue(SrcAddr, E->getType());
CGF.EmitStoreThroughLValue(RV, TmpLV);
}

LValue SrcVal = CGF.MakeAddrLValue(SrcAddr, E->getType());
return EmitHLSLElementwiseCast(CGF, SrcVal, DestTy, Loc);
}

} // end of switch

llvm_unreachable("unknown scalar cast");
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Edit/RewriteObjCFoundationAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
llvm_unreachable("OpenCL-specific cast in Objective-C?");

case CK_HLSLVectorTruncation:
case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
llvm_unreachable("HLSL-specific cast in Objective-C?");
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
#include "clang/AST/TypeBase.h"
#include "clang/AST/TypeLoc.h"
#include "clang/AST/UnresolvedSet.h"
#include "clang/Basic/AddressSpaces.h"
Expand Down Expand Up @@ -12589,6 +12590,19 @@ void Sema::CheckImplicitConversion(Expr *E, QualType T, SourceLocation CC,
if (auto VecTy = dyn_cast<VectorType>(Target))
Target = VecTy->getElementType().getTypePtr();

if (isa<ConstantMatrixType>(Source)) {
if (!isa<ConstantMatrixType>(Target)) {
return DiagnoseImpCast(*this, E, T, CC, diag::warn_impcast_matrix_scalar);
} else if (getLangOpts().HLSL &&
Target->castAs<ConstantMatrixType>()->getNumElementsFlattened() <
Source->castAs<ConstantMatrixType>()
->getNumElementsFlattened()) {
// Diagnose Matrix truncation but don't return. We may also want to
// diagnose an element conversion.
DiagnoseImpCast(*this, E, T, CC,
diag::warn_hlsl_impcast_matrix_truncation);
}
}
// Strip complex types.
if (isa<ComplexType>(Source)) {
if (!isa<ComplexType>(Target)) {
Expand Down
22 changes: 16 additions & 6 deletions clang/lib/Sema/SemaExprCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5196,19 +5196,18 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
case ICK_Incompatible_Pointer_Conversion:
case ICK_HLSL_Array_RValue:
case ICK_HLSL_Vector_Truncation:
case ICK_HLSL_Matrix_Truncation:
case ICK_HLSL_Vector_Splat:
llvm_unreachable("Improper second standard conversion");
}

if (SCS.Dimension != ICK_Identity) {
// If SCS.Element is not ICK_Identity the To and From types must be HLSL
// vectors or matrices.

// TODO: Support HLSL matrices.
assert((!From->getType()->isMatrixType() && !ToType->isMatrixType()) &&
"Dimension conversion for matrix types is not implemented yet.");
assert((ToType->isVectorType() || ToType->isBuiltinType()) &&
"Dimension conversion output must be vector or scalar type.");
assert(
(ToType->isVectorType() || ToType->isConstantMatrixType() ||
ToType->isBuiltinType()) &&
"Dimension conversion output must be vector, matrix, or scalar type.");
switch (SCS.Dimension) {
case ICK_HLSL_Vector_Splat: {
// Vector splat from any arithmetic type to a vector.
Expand All @@ -5234,6 +5233,17 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,

break;
}
case ICK_HLSL_Matrix_Truncation: {
auto *FromMat = From->getType()->castAs<ConstantMatrixType>();
QualType TruncTy = FromMat->getElementType();
if (auto *ToMat = ToType->getAs<ConstantMatrixType>())
TruncTy = Context.getConstantMatrixType(TruncTy, ToMat->getNumRows(),
ToMat->getNumColumns());
From = ImpCastExprToType(From, TruncTy, CK_HLSLMatrixTruncation,
From->getValueKind())
.get();
break;
}
case ICK_Identity:
default:
llvm_unreachable("Improper element standard conversion");
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3721,7 +3721,6 @@ bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
}

// Can we perform an HLSL Elementwise cast?
// TODO: update this code when matrices are added; see issue #88060
bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {

// Don't handle casts where LHS and RHS are any combination of scalar/vector
Expand All @@ -3734,6 +3733,10 @@ bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
(DestTy->isScalarType() || DestTy->isVectorType()))
return false;

if (SrcTy->isConstantMatrixType() &&
(DestTy->isScalarType() || DestTy->isConstantMatrixType()))
return false;

llvm::SmallVector<QualType> DestTypes;
BuildFlattenedTypeList(DestTy, DestTypes);
llvm::SmallVector<QualType> SrcTypes;
Expand Down
Loading
Loading