diff --git a/README.md b/README.md index c78d140..aad928e 100644 --- a/README.md +++ b/README.md @@ -41,11 +41,15 @@ Error codes | T201 | print found | | T203 | pprint found | | T204 | pprint declared | +| T205 | traceback print | Changes ------- +##### 5.1.0 - 2023-07-18 +* Add support for traceback.print_* functions + ##### 5.0.0 - 2022-04-30 * Move namespace from T0* to T2* to avoid collision with other library using same error code. diff --git a/flake8_print.py b/flake8_print.py index bb35b72..9a3a0b6 100644 --- a/flake8_print.py +++ b/flake8_print.py @@ -7,34 +7,66 @@ except ImportError: from flake8 import utils as stdin_utils -__version__ = "5.0.0" +__version__ = "5.1.0" PRINT_FUNCTION_NAME = "print" PPRINT_FUNCTION_NAME = "pprint" +TRACEBACK_FUNCTION_NAMES = ["print_tb", "print_exception", "print_exc", "print_last", "print_stack"] PRINT_FUNCTION_NAMES = [PRINT_FUNCTION_NAME, PPRINT_FUNCTION_NAME] VIOLATIONS = { - "found": {"print": "T201 print found.", "pprint": "T203 pprint found."}, + "found": { + "print": "T201 print found.", + "pprint": "T203 pprint found.", + }, "declared": {"print": "T202 Python 2.x reserved word print used.", "pprint": "T204 pprint declared"}, } +for func_name in TRACEBACK_FUNCTION_NAMES: + VIOLATIONS["found"][func_name] = "T205 traceback print found." + class PrintFinder(ast.NodeVisitor): def __init__(self, *args, **kwargs): super(PrintFinder, self).__init__(*args, **kwargs) self.prints_used = {} self.prints_redefined = {} + self.traceback_imports = [] + self.traceback_func_imports = {} + + def visit_Import(self, node: ast.Import): + for import_ in node.names: + if import_.name in ["traceback"] and hasattr(import_, "asname"): + self.traceback_imports.append(import_.asname) + + def visit_ImportFrom(self, node: ast.ImportFrom): + if node.module in ["traceback"]: + for import_ in node.names: + if import_.asname: + self.traceback_func_imports[import_.asname] = import_.name + else: + self.traceback_func_imports[import_.asname] = import_.name def visit_Call(self, node): - is_print_function = getattr(node.func, "id", None) in PRINT_FUNCTION_NAMES + is_print_function = ( + getattr(node.func, "id", None) in PRINT_FUNCTION_NAMES + or getattr(node.func, "id", None) in self.traceback_func_imports + ) is_print_function_attribute = ( - getattr(getattr(node.func, "value", None), "id", None) in PRINT_FUNCTION_NAMES - and getattr(node.func, "attr", None) in PRINT_FUNCTION_NAMES + getattr(getattr(node.func, "value", None), "id", None) + in PRINT_FUNCTION_NAMES + ["traceback"] + self.traceback_imports + and getattr(node.func, "attr", None) in PRINT_FUNCTION_NAMES + TRACEBACK_FUNCTION_NAMES ) if is_print_function: - self.prints_used[(node.lineno, node.col_offset)] = VIOLATIONS["found"][node.func.id] + func_name = node.func.id elif is_print_function_attribute: - self.prints_used[(node.lineno, node.col_offset)] = VIOLATIONS["found"][node.func.attr] + func_name = node.func.attr + else: + func_name = None + if func_name in self.traceback_func_imports: + func_name = self.traceback_func_imports[func_name] + if func_name is not None: + self.prints_used[(node.lineno, node.col_offset)] = VIOLATIONS["found"][func_name] self.generic_visit(node) def visit_FunctionDef(self, node): diff --git a/pyproject.toml b/pyproject.toml index b0eb423..dde7ddc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "flake8-print" -version = "5.0.0" +version = "5.1.0" description = "print statement checker plugin for flake8" readme = "README.md" diff --git a/test_linter.py b/test_linter.py index d1a07bc..515f772 100644 --- a/test_linter.py +++ b/test_linter.py @@ -56,6 +56,7 @@ def check_code_for_print_statements(code): T203 = "T203 pprint found." T202 = "T202 Python 2.x reserved word print used." T204 = "T204 pprint declared." +T205 = "T205 traceback print found." class TestGenericCases(object): @@ -91,6 +92,40 @@ def test_catches_print_invocation_in_lambda(self): assert result == [{"col": 14, "line": 1, "message": T201}] +class TestTracebackPrintCases(object): + prohibited_funcs = ["print_tb", "print_exception", "print_exc", "print_last", "print_stack"] + + def test_print_funcs(self): + for func in self.prohibited_funcs: + result = check_code_for_print_statements(f"import traceback; traceback.{func}()") + assert result == [{"col": 18, "line": 1, "message": T205}] + + def test_print_funcs_imported_from(self): + for func in self.prohibited_funcs: + result = check_code_for_print_statements(f"from traceback import {func}; {func}()") + assert result == [{"col": 24 + len(func), "line": 1, "message": T205}] + + def test_print_funcs_imported_from_as(self): + for func in self.prohibited_funcs: + result = check_code_for_print_statements(f"from traceback import {func} as pre; pre()") + assert result == [{"col": 31 + len(func), "line": 1, "message": T205}] + + def test_func_import_as(self): + for func in self.prohibited_funcs: + result = check_code_for_print_statements(f"import traceback as tb; tb.{func}()") + assert result == [{"col": 24, "line": 1, "message": T205}] + + +class TestTracebackPrintFalsePositiveCases(object): + def test_valid_import(self): + result = check_code_for_print_statements("import traceback") + assert result == [] + + def test_non_prohibited_func(self): + result = check_code_for_print_statements("import traceback; traceback.format_exception()") + assert result == [] + + class TestComments(object): def test_print_in_inline_comment_is_not_a_false_positive(self): result = check_code_for_print_statements("# what should I print ?")