33import com .contrastsecurity .sarif .Result ;
44import com .github .javaparser .Range ;
55import com .github .javaparser .ast .CompilationUnit ;
6+ import com .github .javaparser .ast .ImportDeclaration ;
67import com .github .javaparser .ast .Node ;
78import com .github .javaparser .ast .NodeList ;
89import com .github .javaparser .ast .body .FieldDeclaration ;
910import com .github .javaparser .ast .body .Parameter ;
1011import com .github .javaparser .ast .body .VariableDeclarator ;
1112import com .github .javaparser .ast .expr .AnnotationExpr ;
1213import com .github .javaparser .ast .expr .BinaryExpr ;
14+ import com .github .javaparser .ast .expr .ConditionalExpr ;
1315import com .github .javaparser .ast .expr .Expression ;
16+ import com .github .javaparser .ast .expr .FieldAccessExpr ;
17+ import com .github .javaparser .ast .expr .LiteralExpr ;
1418import com .github .javaparser .ast .expr .MethodCallExpr ;
1519import com .github .javaparser .ast .expr .Name ;
1620import com .github .javaparser .ast .expr .NameExpr ;
1721import com .github .javaparser .ast .expr .NullLiteralExpr ;
1822import com .github .javaparser .ast .expr .SimpleName ;
23+ import com .github .javaparser .ast .expr .StringLiteralExpr ;
1924import com .github .javaparser .ast .nodeTypes .NodeWithAnnotations ;
2025import com .github .javaparser .ast .nodeTypes .NodeWithSimpleName ;
2126import com .github .javaparser .resolution .UnsolvedSymbolException ;
@@ -68,16 +73,15 @@ public ChangesResult onResultFound(
6873 * This codemod will not be executed if:
6974 *
7075 * <ol>
71- * <li>Variable was previously initialized to a not null value
7276 * <li>Variable has a previous not null assertion
7377 * <li>Variable has a {@link @NotNull} or {@link @Nonnull} annotation
78+ * <li>Variable was previously initialized to a not null value
7479 * </ol>
7580 */
7681 if (simpleNameOptional .isPresent ()
77- && (isSimpleNameANotNullInitializedVariableDeclarator (
78- variableDeclarators , simpleNameOptional .get ())
79- || hasSimpleNameNotNullAnnotation (cu , simpleNameOptional .get (), variableDeclarators )
80- || hasSimpleNamePreviousNullAssertion (cu , simpleNameOptional .get ()))) {
82+ && (hasSimpleNameNotNullAnnotation (cu , simpleNameOptional .get (), variableDeclarators )
83+ || hasSimpleNamePreviousNullAssertion (cu , simpleNameOptional .get ())
84+ || isSimpleNameANotNullInitializedVariableDeclarator (cu , simpleNameOptional .get ()))) {
8185 return ChangesResult .noChanges ;
8286 }
8387
@@ -238,21 +242,20 @@ private boolean isNotNullOrNonnullAnnotation(final AnnotationExpr annotation) {
238242
239243 /**
240244 * Checks if the provided {@link SimpleName} variable corresponds to a {@link VariableDeclarator}
241- * that was previously initialized to a non-null value .
245+ * that was previously initialized to a non-null expression .
242246 */
243247 private boolean isSimpleNameANotNullInitializedVariableDeclarator (
244- final List <VariableDeclarator > variableDeclarators , final SimpleName targetName ) {
245-
246- return targetName != null
247- && variableDeclarators .stream ()
248- .filter (declarator -> declarator .getName ().equals (targetName ))
249- .filter (declarator -> isPreviousNodeBefore (targetName , declarator .getName ()))
250- .anyMatch (
251- declarator ->
252- declarator
253- .getInitializer ()
254- .map (expr -> !(expr instanceof NullLiteralExpr ))
255- .orElse (false ));
248+ final CompilationUnit cu , final SimpleName targetName ) {
249+
250+ final Optional <VariableDeclarator > variableDeclaratorOptional =
251+ getDeclaredVariable (cu , targetName );
252+
253+ if (variableDeclaratorOptional .isEmpty ()
254+ || variableDeclaratorOptional .get ().getInitializer ().isEmpty ()) {
255+ return false ;
256+ }
257+
258+ return isNullSafeExpression (cu , variableDeclaratorOptional .get ().getInitializer ().get ());
256259 }
257260
258261 /**
@@ -269,6 +272,163 @@ private Optional<SimpleName> getSimpleNameFromMethodCallExpr(
269272 return simpleNames .isEmpty () ? Optional .empty () : Optional .of (simpleNames .get (0 ));
270273 }
271274
275+ private boolean isNullSafeExpression (final CompilationUnit cu , final Expression expression ) {
276+ if (expression instanceof NullLiteralExpr ) {
277+ return false ;
278+ }
279+
280+ if (expression instanceof MethodCallExpr methodCallExpr ) {
281+ return isNullSafeMethodExpr (cu , methodCallExpr );
282+ }
283+
284+ if (expression instanceof ConditionalExpr conditionalExpr ) {
285+ return isNullSafeExpression (cu , conditionalExpr .getThenExpr ())
286+ && isNullSafeExpression (cu , conditionalExpr .getElseExpr ());
287+ }
288+
289+ if (expression instanceof NameExpr nameExpr ) {
290+ return isSimpleNameANotNullInitializedVariableDeclarator (cu , nameExpr .getName ());
291+ }
292+
293+ return expression instanceof LiteralExpr ;
294+ }
295+
296+ private boolean isNullSafeMethodExpr (
297+ final CompilationUnit cu , final MethodCallExpr methodCallExpr ) {
298+ final Optional <Expression > optionalScope = methodCallExpr .getScope ();
299+
300+ final String method = methodCallExpr .getName ().getIdentifier ();
301+
302+ // Static import case for example: import static
303+ // org.apache.commons.lang3.StringUtils.defaultString
304+ if (optionalScope .isEmpty ()) {
305+ return isNullSafeImportLibrary (cu , methodCallExpr .getName ().getIdentifier (), method );
306+ }
307+
308+ final Expression scope = optionalScope .get ();
309+
310+ // Using java.lang.String's method
311+ if (scope instanceof StringLiteralExpr ) {
312+ return commonMethodsThatCantReturnNull .contains ("java.lang.String#" .concat (method ));
313+ }
314+
315+ // Using full import name as scope of method, for example
316+ // String str = org.apache.commons.lang3.StringUtils.defaultString("")
317+ if (scope instanceof FieldAccessExpr fieldAccessExpr ) {
318+ final String fullImportName = fieldAccessExpr .toString ();
319+ return commonMethodsThatCantReturnNull .contains (fullImportName .concat ("#" ).concat (method ));
320+ }
321+
322+ if (scope instanceof NameExpr scopeName ) {
323+
324+ if (!isVariable (cu , scopeName )) {
325+ // check if scope is non-static import like: import org.apache.commons.lang3.StringUtils
326+ return isNullSafeImportLibrary (cu , scopeName .getName ().getIdentifier (), method );
327+ }
328+
329+ final Optional <VariableDeclarator > variableDeclaratorOptional =
330+ getDeclaredVariable (cu , scopeName .getName ());
331+
332+ if (variableDeclaratorOptional .isEmpty ()) {
333+ return false ;
334+ }
335+
336+ final String type = variableDeclaratorOptional .get ().getTypeAsString ();
337+
338+ // when scope is an object variable, check class type to determine if it is an implicit or
339+ // explicit import
340+ return isClassObjectMethodNullSafe (cu , type , method );
341+ }
342+
343+ return false ;
344+ }
345+
346+ /** Some basic java lang type classes */
347+ private boolean isClassObjectMethodNullSafe (
348+ final CompilationUnit cu , final String type , final String method ) {
349+ switch (type ) {
350+ case "String" -> {
351+ return commonMethodsThatCantReturnNull .contains ("java.lang.String#" .concat (method ));
352+ }
353+ case "Integer" -> {
354+ return commonMethodsThatCantReturnNull .contains ("java.lang.Integer#" .concat (method ));
355+ }
356+ case "Double" -> {
357+ return commonMethodsThatCantReturnNull .contains ("java.lang.Double#" .concat (method ));
358+ }
359+ case "Character" -> {
360+ return commonMethodsThatCantReturnNull .contains ("java.lang.Character#" .concat (method ));
361+ }
362+ case "Long" -> {
363+ return commonMethodsThatCantReturnNull .contains ("java.lang.Long#" .concat (method ));
364+ }
365+ default -> {
366+ return isNullSafeImportLibrary (cu , type , method );
367+ }
368+ }
369+ }
370+
371+ private boolean isVariable (final CompilationUnit cu , final NameExpr nameExpr ) {
372+ final SimpleName simpleName = nameExpr .getName ();
373+ final Optional <VariableDeclarator > variableDeclaratorOptional =
374+ getDeclaredVariable (cu , simpleName );
375+ return variableDeclaratorOptional .isPresent ();
376+ }
377+
378+ private boolean isNullSafeImportLibrary (
379+ final CompilationUnit cu , final String identifier , final String method ) {
380+ final Optional <ImportDeclaration > optionalImport =
381+ cu .getImports ().stream ()
382+ .filter (importName -> importName .getName ().getIdentifier ().equals (identifier ))
383+ .findFirst ();
384+
385+ if (optionalImport .isEmpty ()) {
386+ return false ;
387+ }
388+
389+ if (optionalImport .get ().isStatic ()
390+ && optionalImport .get ().getName ().getQualifier ().isEmpty ()) {
391+ return false ;
392+ }
393+
394+ final Name importDeclaration =
395+ optionalImport .get ().isStatic ()
396+ ? optionalImport .get ().getName ().getQualifier ().get ()
397+ : optionalImport .get ().getName ();
398+
399+ return commonMethodsThatCantReturnNull .contains (
400+ importDeclaration .asString ().concat ("#" ).concat (method ));
401+ }
402+
403+ private Optional <VariableDeclarator > getDeclaredVariable (
404+ final CompilationUnit cu , final SimpleName simpleName ) {
405+ final List <VariableDeclarator > variableDeclarators = cu .findAll (VariableDeclarator .class );
406+ return variableDeclarators .stream ()
407+ .filter (declarator -> declarator .getName ().equals (simpleName ))
408+ .filter (declarator -> isPreviousNodeBefore (simpleName , declarator .getName ()))
409+ .findFirst ();
410+ }
411+
272412 private static final Set <String > flippableComparisonMethods =
273413 Set .of ("equals" , "equalsIgnoreCase" );
414+
415+ private static final List <String > commonMethodsThatCantReturnNull =
416+ List .of (
417+ "org.apache.commons.lang3.StringUtils#defaultString" ,
418+ "java.lang.String#concat" ,
419+ "java.lang.String#replace" ,
420+ "java.lang.String#replaceAll" ,
421+ "java.lang.String#replaceFirst" ,
422+ "java.lang.String#join" ,
423+ "java.lang.String#substring" ,
424+ "java.lang.String#substring" ,
425+ "java.lang.String#toLowerCase" ,
426+ "java.lang.String#toUpperCase" ,
427+ "java.lang.String#trim" ,
428+ "java.lang.String#strip" ,
429+ "java.lang.String#stripLeading" ,
430+ "java.lang.String#stripTrailing" ,
431+ "java.lang.String#toString" ,
432+ "java.lang.String#valueOf" ,
433+ "java.lang.String#formatted" );
274434}
0 commit comments