Skip to content

Commit a0d033e

Browse files
author
Peter Hamfelt
committed
Add dataframe and series checkers
1 parent 7d3315a commit a0d033e

File tree

8 files changed

+114
-22
lines changed

8 files changed

+114
-22
lines changed

pylint_ml/checkers/pandas/pandas_dataframe_bool.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pylint.checkers.utils import only_required_for_messages
1212
from pylint.interfaces import HIGH
1313

14+
# Todo add version deprecated
15+
1416

1517
class PandasDataFrameBoolChecker(BaseChecker):
1618
name = "pandas-dataframe-bool"
@@ -26,13 +28,14 @@ class PandasDataFrameBoolChecker(BaseChecker):
2628
def visit_call(self, node: nodes.Call) -> None:
2729
if isinstance(node.func, nodes.Attribute):
2830
method_name = getattr(node.func, "attrname", None)
29-
module_name = getattr(node.func.expr, "name", None)
3031

31-
if method_name == "bool" and module_name == "pd":
32-
self.add_message("pandas-dataframe-bool", node=node, confidence=HIGH)
32+
if method_name == "bool":
33+
# Check if the object calling .bool() has a name starting with 'df_'
34+
object_name = getattr(node.func.expr, "name", None)
35+
if object_name and self._is_valid_dataframe_name(object_name):
36+
self.add_message("pandas-dataframe-bool", node=node, confidence=HIGH)
3337

34-
def _check_method_usage(self, node):
35-
method_name = getattr(node.func, "attrname", None)
36-
module_name = getattr(node.func.expr, "name", None)
37-
if method_name == "bool" and module_name == "pd":
38-
self.add_message("pandas-dataframe-bool", node=node, confidence=HIGH)
38+
@staticmethod
39+
def _is_valid_dataframe_name(name: str) -> bool:
40+
"""Check if the DataFrame name starts with 'df_'."""
41+
return name.startswith("df_")

pylint_ml/checkers/pandas/pandas_dataframe_naming.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,8 @@ def visit_assign(self, node: nodes.Assign) -> None:
2929
module_name = getattr(node.value.func.expr, "name", None)
3030

3131
if func_name == "DataFrame" and module_name == "pd":
32-
3332
for target in node.targets:
3433
if isinstance(target, nodes.AssignName):
3534
var_name = target.name
3635
if not var_name.startswith("df_") or len(var_name) <= 3:
3736
self.add_message("pandas-dataframe-naming", node=node, confidence=HIGH)
38-
39-
def _check_variable_name(self, var_name, node):
40-
if not var_name.startswith("df_") or len(var_name) <= 3:
41-
self.add_message("pandas-dataframe-naming", node=node, confidence=HIGH)

pylint_ml/checkers/pandas/pandas_series_bool.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pylint.checkers.utils import only_required_for_messages
1212
from pylint.interfaces import HIGH
1313

14+
# Todo add version deprecated
15+
1416

1517
class PandasSeriesBoolChecker(BaseChecker):
1618
name = "pandas-series-bool"
@@ -26,13 +28,14 @@ class PandasSeriesBoolChecker(BaseChecker):
2628
def visit_call(self, node: nodes.Call) -> None:
2729
if isinstance(node.func, nodes.Attribute):
2830
method_name = getattr(node.func, "attrname", None)
29-
module_name = getattr(node.func.expr, "name", None)
3031

31-
if method_name == "bool" and module_name == "pd":
32-
self.add_message("pandas-series-bool", node=node, confidence=HIGH)
32+
if method_name == "bool":
33+
# Check if the object calling .bool() has a name starting with 'ser'
34+
object_name = getattr(node.func.expr, "name", None)
35+
if object_name and self._is_valid_series_name(object_name):
36+
self.add_message("pandas-series-bool", node=node, confidence=HIGH)
3337

34-
def _check_method_usage(self, node):
35-
method_name = getattr(node.func, "attrname", None)
36-
module_name = getattr(node.func.expr, "name", None)
37-
if method_name == "bool" and module_name == "pd":
38-
self.add_message("pandas-series-bool", node=node, confidence=HIGH)
38+
@staticmethod
39+
def _is_valid_series_name(name: str) -> bool:
40+
"""Check if the Series name starts with 'ser_'."""
41+
return name.startswith("ser_")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for consistent naming of pandas Series variables."""
6+
7+
from __future__ import annotations
8+
9+
from astroid import nodes
10+
from pylint.checkers import BaseChecker
11+
from pylint.checkers.utils import only_required_for_messages
12+
from pylint.interfaces import HIGH
13+
14+
15+
class PandasSeriesNamingChecker(BaseChecker):
16+
name = "pandas-series-naming"
17+
msgs = {
18+
"W8103": (
19+
"Pandas Series variable names should start with 'ser_' followed by descriptive text",
20+
"pandas-series-naming",
21+
"Ensure that pandas Series variables follow the naming convention.",
22+
),
23+
}
24+
25+
@only_required_for_messages("pandas-series-naming")
26+
def visit_assign(self, node: nodes.Assign) -> None:
27+
print(node)
28+
if isinstance(node.value, nodes.Call):
29+
func_name = getattr(node.value.func, "attrname", None)
30+
module_name = getattr(node.value.func.expr, "name", None)
31+
32+
if func_name == "Series" and module_name == "pd":
33+
for target in node.targets:
34+
if isinstance(target, nodes.AssignName):
35+
var_name = target.name
36+
if not var_name.startswith("ser_") or len(var_name) <= 4:
37+
self.add_message("pandas-series-naming", node=node, confidence=HIGH)

tests/checkers/test_pandas/test_pandas_dataframe_bool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def test_dataframe_bool_usage(self):
2121
msg_id="pandas-dataframe-bool",
2222
confidence=HIGH,
2323
node=node,
24-
)
24+
),
25+
ignore_position=True,
2526
):
2627
self.checker.visit_call(node)
2728

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.pandas.pandas_series_naming import PandasSeriesNamingChecker
6+
7+
8+
class TestPandasSeriesNamingChecker(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = PandasSeriesNamingChecker
10+
11+
def test_series_correct_naming(self):
12+
node = astroid.extract_node(
13+
"""
14+
import pandas as pd
15+
ser_sales = pd.Series([100, 200, 300])
16+
"""
17+
)
18+
with self.assertNoMessages():
19+
self.checker.visit_assign(node)
20+
21+
def test_series_incorrect_naming(self):
22+
node = astroid.extract_node(
23+
"""
24+
import pandas as pd
25+
df_sales = pd.Series([100, 200, 300])
26+
"""
27+
)
28+
with self.assertAddsMessages(
29+
pylint.testutils.MessageTest(
30+
msg_id="pandas-series-naming",
31+
confidence=HIGH,
32+
node=node,
33+
),
34+
ignore_position=True,
35+
):
36+
self.checker.visit_assign(node)
37+
38+
def test_series_invalid_length_naming(self):
39+
node = astroid.extract_node(
40+
"""
41+
import pandas as pd
42+
ser_ = pd.Series([True])
43+
"""
44+
)
45+
with self.assertAddsMessages(
46+
pylint.testutils.MessageTest(
47+
msg_id="pandas-series-naming",
48+
confidence=HIGH,
49+
node=node,
50+
),
51+
ignore_position=True,
52+
):
53+
self.checker.visit_assign(node)

0 commit comments

Comments
 (0)