Skip to content

Commit 39861dd

Browse files
Remove early optimization for cascading cases and update test cases accordingly
1 parent b032de7 commit 39861dd

File tree

3 files changed

+180
-60
lines changed

3 files changed

+180
-60
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ struct MissingFeatures {
9797
// Unary operator handling
9898
static bool opUnaryPromotionType() { return false; }
9999

100+
// SwitchOp handling
101+
static bool foldRangeCase() { return false; }
102+
100103
// Clang early optimizations or things defered to LLVM lowering.
101104
static bool mayHaveIntegerOverflow() { return false; }
102105
static bool shouldReverseUnaryCondOnBoolExpr() { return false; }

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 28 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -429,52 +429,6 @@ mlir::LogicalResult CIRGenFunction::emitBreakStmt(const clang::BreakStmt &s) {
429429
return mlir::success();
430430
}
431431

432-
const CaseStmt *CIRGenFunction::foldCaseStmt(const clang::CaseStmt &s,
433-
mlir::Type condType,
434-
mlir::ArrayAttr &value,
435-
cir::CaseOpKind &kind) {
436-
const CaseStmt *caseStmt = &s;
437-
const CaseStmt *lastCase = &s;
438-
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;
439-
440-
// Fold cascading cases whenever possible to simplify codegen a bit.
441-
while (caseStmt) {
442-
lastCase = caseStmt;
443-
444-
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
445-
446-
if (auto *rhs = caseStmt->getRHS()) {
447-
auto endVal = rhs->EvaluateKnownConstInt(getContext());
448-
SmallVector<mlir::Attribute, 4> rangeCaseAttr = {
449-
cir::IntAttr::get(condType, intVal),
450-
cir::IntAttr::get(condType, endVal)};
451-
value = builder.getArrayAttr(rangeCaseAttr);
452-
kind = cir::CaseOpKind::Range;
453-
454-
// We may not be able to fold rangaes. Due to we can't present range case
455-
// with other trivial cases now.
456-
return caseStmt;
457-
}
458-
459-
caseEltValueListAttr.push_back(cir::IntAttr::get(condType, intVal));
460-
461-
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
462-
463-
// Break early if we found ranges. We can't fold ranges due to the same
464-
// reason above.
465-
if (caseStmt && caseStmt->getRHS())
466-
break;
467-
}
468-
469-
if (!caseEltValueListAttr.empty()) {
470-
value = builder.getArrayAttr(caseEltValueListAttr);
471-
kind = caseEltValueListAttr.size() > 1 ? cir::CaseOpKind::Anyof
472-
: cir::CaseOpKind::Equal;
473-
}
474-
475-
return lastCase;
476-
}
477-
478432
template <typename T>
479433
mlir::LogicalResult
480434
CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
@@ -502,7 +456,8 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
502456
if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
503457
subStmtKind = SubStmtKind::Default;
504458
builder.createYield(loc);
505-
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
459+
} else if ((isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) ||
460+
(isa<CaseStmt>(sub) && isa<CaseStmt>(stmt))) {
506461
subStmtKind = SubStmtKind::Case;
507462
builder.createYield(loc);
508463
} else {
@@ -564,8 +519,32 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
564519
bool buildingTopLevelCase) {
565520
cir::CaseOpKind kind;
566521
mlir::ArrayAttr value;
567-
const CaseStmt *caseStmt = foldCaseStmt(s, condType, value, kind);
568-
return emitCaseDefaultCascade(caseStmt, condType, value, kind,
522+
523+
SmallVector<mlir::Attribute, 1> caseEltValueListAttr;
524+
llvm::APSInt intVal = s.getLHS()->EvaluateKnownConstInt(getContext());
525+
526+
// If the case statement has an RHS value, it is representing a GNU
527+
// case range statement, where LHS is the beginning of the range
528+
// and RHS is the end of the range.
529+
if (const Expr *rhs = s.getRHS()) {
530+
531+
llvm::APSInt endVal = rhs->EvaluateKnownConstInt(getContext());
532+
SmallVector<mlir::Attribute, 4> rangeCaseAttr = {
533+
cir::IntAttr::get(condType, intVal),
534+
cir::IntAttr::get(condType, endVal)};
535+
value = builder.getArrayAttr(rangeCaseAttr);
536+
kind = cir::CaseOpKind::Range;
537+
538+
// We don't currently fold case range statements with other case statements.
539+
// TODO(cir): Add this capability.
540+
assert(!cir::MissingFeatures::foldRangeCase());
541+
} else {
542+
caseEltValueListAttr.push_back(cir::IntAttr::get(condType, intVal));
543+
value = builder.getArrayAttr(caseEltValueListAttr);
544+
kind = cir::CaseOpKind::Equal;
545+
}
546+
547+
return emitCaseDefaultCascade(&s, condType, value, kind,
569548
buildingTopLevelCase);
570549
}
571550

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 149 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,25 @@ void sw6(int a) {
207207

208208
// CIR: cir.func @_Z3sw6i
209209
// CIR: cir.switch (%1 : !s32i) {
210-
// CIR-NEXT: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
211-
// CIR-NEXT: cir.break
210+
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
211+
// CIR-NEXT: cir.yield
212212
// CIR-NEXT: }
213-
// CIR-NEXT: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
214-
// CIR-NEXT: cir.break
213+
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
214+
// CIR-NEXT: cir.yield
215+
// CIR-NEXT: }
216+
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
217+
// CIR-NEXT: cir.break
218+
// CIR-NEXT: }
219+
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
220+
// CIR-NEXT: cir.yield
221+
// CIR-NEXT: }
222+
// CIR-NEXT: cir.case(equal, [#cir.int<4> : !s32i]) {
223+
// CIR-NEXT: cir.yield
215224
// CIR-NEXT: }
225+
// CIR-NEXT: cir.case(equal, [#cir.int<5> : !s32i]) {
226+
// CIR-NEXT: cir.break
227+
// CIR-NEXT: }
228+
216229

217230
// OGCG: define dso_local void @_Z3sw6i
218231
// OGCG: entry:
@@ -248,13 +261,24 @@ void sw7(int a) {
248261
}
249262

250263
// CIR: cir.func @_Z3sw7i
251-
// CIR: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
252-
// CIR-NEXT: cir.yield
264+
// CIR: cir.case(equal, [#cir.int<0> : !s32i]) {
265+
// CIR-NEXT: cir.yield
253266
// CIR-NEXT: }
254-
// CIR-NEXT: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
255-
// CIR-NEXT: cir.break
267+
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
268+
// CIR-NEXT: cir.yield
269+
// CIR-NEXT: }
270+
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
271+
// CIR-NEXT: cir.yield
272+
// CIR-NEXT: }
273+
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
274+
// CIR-NEXT: cir.yield
275+
// CIR-NEXT: }
276+
// CIR-NEXT: cir.case(equal, [#cir.int<4> : !s32i]) {
277+
// CIR-NEXT: cir.yield
278+
// CIR-NEXT: }
279+
// CIR-NEXT: cir.case(equal, [#cir.int<5> : !s32i]) {
280+
// CIR-NEXT: cir.break
256281
// CIR-NEXT: }
257-
258282

259283
// OGCG: define dso_local void @_Z3sw7i
260284
// OGCG: entry:
@@ -419,13 +443,19 @@ void sw11(int a) {
419443
//CIR: cir.case(equal, [#cir.int<3> : !s32i]) {
420444
//CIR-NEXT: cir.break
421445
//CIR-NEXT: }
422-
//CIR-NEXT: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
446+
//CIR-NEXT: cir.case(equal, [#cir.int<4> : !s32i]) {
447+
//CIR-NEXT: cir.yield
448+
//CIR-NEXT: }
449+
//CIR-NEXT: cir.case(equal, [#cir.int<5> : !s32i]) {
423450
//CIR-NEXT: cir.yield
424451
//CIR-NEXT: }
425452
//CIR-NEXT: cir.case(default, []) {
426453
//CIR-NEXT: cir.yield
427454
//CIR-NEXT: }
428-
//CIR-NEXT: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
455+
//CIR-NEXT: cir.case(equal, [#cir.int<6> : !s32i]) {
456+
//CIR-NEXT: cir.yield
457+
//CIR-NEXT: }
458+
//CIR-NEXT: cir.case(equal, [#cir.int<7> : !s32i]) {
429459
//CIR-NEXT: cir.break
430460
//CIR-NEXT: }
431461

@@ -527,6 +557,114 @@ void sw13(int a, int b) {
527557
// OGCG: [[EPILOG2]]:
528558
// OGCG: ret void
529559

560+
void sw14(int x) {
561+
switch (x) {
562+
case 1:
563+
case 2:
564+
case 3 ... 6:
565+
case 7:
566+
break;
567+
default:
568+
break;
569+
}
570+
}
571+
572+
// CIR: cir.func @_Z4sw14i
573+
// CIR: cir.switch
574+
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
575+
// CIR-NEXT: cir.yield
576+
// CIR-NEXT: }
577+
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
578+
// CIR-NEXT: cir.yield
579+
// CIR-NEXT: }
580+
// CIR-NEXT: cir.case(range, [#cir.int<3> : !s32i, #cir.int<6> : !s32i]) {
581+
// CIR-NEXT: cir.yield
582+
// CIR-NEXT: }
583+
// CIR-NEXT: cir.case(equal, [#cir.int<7> : !s32i]) {
584+
// CIR-NEXT: cir.break
585+
// CIR-NEXT: }
586+
// CIR-NEXT: cir.case(default, []) {
587+
// CIR-NEXT: cir.break
588+
// CIR-NEXT: }
589+
590+
// OGCG: define dso_local void @_Z4sw14i
591+
// OGCG: entry:
592+
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
593+
// OGCG: store i32 %x, ptr %[[X_ADDR]], align 4
594+
// OGCG: %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4
595+
596+
// OGCG: switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [
597+
// OGCG-DAG: i32 1, label %[[BB1:.*]]
598+
// OGCG-DAG: i32 2, label %[[BB1]]
599+
// OGCG-DAG: i32 3, label %[[BB2:.*]]
600+
// OGCG-DAG: i32 4, label %[[BB2]]
601+
// OGCG-DAG: i32 5, label %[[BB2]]
602+
// OGCG-DAG: i32 6, label %[[BB2]]
603+
// OGCG-DAG: i32 7, label %[[BB3:.*]]
604+
// OGCG: ]
605+
// OGCG: [[BB1]]:
606+
// OGCG: br label %[[BB2]]
607+
// OGCG: [[BB2]]:
608+
// OGCG: br label %[[BB3]]
609+
// OGCG: [[BB3]]:
610+
// OGCG: br label %[[EPILOG:.*]]
611+
// OGCG: [[DEFAULT]]:
612+
// OGCG: br label %[[EPILOG]]
613+
// OGCG: [[EPILOG]]:
614+
// OGCG: ret void
615+
616+
void sw15(int x) {
617+
int y;
618+
switch (x) {
619+
case 1:
620+
case 2:
621+
y = 0;
622+
case 3:
623+
break;
624+
default:
625+
break;
626+
}
627+
}
628+
629+
// CIR: cir.func @_Z4sw15i
630+
// CIR: %[[Y:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["y"]
631+
// CIR: cir.switch
632+
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
633+
// CIR-NEXT: cir.yield
634+
// CIR-NEXT: }
635+
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
636+
// CIR-NEXT: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
637+
// CIR-NEXT: cir.store %[[ZERO]], %[[Y]] : !s32i, !cir.ptr<!s32i>
638+
// CIR-NEXT: cir.yield
639+
// CIR-NEXT: }
640+
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
641+
// CIR-NEXT: cir.break
642+
// CIR-NEXT: }
643+
// CIR-NEXT: cir.case(default, []) {
644+
// CIR-NEXT: cir.break
645+
// CIR-NEXT: }
646+
647+
// OGCG: define dso_local void @_Z4sw15i
648+
// OGCG: entry:
649+
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
650+
// OGCG: %[[Y:.*]] = alloca i32, align 4
651+
// OGCG: store i32 %x, ptr %[[X_ADDR]], align 4
652+
// OGCG: %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4
653+
// OGCG: switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [
654+
// OGCG-DAG: i32 1, label %[[BB0:.*]]
655+
// OGCG-DAG: i32 2, label %[[BB0]]
656+
// OGCG-DAG: i32 3, label %[[BB1:.*]]
657+
// OGCG: ]
658+
// OGCG: [[BB0]]:
659+
// OGCG: store i32 0, ptr %[[Y]], align 4
660+
// OGCG: br label %[[BB1]]
661+
// OGCG: [[BB1]]:
662+
// OGCG: br label %[[EPILOG:.*]]
663+
// OGCG: [[DEFAULT]]:
664+
// OGCG: br label %[[EPILOG]]
665+
// OGCG: [[EPILOG]]:
666+
// OGCG: ret void
667+
530668
int nested_switch(int a) {
531669
switch (int b = 1; a) {
532670
case 0:

0 commit comments

Comments
 (0)