diff --git a/tests/rewrite_fn_ptr_eq/main.c b/tests/rewrite_fn_ptr_eq/main.c index f8fe55e222..9428abc7fd 100644 --- a/tests/rewrite_fn_ptr_eq/main.c +++ b/tests/rewrite_fn_ptr_eq/main.c @@ -27,7 +27,9 @@ Test(rewrite_fn_ptr_eq, main) { int res; int *y = &res; void *x = NULL; + // REWRITER: bin_op fn = IA2_FN(add); bin_op fn = add; + // REWRITER: bin_op fn2 = { NULL }; bin_op fn2 = NULL; // Check that pointers for types other than functions are not rewritten @@ -78,4 +80,37 @@ Test(rewrite_fn_ptr_eq, main) { if (y || !fn) { } // REWRITER: if (x && IA2_ADDR(fn) && y) { } if (x && fn && y) { } + + // REWRITER: fn = (typeof(fn)) { NULL }; + fn = NULL; + + // the following tests don't use NULL so the rewriter output shouldn't rely on it either +#undef NULL + // REWRITER: bin_op fn3 = { 0 }; + bin_op fn3 = 0; + + // REWRITER: bin_op fn4 = (typeof(fn)) { 0 }; + bin_op fn4 = (typeof(fn)) 0; + + // REWRITER: fn = (typeof(fn)) { 0 }; + fn = 0; + + // REWRITER: fn = (typeof(fn)) { 0 }; + fn = (typeof(fn)) 0; + + // check that literal zeroes aren't rewritten if not cast to function pointers + // REWRITER: res = 0; + res = 0; + + // REWRITER: if (IA2_ADDR(fn) == 0) { } + if (fn == 0) { } + + // REWRITER: if (IA2_ADDR(mod.fn) == 0) { } + if (mod.fn == 0) { } + + // REWRITER: if (IA2_ADDR(fn) == 0) { } + //if (fn == (typeof(fn)) 0) { } + + // REWRITER: if (IA2_ADDR(mod.fn) == 0) { } + //if (mod.fn == (typeof(fn)) 0) { } } diff --git a/tools/rewriter/SourceRewriter.cpp b/tools/rewriter/SourceRewriter.cpp index 4820d10045..8551da420c 100644 --- a/tools/rewriter/SourceRewriter.cpp +++ b/tools/rewriter/SourceRewriter.cpp @@ -404,8 +404,29 @@ class FnPtrNull : public RefactoringCallback { auto fn_ptr_typedef = hasType(typedefNameDecl(hasType(fn_ptr))); - auto null_expr = implicitCastExpr(ignoringParenCasts(nullPointerConstant())) - .bind("nullExpr"); + auto zero_literal = ignoringParenCasts(integerLiteral(equals(0)).bind("literalZero")); + auto null_macro = ignoringParenCasts(nullPointerConstant()); + auto null_value = anyOf(zero_literal, null_macro); + + auto null_expr = expr(anyOf( + implicitCastExpr(null_value), + cStyleCastExpr(zero_literal) + )).bind("nullExpr"); + //auto literal_zero = implicitCastExpr(ignoringParenCasts(integerLiteral(equals(0)))); + //auto cast_zero = cStyleCastExpr(ignoringParenCasts(integerLiteral(equals(0)))); + //auto null_macro = implicitCastExpr(cStyleCastExpr(integerLiteral(equals(0)))); + //auto null_expr = expr(anyOf(literal_zero, cast_zero, null_macro)).bind("nullExpr"); + //auto literal_zero = integerLiteral(equals(0)).bind("literalZero"); + //auto null_macro = nullPointerConstant(); + + //auto null_expr = expr(anyOf( + // implicitCastExpr(ignoringParenCasts(null_macro)), + // implicitCastExpr(zero_literal), + //implicitCastExpr(anyOf(zero_literal, null_macro)), + //cStyleCastExpr(ignoringParenCasts(anyOf(zero_literal, null_macro))))).bind("nullExpr"); + + //auto null_expr = implicitCastExpr(ignoringParenCasts(anyOf(null_macro, zero_literal))) + // .bind("nullExpr"); auto null_fn_ptr = varDecl(hasInitializer(null_expr), anyOf(fn_ptr_typedef, hasType(fn_ptr))); @@ -418,6 +439,7 @@ class FnPtrNull : public RefactoringCallback { refactorer.addMatcher(assign_null, this); } virtual void run(const MatchFinder::MatchResult &result) { + bool spelled_as_zero = result.Nodes.getNodeAs("literalZero"); // The two matchers both have a nullExpr node so this getNodeAs can't fail auto *null_fn_ptr = result.Nodes.getNodeAs("nullExpr"); assert(null_fn_ptr != nullptr); @@ -436,6 +458,9 @@ class FnPtrNull : public RefactoringCallback { Filename filename = get_expansion_filename(loc, sm); std::string new_expr = "{ NULL }"; + if (spelled_as_zero) { + new_expr = "{ 0 }"; + } // If the matcher found an assignment add the type of the LHS variable to // new_expr if (auto *lhs_ptr = result.Nodes.getNodeAs("ptrLHS")) { @@ -443,7 +468,7 @@ class FnPtrNull : public RefactoringCallback { clang::CharSourceRange::getTokenRange(lhs_ptr->getSourceRange()); auto lhs_binding = clang::Lexer::getSourceText(char_range, sm, ctxt.getLangOpts()); - new_expr = "(typeof("s + lhs_binding.str() + ")) { NULL }"; + new_expr = "(typeof("s + lhs_binding.str() + ")) " + new_expr; } clang::CharSourceRange expansion_range = sm.getExpansionRange(loc);