1717
1818import lombok .AllArgsConstructor ;
1919import lombok .EqualsAndHashCode ;
20+ import lombok .RequiredArgsConstructor ;
21+ import lombok .Value ;
2022import org .openrewrite .ExecutionContext ;
2123import org .openrewrite .Option ;
22- import org .openrewrite .Recipe ;
24+ import org .openrewrite .ScanningRecipe ;
2325import org .openrewrite .TreeVisitor ;
2426import org .openrewrite .internal .ListUtils ;
2527import org .openrewrite .internal .lang .Nullable ;
2628import org .openrewrite .java .ChangeMethodAccessLevelVisitor ;
2729import org .openrewrite .java .JavaIsoVisitor ;
2830import org .openrewrite .java .MethodMatcher ;
29- import org .openrewrite .java .tree .*;
31+ import org .openrewrite .java .tree .Comment ;
32+ import org .openrewrite .java .tree .Flag ;
33+ import org .openrewrite .java .tree .J ;
34+ import org .openrewrite .java .tree .TypeUtils ;
3035
31- import java .util .ArrayList ;
32- import java .util .Collections ;
33- import java .util .List ;
34- import java .util .Set ;
36+ import java .util .*;
3537
3638@ AllArgsConstructor
3739@ EqualsAndHashCode (callSuper = false )
38- public class TestsShouldNotBePublic extends Recipe {
40+ public class TestsShouldNotBePublic extends ScanningRecipe < TestsShouldNotBePublic . Accumulator > {
3941
4042 @ Option (displayName = "Remove protected modifiers" ,
4143 description = "Also remove protected modifiers from test methods" ,
@@ -60,25 +62,46 @@ public Set<String> getTags() {
6062 }
6163
6264 @ Override
63- public TreeVisitor <?, ExecutionContext > getVisitor ( ) {
64- return new TestsNotPublicVisitor ( Boolean . TRUE . equals ( removeProtectedModifiers ) );
65+ public Accumulator getInitialValue ( ExecutionContext ctx ) {
66+ return new Accumulator ( );
6567 }
6668
69+ @ Override
70+ public TreeVisitor <?, ExecutionContext > getScanner (Accumulator acc ) {
71+ return new JavaIsoVisitor <ExecutionContext >() {
72+ @ Override
73+ public J .ClassDeclaration visitClassDeclaration (J .ClassDeclaration classDeclaration , ExecutionContext ctx ) {
74+ J .ClassDeclaration cd = super .visitClassDeclaration (classDeclaration , ctx );
75+ if (cd .getExtends () != null ) {
76+ acc .extendedClasses .add (String .valueOf (cd .getExtends ().getType ()));
77+ }
78+ return cd ;
79+ }
80+ };
81+ }
82+
83+ @ Override
84+ public TreeVisitor <?, ExecutionContext > getVisitor (Accumulator acc ) {
85+ return new TestsNotPublicVisitor (Boolean .TRUE .equals (removeProtectedModifiers ), acc );
86+ }
87+
88+ public static class Accumulator {
89+ Set <String > extendedClasses = new HashSet <>();
90+ }
91+
92+ @ RequiredArgsConstructor
6793 private static final class TestsNotPublicVisitor extends JavaIsoVisitor <ExecutionContext > {
6894 private final Boolean orProtected ;
69-
70- private TestsNotPublicVisitor (Boolean orProtected ) {
71- this .orProtected = orProtected ;
72- }
95+ private final Accumulator acc ;
7396
7497 @ Override
7598 public J .ClassDeclaration visitClassDeclaration (J .ClassDeclaration classDecl , ExecutionContext ctx ) {
7699 J .ClassDeclaration c = super .visitClassDeclaration (classDecl , ctx );
77100
78101 if (c .getKind () != J .ClassDeclaration .Kind .Type .Interface
79102 && c .getModifiers ().stream ().anyMatch (mod -> mod .getType () == J .Modifier .Type .Public )
80- && c .getModifiers ().stream ().noneMatch (mod -> mod .getType () == J .Modifier .Type .Abstract )) {
81-
103+ && c .getModifiers ().stream ().noneMatch (mod -> mod .getType () == J .Modifier .Type .Abstract )
104+ && ! acc . extendedClasses . contains ( String . valueOf ( c . getType ()))) {
82105 boolean hasTestMethods = c .getBody ().getStatements ().stream ()
83106 .filter (org .openrewrite .java .tree .J .MethodDeclaration .class ::isInstance )
84107 .map (J .MethodDeclaration .class ::cast )
0 commit comments