Skip to content

Commit e593216

Browse files
committed
Properly compute iterating expansion stmt size
1 parent 6b08737 commit e593216

File tree

1 file changed

+167
-24
lines changed

1 file changed

+167
-24
lines changed

clang/lib/Sema/SemaExpand.cpp

Lines changed: 167 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "clang/Sema/Sema.h"
2222
#include "clang/Sema/Template.h"
2323

24+
#include <llvm/ADT/ScopeExit.h>
25+
2426
using namespace clang;
2527
using namespace sema;
2628

@@ -273,6 +275,13 @@ StmtResult Sema::BuildNonEnumeratingCXXExpansionStmtPattern(
273275
ArrayRef<MaterializeTemporaryExpr *> LifetimeExtendTemps) {
274276
VarDecl *ExpansionVar = cast<VarDecl>(ExpansionVarStmt->getSingleDecl());
275277

278+
// Reject lambdas early.
279+
if (auto *RD = ExpansionInitializer->getType()->getAsCXXRecordDecl();
280+
RD && RD->isLambda()) {
281+
Diag(ExpansionInitializer->getBeginLoc(), diag::err_expansion_stmt_lambda);
282+
return StmtError();
283+
}
284+
276285
if (ExpansionInitializer->isTypeDependent()) {
277286
ActOnDependentForRangeInitializer(ExpansionVar, BFRK_Build);
278287
return new (Context) CXXDependentExpansionStmtPattern(
@@ -291,13 +300,6 @@ StmtResult Sema::BuildNonEnumeratingCXXExpansionStmtPattern(
291300
return StmtError();
292301
}
293302

294-
// Reject lambdas early.
295-
if (auto *RD = ExpansionInitializer->getType()->getAsCXXRecordDecl();
296-
RD && RD->isLambda()) {
297-
Diag(ExpansionInitializer->getBeginLoc(), diag::err_expansion_stmt_lambda);
298-
return StmtError();
299-
}
300-
301303
// Otherwise, if it can be an iterating expansion statement, it is one.
302304
DeclRefExpr *Index = BuildIndexDRE(*this, ESD);
303305
IterableExpansionStmtData Data = TryBuildIterableExpansionStmtInitializer(
@@ -333,6 +335,18 @@ StmtResult Sema::FinishCXXExpansionStmt(Stmt *Exp, Stmt *Body) {
333335
if (Expansion->hasDependentSize())
334336
return Expansion;
335337

338+
// Now that we're expanding this, exit the context of the expansion stmt
339+
// so that we no longer treat this as dependent.
340+
ContextRAII CtxGuard(*this, CurContext->getParent(),
341+
/*NewThis=*/false);
342+
343+
// Even if the size isn't technically dependent, delay expansion until
344+
// we're no longer in a template if this is an iterating expansion statement
345+
// since evaluating a lambda declared in a template doesn't work too well.
346+
if (CurContext->isDependentContext() &&
347+
isa<CXXIteratingExpansionStmtPattern>(Expansion))
348+
return Expansion;
349+
336350
// This can fail if this is an iterating expansion statement.
337351
std::optional<uint64_t> NumInstantiations = ComputeExpansionSize(Expansion);
338352
if (!NumInstantiations)
@@ -371,11 +385,6 @@ StmtResult Sema::FinishCXXExpansionStmt(Stmt *Exp, Stmt *Body) {
371385
SmallVector<Stmt *, 4> Instantiations;
372386
CXXExpansionStmtDecl *ESD = Expansion->getDecl();
373387
for (uint64_t I = 0; I < *NumInstantiations; ++I) {
374-
// Now that we're expanding this, exit the context of the expansion stmt
375-
// so that we no longer treat this as dependent.
376-
ContextRAII CtxGuard(*this, CurContext->getParent(),
377-
/*NewThis=*/false);
378-
379388
TemplateArgument Arg{Context, llvm::APSInt::get(I),
380389
Context.getPointerDiffType()};
381390
MultiLevelTemplateArgumentList MTArgList(ESD, Arg, true);
@@ -437,42 +446,176 @@ Sema::ComputeExpansionSize(CXXExpansionStmtPattern *Expansion) {
437446
// }()
438447
// TODO: CWG 3131 changes this lambda a bit.
439448
if (auto *Iterating = dyn_cast<CXXIteratingExpansionStmtPattern>(Expansion)) {
449+
SourceLocation Loc = Expansion->getColonLoc();
440450
EnterExpressionEvaluationContext ExprEvalCtx(
441451
*this, ExpressionEvaluationContext::ConstantEvaluated);
442452

443-
// FIXME: Actually do that; unfortunately, conjuring a lambda out of thin
444-
// air in Sema is a massive pain, so for now just cheat by computing
445-
// 'end - begin'.
446-
SourceLocation Loc = Iterating->getColonLoc();
453+
// This is mostly copied from ParseLambdaExpressionAfterIntroducer().
454+
ParseScope LambdaScope(*this, Scope::LambdaScope | Scope::DeclScope |
455+
Scope::FunctionDeclarationScope |
456+
Scope::FunctionPrototypeScope);
457+
AttributeFactory AttrFactory;
458+
LambdaIntroducer Intro;
459+
Intro.Range = SourceRange(Loc, Loc);
460+
Intro.Default = LCD_ByRef; // CWG 3131
461+
Intro.DefaultLoc = Loc;
462+
DeclSpec DS(AttrFactory);
463+
Declarator D(DS, ParsedAttributesView::none(),
464+
DeclaratorContext::LambdaExpr);
465+
PushLambdaScope();
466+
ActOnLambdaExpressionAfterIntroducer(Intro, getCurScope());
467+
468+
// Make the lambda 'consteval'.
469+
{
470+
ParseScope Prototype(*this, Scope::FunctionPrototypeScope |
471+
Scope::FunctionDeclarationScope |
472+
Scope::DeclScope);
473+
const char* PrevSpec = nullptr;
474+
unsigned DiagId = 0;
475+
DS.SetConstexprSpec(ConstexprSpecKind::Consteval, Loc, PrevSpec, DiagId);
476+
assert(DiagId == 0 && PrevSpec == nullptr);
477+
ActOnLambdaClosureParameters(getCurScope(), /*ParamInfo=*/{});
478+
ActOnLambdaClosureQualifiers(Intro, /*MutableLoc=*/SourceLocation());
479+
}
480+
481+
ParseScope BodyScope(*this, Scope::BlockScope | Scope::FnScope |
482+
Scope::DeclScope |
483+
Scope::CompoundStmtScope);
484+
485+
ActOnStartOfLambdaDefinition(Intro, D, DS);
486+
487+
// Enter the compound statement that is the lambda body.
488+
ActOnStartOfCompoundStmt(/*IsStmtExpr=*/false);
489+
ActOnAfterCompoundStatementLeadingPragmas();
490+
auto PopScopesOnReturn = llvm::make_scope_exit([&] {
491+
ActOnFinishOfCompoundStmt();
492+
ActOnLambdaError(Loc, getCurScope());
493+
});
494+
495+
// std::ptrdiff_t result = 0;
496+
QualType PtrDiffT = Context.getPointerDiffType();
497+
VarDecl *ResultVar = VarDecl::Create(
498+
Context, CurContext, Loc, Loc, &PP.getIdentifierTable().get("__result"),
499+
PtrDiffT, Context.getTrivialTypeSourceInfo(PtrDiffT, Loc), SC_None);
500+
Expr *Zero = ActOnIntegerConstant(Loc, 0).get();
501+
AddInitializerToDecl(ResultVar, Zero, false);
502+
StmtResult ResultVarStmt =
503+
ActOnDeclStmt(ConvertDeclToDeclGroup(ResultVar), Loc, Loc);
504+
if (ResultVarStmt.isInvalid() || ResultVar->isInvalidDecl())
505+
return std::nullopt;
506+
507+
// Start the for loop.
508+
ParseScope ForScope(*this, Scope::DeclScope | Scope::ControlScope);
509+
510+
// auto i = begin;
511+
VarDecl *IterationVar = VarDecl::Create(
512+
Context, CurContext, Loc, Loc, &PP.getIdentifierTable().get("__i"),
513+
Context.getAutoDeductType(),
514+
Context.getTrivialTypeSourceInfo(Context.getAutoDeductType(), Loc),
515+
SC_None);
447516
DeclRefExpr *Begin = BuildDeclRefExpr(
448517
Iterating->getBeginVar(),
449518
Iterating->getBeginVar()->getType().getNonReferenceType(), VK_LValue,
450519
Loc);
520+
AddInitializerToDecl(IterationVar, Begin, false);
521+
StmtResult IterationVarStmt =
522+
ActOnDeclStmt(ConvertDeclToDeclGroup(IterationVar), Loc, Loc);
523+
if (IterationVarStmt.isInvalid() || IterationVar->isInvalidDecl())
524+
return std::nullopt;
451525

526+
// i != end
527+
DeclRefExpr *IterationVarDeclRef = BuildDeclRefExpr(
528+
IterationVar, IterationVar->getType().getNonReferenceType(), VK_LValue,
529+
Loc);
452530
DeclRefExpr *End = BuildDeclRefExpr(
453531
Iterating->getEndVar(),
454532
Iterating->getEndVar()->getType().getNonReferenceType(), VK_LValue,
455533
Loc);
534+
ExprResult NotEqual = ActOnBinOp(getCurScope(), Loc, tok::exclaimequal,
535+
IterationVarDeclRef, End);
536+
if (NotEqual.isInvalid())
537+
return std::nullopt;
538+
ConditionResult Condition = ActOnCondition(
539+
getCurScope(), Loc, NotEqual.get(), ConditionKind::Boolean,
540+
/*MissingOk=*/false);
541+
if (Condition.isInvalid())
542+
return std::nullopt;
543+
544+
// ++i
545+
IterationVarDeclRef = BuildDeclRefExpr(
546+
IterationVar, IterationVar->getType().getNonReferenceType(), VK_LValue,
547+
Loc);
548+
ExprResult Increment =
549+
ActOnUnaryOp(getCurScope(), Loc, tok::plusplus, IterationVarDeclRef);
550+
if (Increment.isInvalid())
551+
return std::nullopt;
552+
FullExprArg ThirdPart = MakeFullDiscardedValueExpr(Increment.get());
553+
554+
// Enter the body of the for loop.
555+
ParseScope InnerScope(*this, Scope::DeclScope);
556+
getCurScope()->decrementMSManglingNumber();
557+
558+
// ++result;
559+
DeclRefExpr *ResultDeclRef = BuildDeclRefExpr(
560+
ResultVar, ResultVar->getType().getNonReferenceType(), VK_LValue, Loc);
561+
ExprResult IncrementResult =
562+
ActOnUnaryOp(getCurScope(), Loc, tok::plusplus, ResultDeclRef);
563+
if (IncrementResult.isInvalid())
564+
return std::nullopt;
565+
StmtResult IncrementStmt = ActOnExprStmt(IncrementResult.get());
566+
if (IncrementStmt.isInvalid())
567+
return std::nullopt;
456568

457-
ExprResult N = ActOnBinOp(getCurScope(), Loc, tok::minus, End, Begin);
458-
if (N.isInvalid())
569+
// Exit the for loop.
570+
InnerScope.Exit();
571+
ForScope.Exit();
572+
StmtResult ForLoop =
573+
ActOnForStmt(Loc, Loc, IterationVarStmt.get(), Condition, ThirdPart,
574+
Loc, IncrementStmt.get());
575+
if (ForLoop.isInvalid())
576+
return std::nullopt;
577+
578+
// return result;
579+
ResultDeclRef = BuildDeclRefExpr(
580+
ResultVar, ResultVar->getType().getNonReferenceType(), VK_LValue, Loc);
581+
StmtResult Return = ActOnReturnStmt(Loc, ResultDeclRef, getCurScope());
582+
if (Return.isInvalid())
583+
return std::nullopt;
584+
585+
// Finally, we can build the compound statement that is the lambda body.
586+
StmtResult LambdaBody = ActOnCompoundStmt(
587+
Loc, Loc, {ResultVarStmt.get(), ForLoop.get(), Return.get()},
588+
/*isStmtExpr=*/false);
589+
if (LambdaBody.isInvalid())
590+
return std::nullopt;
591+
592+
ActOnFinishOfCompoundStmt();
593+
BodyScope.Exit();
594+
LambdaScope.Exit();
595+
PopScopesOnReturn.release();
596+
ExprResult Lambda = ActOnLambdaExpr(Loc, LambdaBody.get());
597+
if (Lambda.isInvalid())
598+
return std::nullopt;
599+
600+
// Invoke the lambda.
601+
ExprResult Call =
602+
ActOnCallExpr(getCurScope(), Lambda.get(), Loc, /*ArgExprs=*/{}, Loc);
603+
if (Call.isInvalid())
459604
return std::nullopt;
460605

461606
Expr::EvalResult ER;
462607
SmallVector<PartialDiagnosticAt, 4> Notes;
463608
ER.Diag = &Notes;
464-
if (!N.get()->EvaluateAsInt(ER, Context)) {
609+
if (!Call.get()->EvaluateAsInt(ER, Context)) {
465610
Diag(Loc, diag::err_expansion_size_expr_not_ice);
466611
for (const auto &[Location, PDiag] : Notes)
467612
Diag(Location, PDiag);
468613
return std::nullopt;
469614
}
470615

471-
if (ER.Val.getInt().isNegative()) {
472-
Diag(Loc, diag::err_expansion_size_negative) << ER.Val.getInt();
473-
return std::nullopt;
474-
}
475-
616+
// It shouldn't be possible for this to be negative since we compute this
617+
// via the built-in '++' on a ptrdiff_t.
618+
assert(ER.Val.getInt().isNonNegative());
476619
return ER.Val.getInt().getZExtValue();
477620
}
478621

0 commit comments

Comments
 (0)