diff --git a/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSet.java b/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSet.java index 45e3528ca0..7938c1b3eb 100644 --- a/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSet.java +++ b/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSet.java @@ -55,22 +55,33 @@ public TreeVisitor getVisitor() { @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { if (NEW_HASH_SET.matches(method)) { - maybeRemoveImport("com.google.common.collect.Sets"); - maybeAddImport("java.util.HashSet"); if (method.getArguments().isEmpty() || method.getArguments().get(0) instanceof J.Empty) { + maybeRemoveImport("com.google.common.collect.Sets"); + maybeAddImport("java.util.HashSet"); return JavaTemplate.builder("new HashSet<>()") .contextSensitive() .imports("java.util.HashSet") .build() .apply(getCursor(), method.getCoordinates().replace()); } - if (method.getArguments().size() == 1 && TypeUtils.isAssignableTo("java.util.Collection", method.getArguments().get(0).getType())) { - return JavaTemplate.builder("new HashSet<>(#{any(java.util.Collection)})") - .contextSensitive() - .imports("java.util.HashSet") - .build() - .apply(getCursor(), method.getCoordinates().replace(), method.getArguments().get(0)); + if (method.getArguments().size() == 1) { + // Only handle if it's a Collection (not just any Iterable) + if (TypeUtils.isAssignableTo("java.util.Collection", method.getArguments().get(0).getType())) { + maybeRemoveImport("com.google.common.collect.Sets"); + maybeAddImport("java.util.HashSet"); + return JavaTemplate.builder("new HashSet<>(#{any(java.util.Collection)})") + .contextSensitive() + .imports("java.util.HashSet") + .build() + .apply(getCursor(), method.getCoordinates().replace(), method.getArguments().get(0)); + } + // Skip Iterable-only cases to avoid generating broken code + if (TypeUtils.isAssignableTo("java.lang.Iterable", method.getArguments().get(0).getType())) { + return method; + } } + maybeRemoveImport("com.google.common.collect.Sets"); + maybeAddImport("java.util.HashSet"); maybeAddImport("java.util.Arrays"); JavaTemplate newHashSetVarargs = JavaTemplate.builder("new HashSet<>(Arrays.asList(" + method.getArguments().stream().map(a -> "#{any()}").collect(joining(",")) + "))") .contextSensitive() diff --git a/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSetTest.java b/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSetTest.java index aed1a27630..99f8b21294 100644 --- a/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSetTest.java +++ b/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaSetsNewHashSetTest.java @@ -118,4 +118,78 @@ class Test { ) ); } + + @Test + void setsNewHashSetWithIterablesFilter() { + //language=java + rewriteRun( + java( + """ + import java.util.ArrayList; + import java.util.List; + + import com.google.common.collect.Iterables; + import com.google.common.collect.Sets; + + class Test { + void test() { + final List result = new ArrayList(); + List myExceptions = new ArrayList(); + result.addAll(Sets.newHashSet(Iterables.filter(myExceptions, ClassCastException.class))); + } + } + """ + ) + ); + } + + @Test + void setsNewHashSetWithCustomIterable() { + //language=java + rewriteRun( + java( + """ + import com.google.common.collect.Sets; + + class Test { + void test(Iterable myIterable) { + var result = Sets.newHashSet(myIterable); + } + } + """ + ) + ); + } + + @Test + void setsNewHashSetWithCollectionStillWorks() { + //language=java + rewriteRun( + java( + """ + import com.google.common.collect.Sets; + + import java.util.List; + import java.util.Set; + + class Test { + public static void test(List myList) { + Set result = Sets.newHashSet(myList); + } + } + """, + """ + import java.util.HashSet; + import java.util.List; + import java.util.Set; + + class Test { + public static void test(List myList) { + Set result = new HashSet<>(myList); + } + } + """ + ) + ); + } }