Skip to content

Conversation

@spall
Copy link
Contributor

@spall spall commented Oct 22, 2025

Add support to handle these casts in the constant expression evaluator.

  • HLSLAggregateSplatCast
  • HLSLElementwiseCast
  • HLSLArrayRValue

Add tests
Closes #125766
Closes #125321

Constant expression evaluator. Add tests. Fix/Add support for
other minor necessary things.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Oct 22, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

Changes

Add support to handle these casts in the constant expression evaluator.

  • HLSLAggregateSplatCast
  • HLSLElementwiseCast
  • HLSLArrayRValue

Add tests


Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164700.diff

4 Files Affected:

  • (modified) clang/lib/AST/ExprConstant.cpp (+586-1)
  • (added) clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl (+89)
  • (modified) clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl (+21)
  • (added) clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl (+76)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
       /*Diag=*/true);
 }
 
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+                             QualType SourceTy, QualType DestTy,
+                             APValue const &Original, APValue &Result) {
+  // boolean must be checked before integer
+  // since IsIntegerType() is true for bool
+  if (SourceTy->isBooleanType()) {
+    if (DestTy->isBooleanType()) {
+      Result = Original;
+      return true;
+    }
+    if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      // TODO destty is wrong here if destty is float....
+      // can we use sourcety here?
+    }
+    if (DestTy->isFloatingType()) {
+      APValue Result2 = APValue(APFloat(0.0));
+      if (!HandleIntToFloatCast(Info, E, FPO,
+                                Info.Ctx.getIntTypeForBitwidth(64, true),
+                                Result.getInt(), DestTy, Result2.getFloat()))
+        return false;
+      Result = Result2;
+    }
+    return true;
+  }
+  if (SourceTy->isIntegerType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = APValue(APFloat(0.0));
+      return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+                                  DestTy, Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(
+          HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+      return true;
+    }
+  } else if (SourceTy->isRealFloatingType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = Original;
+      return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+                                    Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(APSInt());
+      return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+                                  DestTy, Result.getInt());
+    }
+  }
+
+  // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+  //   << SourceTy << DestTy;
+  return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+                               const Expr *E, APValue &Result,
+                               QualType ResultType,
+                               SmallVectorImpl<APValue> &Elements,
+                               SmallVectorImpl<QualType> &ElTypes) {
+
+  SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+      {&Result, ResultType, 0}};
+
+  unsigned ElI = 0;
+  while (!WorkList.empty() && ElI < Elements.size()) {
+    auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+    if (Type->isRealFloatingType() || Type->isBooleanType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      ElI++;
+      continue;
+    }
+    if (Type->isIntegerType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      if (BitWidth > 0) {
+        if (!Res->isInt())
+          return false;
+        APSInt &Int = Res->getInt();
+        unsigned OldBitWidth = Int.getBitWidth();
+        unsigned NewBitWidth = BitWidth;
+        if (NewBitWidth < OldBitWidth)
+          Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+      }
+      ElI++;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      SmallVector<APValue> Vals(NumEl);
+      for (unsigned I = 0; I < NumEl; ++I) {
+        if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+                              Vals[I]))
+          return false;
+        ElI++;
+      }
+      *Res = APValue(Vals.data(), NumEl);
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      *Res = APValue(APValue::UninitArray(), Size, Size);
+      for (int64_t I = Size - 1; I > -1; --I) {
+        WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      unsigned NumBases = 0;
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+        NumBases = CXXRD->getNumBases();
+
+      *Res = APValue(APValue::UninitStruct(), NumBases,
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+      // we need to traverse backwards
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+        }
+      }
+
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        unsigned FDBW = 0;
+        if (FD->isUnnamedBitField())
+          continue;
+        if (FD->isBitField()) {
+          FDBW = FD->getBitWidthValue();
+        }
+
+        ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+                                 FD->getType(), FDBW);
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+                                  const FPOptions FPO,
+                                  SmallVectorImpl<APValue> &Elements,
+                                  SmallVectorImpl<QualType> &SrcTypes,
+                                  SmallVectorImpl<QualType> &DestTypes,
+                                  SmallVectorImpl<APValue> &Results) {
+
+  assert((Elements.size() == SrcTypes.size()) &&
+         (Elements.size() == DestTypes.size()));
+
+  for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+    APValue Original = Elements[I];
+    QualType SourceTy = SrcTypes[I];
+    QualType DestTy = DestTypes[I];
+
+    if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+      return false;
+  }
+  return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+  SmallVector<QualType> WorkList = {BaseTy};
+
+  unsigned Size = 0;
+  while (!WorkList.empty()) {
+    QualType Type = WorkList.pop_back_val();
+    if (Type->isRealFloatingType() || Type->isIntegerType() ||
+        Type->isBooleanType()) {
+      ++Size;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      Size += NumEl;
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      for (uint64_t I = 0; I < Size; ++I) {
+        WorkList.push_back(ElTy);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+      // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          WorkList.push_back(BS.getType());
+        }
+      }
+
+      // visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        WorkList.push_back(FD->getType());
+      }
+      continue;
+    }
+  }
+  return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+                           QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+                           SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+  SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+  unsigned Populated = 0;
+  while (!WorkList.empty() && Populated < Size) {
+    auto [Work, Type] = WorkList.pop_back_val();
+
+    if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+      Elements.push_back(Work);
+      Types.push_back(Type);
+      Populated++;
+      continue;
+    }
+    if (Work.isVector()) {
+      assert(Type->isVectorType() && "Type mismatch.");
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+           I++) {
+        Elements.push_back(Work.getVectorElt(I));
+        Types.push_back(ElTy);
+        Populated++;
+      }
+      continue;
+    }
+    if (Work.isArray()) {
+      assert(Type->isConstantArrayType() && "Type mismatch.");
+      QualType ElTy =
+          cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+      for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+        WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+      }
+      continue;
+    }
+
+    if (Work.isStruct()) {
+      assert(Type->isRecordType() && "Type mismatch.");
+
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      SmallVector<std::pair<APValue, QualType>> ReverseList;
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        // if (FD->isBitField()) {
+        ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+                                 FD->getType());
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          const APValue &Base = Work.getStructBase(0);
+
+          // Can happen in error cases.
+          if (!Base.isStruct())
+            return false;
+
+          WorkList.emplace_back(Base, BS.getType());
+        }
+      }
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
 namespace {
 /// A handle to a complete object (an object that is not a subobject of
 /// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
     case CK_UserDefinedConversion:
       return StmtVisitorTy::Visit(E->getSubExpr());
 
+    case CK_HLSLArrayRValue: {
+      const Expr *SubExpr = E->getSubExpr();
+      if (!SubExpr->isGLValue()) {
+        APValue Val;
+        if (!Evaluate(Val, Info, SubExpr))
+          return false;
+        return DerivedSuccess(Val, E);
+      }
+
+      LValue LVal;
+      if (!EvaluateLValue(SubExpr, LVal, Info))
+        return false;
+      APValue RVal;
+      // Note, we use the subexpression's type in order to retain cv-qualifiers.
+      if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+                                          RVal))
+        return false;
+      return DerivedSuccess(RVal, E);
+    }
     case CK_LValueToRValue: {
       LValue LVal;
       if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
     Result = *Value;
     return true;
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    APValue Tmp;
+    handleDefaultInitValue(E->getType(), Tmp);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
   }
 }
 
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
       Elements.push_back(Val.getVectorElt(I));
     return Success(Elements, E);
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // this cast doesn't handle splatting from scalars when result is a vector
+    SmallVector<APValue, 1> Elements;
+    SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+
+    // check there is only one element and cast and splat it
+    assert(Elements.size() == 1 &&
+           "HLSLAggregateSplatCast RHS must contain one element");
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 1> ResultEls(1);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+
+    SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+    return Success(SplatEls, E);
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    SmallVector<APValue, 4> Elements;
+    SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+    SmallVector<QualType, 4> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+    // cast elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 4> ResultEls(NElts);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+    return Success(ResultEls, E);
+  }
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
   }
