1616
1717package com .palantir .javaformat .java ;
1818
19+ import static java .lang .Math .max ;
20+ import static java .nio .charset .StandardCharsets .UTF_8 ;
21+
1922import com .google .common .base .CharMatcher ;
2023import com .google .common .collect .HashMultimap ;
24+ import com .google .common .collect .ImmutableList ;
25+ import com .google .common .collect .Iterables ;
2126import com .google .common .collect .Multimap ;
2227import com .google .common .collect .Range ;
2328import com .google .common .collect .RangeMap ;
3641import com .sun .source .util .TreePathScanner ;
3742import com .sun .source .util .TreeScanner ;
3843import com .sun .tools .javac .api .JavacTrees ;
44+ import com .sun .tools .javac .file .JavacFileManager ;
45+ import com .sun .tools .javac .parser .JavacParser ;
46+ import com .sun .tools .javac .parser .ParserFactory ;
3947import com .sun .tools .javac .tree .DCTree ;
4048import com .sun .tools .javac .tree .DCTree .DCReference ;
4149import com .sun .tools .javac .tree .JCTree ;
4250import com .sun .tools .javac .tree .JCTree .JCCompilationUnit ;
4351import com .sun .tools .javac .tree .JCTree .JCFieldAccess ;
44- import com .sun .tools .javac .tree .JCTree .JCIdent ;
4552import com .sun .tools .javac .tree .JCTree .JCImport ;
4653import com .sun .tools .javac .util .Context ;
54+ import com .sun .tools .javac .util .Log ;
4755import com .sun .tools .javac .util .Options ;
56+ import java .io .IOError ;
57+ import java .io .IOException ;
4858import java .lang .reflect .Method ;
59+ import java .net .URI ;
4960import java .util .LinkedHashSet ;
5061import java .util .List ;
5162import java .util .Map ;
5263import java .util .Set ;
64+ import javax .tools .Diagnostic ;
65+ import javax .tools .DiagnosticCollector ;
66+ import javax .tools .DiagnosticListener ;
67+ import javax .tools .JavaFileObject ;
68+ import javax .tools .SimpleJavaFileObject ;
69+ import javax .tools .StandardLocation ;
5370
5471/**
5572 * Removes unused imports from a source file. Imports that are only used in javadoc are also removed, and the references
@@ -76,15 +93,12 @@ public class RemoveUnusedImports {
7693 private static final class UnusedImportScanner extends TreePathScanner <Void , Void > {
7794
7895 private final Set <String > usedNames = new LinkedHashSet <>();
79-
8096 private final Multimap <String , Range <Integer >> usedInJavadoc = HashMultimap .create ();
81-
82- final JavacTrees trees ;
83- final DocTreeScanner docTreeSymbolScanner ;
97+ private final DocTreeScanner docTreeSymbolScanner = new DocTreeScanner ();
98+ private final JavacTrees trees ;
8499
85100 private UnusedImportScanner (JavacTrees trees ) {
86101 this .trees = trees ;
87- docTreeSymbolScanner = new DocTreeScanner ();
88102 }
89103
90104 /** Skip the imports themselves when checking for usage. */
@@ -202,21 +216,59 @@ public Void visitIdentifier(IdentifierTree node, Void aVoid) {
202216 }
203217 }
204218
205- public static String removeUnusedImports (final String contents ) throws FormatterException {
219+ public static String removeUnusedImports (final String contents ) {
206220 Context context = new Context ();
207221 JCCompilationUnit unit = parse (context , contents );
208- if (unit == null ) {
209- // error handling is done during formatting
210- return contents ;
211- }
212222 UnusedImportScanner scanner = new UnusedImportScanner (JavacTrees .instance (context ));
213223 scanner .scan (unit , null );
214- return applyReplacements (contents , buildReplacements (contents , unit , scanner .usedNames , scanner .usedInJavadoc ));
224+ String s = applyReplacements (
225+ contents , buildReplacements (contents , unit , scanner .usedNames , scanner .usedInJavadoc ));
226+
227+ // Normalize newlines while preserving important blank lines
228+ String sep = Newlines .guessLineSeparator (contents );
229+
230+ // Ensure exactly one blank line after package declaration
231+ s = s .replaceAll ("(?m)^(package .+)" + sep + "\\ s+" + sep , "$1" + sep + sep );
232+
233+ // Ensure exactly one blank line between last import and class declaration
234+ s = s .replaceAll ("(?m)^(import .+)" + sep + "\\ s+" + sep + "(?=class|interface|enum|record)" , "$1" + sep + sep );
235+
236+ // Remove multiple blank lines elsewhere in imports section
237+ s = s .replaceAll ("(?m)^(import .+)" + sep + "\\ s+" + sep + "(?=import)" , "$1" + sep );
238+
239+ return s ;
215240 }
216241
217- private static JCCompilationUnit parse (Context context , String javaInput ) throws FormatterException {
242+ private static JCCompilationUnit parse (Context context , String javaInput ) {
243+ DiagnosticCollector <JavaFileObject > diagnostics = new DiagnosticCollector <>();
244+ context .put (DiagnosticListener .class , diagnostics );
245+ Options .instance (context ).put ("--enable-preview" , "true" );
218246 Options .instance (context ).put ("allowStringFolding" , "false" );
219- return Formatter .parseJcCompilationUnit (context , javaInput );
247+ JCCompilationUnit unit ;
248+ try (JavacFileManager fileManager = new JavacFileManager (context , true , UTF_8 )) {
249+ fileManager .setLocation (StandardLocation .PLATFORM_CLASS_PATH , ImmutableList .of ());
250+ } catch (IOException e ) {
251+ // impossible
252+ throw new IOError (e );
253+ }
254+ SimpleJavaFileObject source = new SimpleJavaFileObject (URI .create ("source" ), JavaFileObject .Kind .SOURCE ) {
255+ @ Override
256+ public CharSequence getCharContent (boolean ignoreEncodingErrors ) {
257+ return javaInput ;
258+ }
259+ };
260+ Log .instance (context ).useSource (source );
261+ ParserFactory parserFactory = ParserFactory .instance (context );
262+ JavacParser parser = parserFactory .newParser (
263+ javaInput , /* keepDocComments= */ true , /* keepEndPos= */ true , /* keepLineMap= */ true );
264+ unit = parser .parseCompilationUnit ();
265+ unit .sourcefile = source ;
266+ Iterable <Diagnostic <? extends JavaFileObject >> errorDiagnostics =
267+ Iterables .filter (diagnostics .getDiagnostics (), Formatter ::errorDiagnostic );
268+ if (!Iterables .isEmpty (errorDiagnostics )) {
269+ // error handling is done during formatting
270+ }
271+ return unit ;
220272 }
221273
222274 /** Construct replacements to fix unused imports. */
@@ -226,53 +278,66 @@ private static RangeMap<Integer, String> buildReplacements(
226278 Set <String > usedNames ,
227279 Multimap <String , Range <Integer >> usedInJavadoc ) {
228280 RangeMap <Integer , String > replacements = TreeRangeMap .create ();
229- for (JCImport importTree : unit .getImports ()) {
281+ int size = unit .getImports ().size ();
282+ JCTree lastImport = size > 0 ? unit .getImports ().get (size - 1 ) : null ;
283+ for (JCTree importTree : unit .getImports ()) {
230284 String simpleName = getSimpleName (importTree );
231285 if (!isUnused (unit , usedNames , usedInJavadoc , importTree , simpleName )) {
232286 continue ;
233287 }
234288 // delete the import
235289 int endPosition = importTree .getEndPosition (unit .endPositions );
236- endPosition = Math . max (CharMatcher .isNot (' ' ).indexIn (contents , endPosition ), endPosition );
290+ endPosition = max (CharMatcher .isNot (' ' ).indexIn (contents , endPosition ), endPosition );
237291 String sep = Newlines .guessLineSeparator (contents );
292+
293+ // Check if there's an empty line after this import
294+ boolean hasEmptyLineAfter = false ;
295+ if (endPosition + sep .length () * 2 <= contents .length ()) {
296+ String nextTwoLines = contents .substring (endPosition , endPosition + sep .length () * 2 );
297+ hasEmptyLineAfter = nextTwoLines .equals (sep + sep );
298+ }
299+
238300 if (endPosition + sep .length () < contents .length ()
239301 && contents .subSequence (endPosition , endPosition + sep .length ())
240302 .toString ()
241303 .equals (sep )) {
242304 endPosition += sep .length ();
243305 }
306+
307+ // If this isn't the last import and there's an empty line after, preserve it
308+ if ((size == 1 || importTree != lastImport ) && !hasEmptyLineAfter ) {
309+ while (endPosition + sep .length () <= contents .length ()
310+ && contents .regionMatches (endPosition , sep , 0 , sep .length ())) {
311+ endPosition += sep .length ();
312+ }
313+ }
244314 replacements .put (Range .closedOpen (importTree .getStartPosition (), endPosition ), "" );
245315 }
246316 return replacements ;
247317 }
248318
249- private static String getSimpleName (ImportTree importTree ) {
250- return importTree .getQualifiedIdentifier () instanceof JCIdent
251- ? ((JCIdent ) importTree .getQualifiedIdentifier ()).getName ().toString ()
252- : ((JCFieldAccess ) importTree .getQualifiedIdentifier ())
253- .getIdentifier ()
254- .toString ();
319+ private static String getSimpleName (JCTree importTree ) {
320+ return getQualifiedIdentifier (importTree ).getIdentifier ().toString ();
255321 }
256322
257323 private static boolean isUnused (
258324 JCCompilationUnit unit ,
259325 Set <String > usedNames ,
260326 Multimap <String , Range <Integer >> usedInJavadoc ,
261- ImportTree importTree ,
327+ JCTree importTree ,
262328 String simpleName ) {
263- String qualifier = ((JCFieldAccess ) importTree .getQualifiedIdentifier ())
264- .getExpression ()
265- .toString ();
329+ JCFieldAccess qualifiedIdentifier = getQualifiedIdentifier (importTree );
330+ String qualifier = qualifiedIdentifier .getExpression ().toString ();
266331 if (qualifier .equals ("java.lang" )) {
267332 return true ;
268333 }
334+ if (usedNames .contains (simpleName )) {
335+ return false ;
336+ }
269337 if (unit .getPackageName () != null && unit .getPackageName ().toString ().equals (qualifier )) {
270338 return true ;
271339 }
272- if (importTree .getQualifiedIdentifier () instanceof JCFieldAccess
273- && ((JCFieldAccess ) importTree .getQualifiedIdentifier ())
274- .getIdentifier ()
275- .contentEquals ("*" )) {
340+ if (qualifiedIdentifier .getIdentifier ().contentEquals ("*" ) && !((JCImport ) importTree ).isStatic ()) {
276341 return false ;
277342 }
278343
@@ -285,6 +350,16 @@ private static boolean isUnused(
285350 return true ;
286351 }
287352
353+ private static JCFieldAccess getQualifiedIdentifier (JCTree importTree ) {
354+ // Use reflection because the return type is JCTree in some versions and JCFieldAccess in others
355+ try {
356+ return (JCFieldAccess )
357+ JCImport .class .getMethod ("getQualifiedIdentifier" ).invoke (importTree );
358+ } catch (ReflectiveOperationException e ) {
359+ throw new LinkageError (e .getMessage (), e );
360+ }
361+ }
362+
288363 /** Applies the replacements to the given source, and re-format any edited javadoc. */
289364 private static String applyReplacements (String source , RangeMap <Integer , String > replacements ) {
290365 // save non-empty fixed ranges for reformatting after fixes are applied
0 commit comments