Skip to content

Commit 20af7b2

Browse files
committed
in progress: actually check the type hints
1 parent fe2a973 commit 20af7b2

File tree

6 files changed

+127
-62
lines changed

6 files changed

+127
-62
lines changed

.github/workflows/type-check.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: type checking
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- inheritance_explorer/**/*.py
7+
- pyproject.toml
8+
- requirements/typecheck.txt
9+
- .github/workflows/type-checking.yaml
10+
workflow_dispatch:
11+
12+
jobs:
13+
build:
14+
runs-on: ubuntu-latest
15+
name: type check
16+
timeout-minutes: 60
17+
18+
steps:
19+
- name: Checkout repo
20+
uses: actions/checkout@v4
21+
22+
- name: Set up Python
23+
uses: actions/setup-python@v5
24+
with:
25+
# run with oldest supported python version
26+
# so that we always get compatible versions of
27+
# core dependencies at type-check time
28+
python-version: '3.10'
29+
30+
- name: Build
31+
run: |
32+
python3 -m pip install --upgrade pip
33+
python3 -m pip install -r requirements/typecheck.txt
34+
35+
- name: list installed deps
36+
run: python -m pip list
37+
38+
- name: Run mypy
39+
run: mypy inheritance_explorer

inheritance_explorer/inheritance_explorer.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import collections
22
import inspect
33
import textwrap
4-
from typing import Any, List, Optional, Tuple, Union
4+
from typing import Any, Mapping, Optional, Tuple, Union
55

66
import matplotlib.pyplot as plt
77
import networkx as nx
88
import numpy as np
9+
import numpy.typing as npt
910
import pydot
1011
from matplotlib.axes import Axes
1112
from matplotlib.colors import rgb2hex
@@ -41,7 +42,7 @@ def child_id(self) -> str:
4142
return str(self._child_id)
4243

4344
@property
44-
def parent_id(self) -> str:
45+
def parent_id(self) -> str | None:
4546
if self._parent_id:
4647
return str(self._parent_id)
4748
return
@@ -76,11 +77,11 @@ def __init__(
7677
self,
7778
baseclass: Any,
7879
funcname: Optional[str] = None,
79-
default_color: Optional[str] = "#000000",
80-
func_override_color: Optional[str] = "#ff0000",
81-
similarity_cutoff: Optional[float] = 0.75,
82-
max_recursion_level: Optional[int] = 500,
83-
classes_to_exclude: Optional[List[str]] = None,
80+
default_color: str = "#000000",
81+
func_override_color: str = "#ff0000",
82+
similarity_cutoff: float = 0.75,
83+
max_recursion_level: int = 500,
84+
classes_to_exclude: Optional[list[str]] = None,
8485
):
8586