@@ -13029,6 +13488,7 @@ namespace {
     bool VisitCallExpr(const CallExpr *E) {
       return handleCallExpr(E, Result, &This);
     }
+    bool VisitCastExpr(const CastExpr *E);
     bool VisitInitListExpr(const InitListExpr *E,
                            QualType AllocType = QualType());
     bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
   return true;
 }
 
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+  const Expr *SE = E->getSubExpr();
+
+  switch (E->getCastKind()) {
+  default:
+    return ExprEvaluatorBaseTy::VisitCastExpr(E);
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
+  }
+}
+
 bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
                                            QualType AllocType) {
   const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_NoOp:
   case CK_LValueToRValueBitCast:
   case CK_HLSLArrayRValue:
-  case CK_HLSLElementwiseCast:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
   case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SubExpr))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+                                          LVal, Val...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-clang

Author: Sarah Spall (spall)

Changes

Add support to handle these casts in the constant expression evaluator.

  • HLSLAggregateSplatCast
  • HLSLElementwiseCast
  • HLSLArrayRValue

Add tests


Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164700.diff

4 Files Affected:

  • (modified) clang/lib/AST/ExprConstant.cpp (+586-1)
  • (added) clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl (+89)
  • (modified) clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl (+21)
  • (added) clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl (+76)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
       /*Diag=*/true);
 }
 
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+                             QualType SourceTy, QualType DestTy,
+                             APValue const &Original, APValue &Result) {
+  // boolean must be checked before integer
+  // since IsIntegerType() is true for bool
+  if (SourceTy->isBooleanType()) {
+    if (DestTy->isBooleanType()) {
+      Result = Original;
+      return true;
+    }
+    if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      // TODO destty is wrong here if destty is float....
+      // can we use sourcety here?
+    }
+    if (DestTy->isFloatingType()) {
+      APValue Result2 = APValue(APFloat(0.0));
+      if (!HandleIntToFloatCast(Info, E, FPO,
+                                Info.Ctx.getIntTypeForBitwidth(64, true),
+                                Result.getInt(), DestTy, Result2.getFloat()))
+        return false;
+      Result = Result2;
+    }
+    return true;
+  }
+  if (SourceTy->isIntegerType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = APValue(APFloat(0.0));
+      return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+                                  DestTy, Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(
+          HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+      return true;
+    }
+  } else if (SourceTy->isRealFloatingType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = Original;
+      return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+                                    Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(APSInt());
+      return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+                                  DestTy, Result.getInt());
+    }
+  }
+
+  // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+  //   << SourceTy << DestTy;
+  return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+                               const Expr *E, APValue &Result,
+                               QualType ResultType,
+                               SmallVectorImpl<APValue> &Elements,
+                               SmallVectorImpl<QualType> &ElTypes) {
+
+  SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+      {&Result, ResultType, 0}};
+
+  unsigned ElI = 0;
+  while (!WorkList.empty() && ElI < Elements.size()) {
+    auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+    if (Type->isRealFloatingType() || Type->isBooleanType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      ElI++;
+      continue;
+    }
+    if (Type->isIntegerType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      if (BitWidth > 0) {
+        if (!Res->isInt())
+          return false;
+        APSInt &Int = Res->getInt();
+        unsigned OldBitWidth = Int.getBitWidth();
+        unsigned NewBitWidth = BitWidth;
+        if (NewBitWidth < OldBitWidth)
+          Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+      }
+      ElI++;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      SmallVector<APValue> Vals(NumEl);
+      for (unsigned I = 0; I < NumEl; ++I) {
+        if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+                              Vals[I]))
+          return false;
+        ElI++;
+      }
+      *Res = APValue(Vals.data(), NumEl);
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      *Res = APValue(APValue::UninitArray(), Size, Size);
+      for (int64_t I = Size - 1; I > -1; --I) {
+        WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      unsigned NumBases = 0;
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+        NumBases = CXXRD->getNumBases();
+
+      *Res = APValue(APValue::UninitStruct(), NumBases,
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+      // we need to traverse backwards
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+        }
+      }
+
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        unsigned FDBW = 0;
+        if (FD->isUnnamedBitField())
+          continue;
+        if (FD->isBitField()) {
+          FDBW = FD->getBitWidthValue();
+        }
+
+        ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+                                 FD->getType(), FDBW);
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+                                  const FPOptions FPO,
+                                  SmallVectorImpl<APValue> &Elements,
+                                  SmallVectorImpl<QualType> &SrcTypes,
+                                  SmallVectorImpl<QualType> &DestTypes,
+                                  SmallVectorImpl<APValue> &Results) {
+
+  assert((Elements.size() == SrcTypes.size()) &&
+         (Elements.size() == DestTypes.size()));
+
+  for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+    APValue Original = Elements[I];
+    QualType SourceTy = SrcTypes[I];
+    QualType DestTy = DestTypes[I];
+
+    if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+      return false;
+  }
+  return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+  SmallVector<QualType> WorkList = {BaseTy};
+
+  unsigned Size = 0;
+  while (!WorkList.empty()) {
+    QualType Type = WorkList.pop_back_val();
+    if (Type->isRealFloatingType() || Type->isIntegerType() ||
+        Type->isBooleanType()) {
+      ++Size;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      Size += NumEl;
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      for (uint64_t I = 0; I < Size; ++I) {
+        WorkList.push_back(ElTy);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+      // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          WorkList.push_back(BS.getType());
+        }
+      }
+
+      // visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        WorkList.push_back(FD->getType());
+      }
+      continue;
+    }
+  }
+  return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+                           QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+                           SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+  SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+  unsigned Populated = 0;
+  while (!WorkList.empty() && Populated < Size) {
+    auto [Work, Type] = WorkList.pop_back_val();
+
+    if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+      Elements.push_back(Work);
+      Types.push_back(Type);
+      Populated++;
+      continue;
+    }
+    if (Work.isVector()) {
+      assert(Type->isVectorType() && "Type mismatch.");
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+           I++) {
+        Elements.push_back(Work.getVectorElt(I));
+        Types.push_back(ElTy);
+        Populated++;
+      }
+      continue;
+    }
+    if (Work.isArray()) {
+      assert(Type->isConstantArrayType() && "Type mismatch.");
+      QualType ElTy =
+          cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+      for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+        WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+      }
+      continue;
+    }
+
+    if (Work.isStruct()) {
+      assert(Type->isRecordType() && "Type mismatch.");
+
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      SmallVector<std::pair<APValue, QualType>> ReverseList;
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        // if (FD->isBitField()) {
+        ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+                                 FD->getType());
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          const APValue &Base = Work.getStructBase(0);
+
+          // Can happen in error cases.
+          if (!Base.isStruct())
+            return false;
+
+          WorkList.emplace_back(Base, BS.getType());
+        }
+      }
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
 namespace {
 /// A handle to a complete object (an object that is not a subobject of
 /// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
     case CK_UserDefinedConversion:
       return StmtVisitorTy::Visit(E->getSubExpr());
 
+    case CK_HLSLArrayRValue: {
+      const Expr *SubExpr = E->getSubExpr();
+      if (!SubExpr->isGLValue()) {
+        APValue Val;
+        if (!Evaluate(Val, Info, SubExpr))
+          return false;
+        return DerivedSuccess(Val, E);
+      }
+
+      LValue LVal;
+      if (!EvaluateLValue(SubExpr, LVal, Info))
+        return false;
+      APValue RVal;
+      // Note, we use the subexpression's type in order to retain cv-qualifiers.
+      if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+                                          RVal))
+        return false;
+      return DerivedSuccess(RVal, E);
+    }
     case CK_LValueToRValue: {
       LValue LVal;
       if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
     Result = *Value;
     return true;
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    APValue Tmp;
+    handleDefaultInitValue(E->getType(), Tmp);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
   }
 }
 
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
       Elements.push_back(Val.getVectorElt(I));
     return Success(Elements, E);
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // this cast doesn't handle splatting from scalars when result is a vector
+    SmallVector<APValue, 1> Elements;
+    SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+
+    // check there is only one element and cast and splat it
+    assert(Elements.size() == 1 &&
+           "HLSLAggregateSplatCast RHS must contain one element");
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 1> ResultEls(1);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+
+    SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+    return Success(SplatEls, E);
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    SmallVector<APValue, 4> Elements;
+    SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+    SmallVector<QualType, 4> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+    // cast elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 4> ResultEls(NElts);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+    return Success(ResultEls, E);
+  }
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
   }