8687
self.baseclass = baseclass
@@ -90,20 +91,22 @@ def __init__(
9091
self._nodenum: int = 0
9192
self._node_list = [] # a list of unique ChildNodes
9293
self._node_map = {} # map of global node index to node name
93-
self._override_src = collections.OrderedDict()
94+
self._override_src: Mapping[int, str] = collections.OrderedDict()
9495
self._override_src_files = {}
9596
self._current_node = 1 # the current global node, must start at 1
9697
self._default_color = default_color
9798
self._override_color = func_override_color
9899
self._graphviz_args_kwargs = {}
99100
self.similarity_container = None
100-
self.similarity_results = None
101+
self.similarity_results: dict[str, npt.NDArray]
101102
self.similarity_cutoff = similarity_cutoff
102103
if classes_to_exclude is None:
103104
classes_to_exclude = []
104105
self.classes_to_exclude = classes_to_exclude
105106
self._build()
106-
self._node_map_r = {v: k for k, v in self._node_map.items()} # name to index
107+
self._node_map_r: Mapping[str, int] = {
108+
v: k for k, v in self._node_map.items()
109+
} # name to index
107110

108111
def _get_source_info(self, obj) -> Optional[str]:
109112
f = getattr(obj, self.funcname)
@@ -346,9 +349,9 @@ def plot_similarity(
346349
def build_interactive_graph(
347350
self,
348351
include_similarity: bool = True,
349-
node_style: dict = None,
350-
edge_style: dict = None,
351-
similarity_edge_style: dict = None,
352+
node_style: dict[str, Any] | None = None,
353+
edge_style: dict[str, Any] | None = None,
354+
similarity_edge_style: dict[str, Any] | None = None,
352355
override_node_color: Union[str, tuple] = None,
353356
**kwargs,
354357
) -> Network:
@@ -454,22 +457,22 @@ def build_interactive_graph(
454457
network_wrapper.from_nx(grph)
455458
return network_wrapper
456459

457-
def get_source_code(self, node: Union[str, int]) -> str:
460+
def get_source_code(self, node: int | str) -> str:
458461
"""
459462
retrieve the source code of the comparison function for a
460463
specified node
461464
462465
Parameters
463466
----------
464-
node: Union[str, int]
467+
node: int
465468
the node to fetch the source code for
466469
467470
Returns
468471
-------
469472
str
470473
a string containing the source code for the node.
471474
"""
472-
if node in self._override_src:
475+
if isinstance(node, int) and node in self._override_src:
473476
return self._override_src[node]
474477
elif isinstance(node, str) and node in self._node_map_r:
475478
node_id = self._node_map_r[node]
@@ -481,7 +484,9 @@ def get_source_code(self, node: Union[str, int]) -> str:
481484
)
482485
raise KeyError(f"Could not find node for {node}")
483486

484-
def get_multiple_source_code(self, node_1: Union[str, int], *args) -> dict:
487+
def get_multiple_source_code(
488+
self, node_1: str | int, *args
489+
) -> Mapping[str | int, str]:
485490
"""
486491
Retrieve the source code for multiple nodes
487492
@@ -515,11 +520,11 @@ def display_code_comparison(self):
515520
display_code_compare(self)
516521

517522

518-
def _validate_color(clr, default_rgb_tuple: tuple) -> str:
523+
def _validate_color(clr, default_rgb_tuple: tuple[float, float, float]) -> str:
519524
if clr is None:
520-
return rgb2hex(default_rgb_tuple)
525+
return str(rgb2hex(default_rgb_tuple))
521526
elif isinstance(clr, tuple):
522-
return rgb2hex(clr)
527+
return str(rgb2hex(clr))
523528
elif isinstance(clr, str):
524529
return clr
525530
msg = f"clr has unexpected type: {type(clr)}"

inheritance_explorer/similarity.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,39 @@
11
import abc
22
import collections
3-
from typing import Any, Dict, List, Optional, OrderedDict, Tuple
3+
from typing import Any, Optional, OrderedDict
44

55
import numpy as np
6+
import numpy.typing as npt
67
import pycode_similar
78

89

910
class ResultsContainer:
10-
def __init__(self, results_dict: dict):
11-
for ky, val in results_dict.items():
12-
setattr(self, ky, val)
11+
def __init__(
12+
self,
13+
count: int,
14+
total: int,
15+
similarity_fraction: float,
16+
base_class: Any,
17+
this_class: Any,
18+
):
19+
20+
self.count = count
21+
self.total = total
22+
self.similarity_fraction = similarity_fraction
23+
self.base_class = base_class
24+
self.this_class = this_class
25+
26+
27+
_sim_results_tuple = tuple[
28+
dict[int | str, OrderedDict[int | str, ResultsContainer]],
29+
npt.NDArray,
30+
tuple[int | str, ...],
31+
]
1332

1433

1534
class SimilarityContainer(abc.ABC):
1635

17-
_valid_methods: List[str] = ["permute", "reference"]
36+
_valid_methods: list[str] = ["permute", "reference"]
1837

1938
def __init__(self, method: str = "reference"):
2039
if method not in self._valid_methods:
@@ -24,7 +43,9 @@ def __init__(self, method: str = "reference"):
2443
self.method = method
2544
self.results = None # for storing results of similarity tests
2645

27-
def run(self, source_dict: OrderedDict[Any, str], reference: Optional[Any] = None):
46+
def run(
47+
self, source_dict: OrderedDict[int | str, str], reference: Optional[Any] = None
48+
):
2849
"""
2950
source_dict : dict
3051
dictionary mapping a node identifier to a source code string
@@ -44,7 +65,7 @@ def run(self, source_dict: OrderedDict[Any, str], reference: Optional[Any] = Non
4465
@abc.abstractmethod
4566
def _permute_and_run(
4667
self, source_dict: OrderedDict[Any, str]
47-
) -> Tuple[Dict, np.ndarray, tuple]:
68+
) -> _sim_results_tuple:
4869
pass
4970

5071
@abc.abstractmethod
@@ -56,8 +77,8 @@ def _compare_single_set(
5677

5778
class PycodeSimilarity(SimilarityContainer):
5879
def _compare_single_set(
59-
self, source_dict: OrderedDict[Any, str], reference: Any
60-
) -> OrderedDict[Any, ResultsContainer]:
80+
self, source_dict: OrderedDict[int | str, str], reference: Any
81+
) -> OrderedDict[str | int, ResultsContainer]:
6182

6283
src = source_dict[reference] # extract the reference
6384
# this will result in a self-comparison, but that is OK and makes some
@@ -68,24 +89,22 @@ def _compare_single_set(
6889
similarity = pycode_similar.detect(src_list)
6990

7091
results = collections.OrderedDict()
71-
for id, sim in zip(source_dict.keys(), similarity):
72-
results[id] = ResultsContainer(
73-
{
74-
"count": sim[1][0].plagiarism_count,
75-
"total": sim[1][0].total_count,
76-
"similarity_fraction": sim[1][0].plagiarism_percent,
77-
"base_class": reference,
78-
"this_class": id,
79-
}
92+
for class_id, sim in zip(source_dict.keys(), similarity):
93+
results[class_id] = ResultsContainer(
94+
count=sim[1][0].plagiarism_count,
95+
total=sim[1][0].total_count,
96+
similarity_fraction=sim[1][0].plagiarism_percent,
97+
base_class=reference,
98+
this_class=class_id,
8099
)
81100
return results
82101

83102
def _permute_and_run(
84-
self, source_dict: OrderedDict[Any, str]
85-
) -> Tuple[Dict, np.ndarray, tuple]:
103+
self, source_dict: OrderedDict[int | str, str]
104+
) -> _sim_results_tuple:
86105
N = len(source_dict)
87106
similarity_matrix = np.ones((N, N))
88-
results_by_ref = {}
107+
results_by_ref: dict[int | str, OrderedDict[int | str, ResultsContainer]] = {}
89108
sim_axis = tuple([i for i in source_dict.keys()])
90109
for iref, ref in enumerate(source_dict.keys()):
91110
results = self._compare_single_set(source_dict.copy(), ref)

inheritance_explorer/tests/test_similarity.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,7 @@
33
import numpy as np
44
import pytest
55

6-
from inheritance_explorer.similarity import PycodeSimilarity, ResultsContainer
7-
8-
9-
def test_results_container():
10-
11-
results = {"a": 1, "b": "c", "c": 3}
12-
r = ResultsContainer(results)
13-
for key, val in results.items():
14-
assert getattr(r, key) == val
6+
from inheritance_explorer.similarity import PycodeSimilarity
157

168

179
@pytest.fixture
@@ -41,25 +33,22 @@ def sample_source_dict():
4133
return source_dict, s_matrix.astype(bool)
4234

4335

44-
def check_result(results, source_dict, ref, expected_match):
45-
for key, _ in source_dict.items():
46-
if key != ref:
47-
assert key in results
48-
f = results[key]
49-
if expected_match[key]:
50-
assert f.similarity_fraction == 1.0
51-
else:
52-
assert f.similarity_fraction < 1.0
53-
54-
5536
def test_pycode_similarity_single_ref(sample_source_dict):
5637

5738
s_dict, s_matrix = sample_source_dict
5839
ref = "a"
5940
p = PycodeSimilarity()
6041
results = p.run(s_dict, reference=ref)
6142
s_bool = dict(zip(s_dict.keys(), s_matrix[:, 0]))
62-
check_result(results, s_dict, ref, expected_match=s_bool)
43+
44+
for key, _ in s_dict.items():
45+
if key != ref:
46+
assert key in results
47+
f = results[key]
48+
if s_bool[key]:
49+
assert f.similarity_fraction == 1.0
50+
else:
51+
assert f.similarity_fraction < 1.0
6352

6453

6554
def test_pycode_similarity_permuted(sample_source_dict):

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,15 @@ docs = [
6767
"nbsphinx",
6868
]
6969

70+
[tool.mypy]
71+
files = ["inheritance_explorer",]
72+
python_version = "3.10"
73+
warn_unused_configs = true
74+
strict = true
75+
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
76+
warn_unreachable = true
77+
disallow_untyped_defs = false
78+
disallow_incomplete_defs = false
79+
implicit_optional = true
80+
disable_error_code = ["import-untyped", "import-not-found"]
81+
no_implicit_reexport = false

requirements/typecheck.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
mypy

0 commit comments

Comments
 (0)