@@ -13029,6 +13488,7 @@ namespace {
     bool VisitCallExpr(const CallExpr *E) {
       return handleCallExpr(E, Result, &This);
     }
+    bool VisitCastExpr(const CastExpr *E);
     bool VisitInitListExpr(const InitListExpr *E,
                            QualType AllocType = QualType());
     bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
   return true;
 }
 
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+  const Expr *SE = E->getSubExpr();
+
+  switch (E->getCastKind()) {
+  default:
+    return ExprEvaluatorBaseTy::VisitCastExpr(E);
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
+  }
+}
+
 bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
                                            QualType AllocType) {
   const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_NoOp:
   case CK_LValueToRValueBitCast:
   case CK_HLSLArrayRValue:
-  case CK_HLSLElementwiseCast:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
   case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SubExpr))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+                                          LVal, Val...
[truncated]

@damyanp damyanp requested review from bob80905 and inbelic October 29, 2025 17:12
unsigned OldBitWidth = Int.getBitWidth();
unsigned NewBitWidth = BitWidth;
if (NewBitWidth < OldBitWidth)
Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why extend?

Copy link
Contributor

@bob80905 bob80905 Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the sense my comment here:
Can we really guarantee that this memory won't be overwritten and these results will be consistent?
is relevant to this line of code. By extending, we guarantee the memory is preserved?
Is there an agreed upon design that shows this is the desired behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this line from 'truncateBitfieldValue'; I get the impression the extension is to make the integer to be the right "size" to store in the APInt value. I'll double check this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any findings on this?

@github-actions
Copy link

github-actions bot commented Oct 30, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@inbelic inbelic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically stepping through the test-cases and following the code, it makes sense to me.

Just some comments to see if it is possible to simplify the code a little bit

Comment on lines 62 to 63
constexpr double D = 100.6789;
constexpr R SR = (R)D;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
constexpr double D = 100.6789;
constexpr R SR = (R)D;
constexpr double D = 97.6789;
constexpr R SR = (R)(D + 3.0);

This would check that the RHS is being evaluated as we expect

uint64_t Size =
cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
*Res = APValue(APValue::UninitArray(), Size, Size);
for (int64_t I = Size - 1; I > -1; --I) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int64_t I = Size - 1; I > -1; --I) {
for (int64_t I = Size - 1; I > -1; --I)

nit: style guide


// result type float
// truncate from array
constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0};
constexpr B1 Arr[2] = {(2.5 + 1.5), 3.0, 2.0, 1.0};

Is something like this expected to work? What about:

constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0} + {4.0, 3.0, 2.0, 1.0}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggested change yes, but I don't think the 2nd suggestion is valid HLSL.

}

// Visit the fields.
for (FieldDecl *FD : RD->fields()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (FieldDecl *FD : RD->fields()) {
for (FieldDecl *FD : llvm::reverse(RD->fields())) {

Would doing this and then moving the above base class logic below this for loop let us directly append to the worklist rather than creating the intermediate smallvector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure you can apply reverse to that but I'll check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't appear to be valid.

return Success(Elements, E);
}
case CK_HLSLAggregateSplatCast: {
APValue Val;
Copy link
Contributor

@inbelic inbelic Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code in many of these switch statements seem very similar, it seems like we could abstract out a common function that:

  • evaluates the sub-expression
  • flattens out the value
  • invokes the correct kind of cast based on the number of elements (maybe the function is templated for the type of required cast)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created helper functions for both HLSLAggregateSplatCast and HLSLElementwiseCast, but only handled the parts common to each type and let the case statements till handle constructing the proper return type. I didn't want to template the code further for the required cast.

Copy link
Contributor

@inbelic inbelic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks

@spall spall merged commit 4d67e15 into llvm:main Nov 6, 2025
10 checks passed
vinay-deshmukh pushed a commit to vinay-deshmukh/llvm-project that referenced this pull request Nov 8, 2025
… to constant expression evaluator (llvm#164700)

Add support to handle these casts in the constant expression evaluator. 

- HLSLAggregateSplatCast
- HLSLElementwiseCast
- HLSLArrayRValue

Add tests 
Closes llvm#125766
Closes llvm#125321
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL] Implement constant expression evaluator for HLSL splat cast [HLSL] Implement Constant expression evaluator for flat casts

4 participants