Skip to content
This repository was archived by the owner on Jan 19, 2025. It is now read-only.

Commit 167d25c

Browse files
feat: creation of a differ abstract base class (#1116)
Closes #1113 ### Summary of Changes added an abstract differ base class called `AbstractDiffer` in new package `package_parser.migration` with the methods: - `compute_class_similarity` - `compute_attribute_similarity` - `compute_function_similarity` - `compute_parameter_similarity` - `compute_result_similarity` They should return a value between 0 and 1. The closer the returned value is to 1, the more similar the parameters are. ### Testing Instructions Not (yet) used as an abstract class - see the rewritten class Co-authored-by: Aclrian <[email protected]> Co-authored-by: Lars Reimann <[email protected]>
1 parent 6c30002 commit 167d25c

File tree

6 files changed

+9509
-2517
lines changed

6 files changed

+9509
-2517
lines changed

data/api/sklearn__api.json

Lines changed: 9340 additions & 2498 deletions
Large diffs are not rendered by default.

package-parser/package_parser/processing/api/_ast_visitor.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import re
3-
from typing import Optional, Union
3+
from typing import Any, Optional, Union
44

55
import astroid
66
from astroid import NodeNG
@@ -13,12 +13,15 @@
1313
Function,
1414
Import,
1515
Module,
16+
NamedType,
17+
UnionType,
1618
)
1719
from package_parser.utils import parent_qualified_name
1820

1921
from ._file_filters import _is_init_file
2022
from ._get_parameter_list import get_parameter_list
2123
from .documentation_parsing import AbstractDocumentationParser
24+
from .model._api import Attribute
2225

2326

2427
def trim_code(code, from_line_no, to_line_no, encoding):
@@ -141,9 +144,53 @@ def leave_module(self, _: astroid.Module) -> None:
141144

142145
self.api.add_module(module)
143146

147+
@staticmethod
148+
def get_type_of_attribute(infered_value: Any) -> Optional[str]:
149+
if infered_value == astroid.Uninferable:
150+
return None
151+
if isinstance(infered_value, astroid.Const) and infered_value.value is None:
152+
return None
153+
if isinstance(infered_value, astroid.List):
154+
return "list"
155+
if isinstance(infered_value, astroid.Dict):
156+
return "dict"
157+
if isinstance(infered_value, astroid.ClassDef):
158+
return "type"
159+
if isinstance(infered_value, astroid.Tuple):
160+
return "tuple"
161+
if isinstance(infered_value, (astroid.FunctionDef, astroid.Lambda)):
162+
return "Callable"
163+
if isinstance(infered_value, astroid.Const):
164+
return infered_value.value.__class__.__name__
165+
if isinstance(infered_value, astroid.Instance):
166+
return infered_value.name
167+
return None
168+
169+
@staticmethod
170+
def get_instance_attributes(instance_attributes: dict[str, Any]) -> list[Attribute]:
171+
attributes = []
172+
for name, assignments in instance_attributes.items():
173+
types = set()
174+
for assignment in assignments:
175+
if isinstance(assignment, astroid.AssignAttr) and isinstance(
176+
assignment.parent, astroid.Assign
177+
):
178+
attribute_type = _AstVisitor.get_type_of_attribute(
179+
next(astroid.inference.infer_attribute(self=assignment))
180+
)
181+
if attribute_type is not None:
182+
types.add(attribute_type)
183+
if len(types) == 1:
184+
attributes.append(Attribute(name, NamedType(types.pop())))
185+
if len(types) > 1:
186+
attributes.append(
187+
Attribute(name, UnionType([NamedType(type_) for type_ in types]))
188+
)
189+
return attributes
190+
144191
def enter_classdef(self, class_node: astroid.ClassDef) -> None:
145192
qname = class_node.qname()
146-
instance_attributes = list(class_node.instance_attrs)
193+
instance_attributes = self.get_instance_attributes(class_node.instance_attrs)
147194

148195
decorators: Optional[astroid.Decorators] = class_node.decorators
149196
if decorators is not None:

package-parser/package_parser/processing/api/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._api import (
22
API,
33
API_SCHEMA_VERSION,
4+
Attribute,
45
Class,
56
FromImport,
67
Function,

package-parser/package_parser/processing/api/model/_api.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ._documentation import ClassDocumentation, FunctionDocumentation
99
from ._parameters import Parameter
10+
from ._types import AbstractType
1011

1112
API_SCHEMA_VERSION = 1
1213

@@ -207,7 +208,10 @@ def from_json(json: Any) -> Class:
207208
full_docstring=json.get("docstring", ""),
208209
),
209210
json.get("code", ""),
210-
json.get("instance_attributes", []),
211+
[
212+
Attribute.from_json(instance_attribute)
213+
for instance_attribute in json.get("instance_attributes", [])
214+
],
211215
)
212216

213217
for method_id in json["methods"]:
@@ -225,7 +229,7 @@ def __init__(
225229
reexported_by: list[str],
226230
documentation: ClassDocumentation,
227231
code: str,
228-
instance_attributes: list[str],
232+
instance_attributes: list[Attribute],
229233
) -> None:
230234
self.id: str = id_
231235
self.qname: str = qname
@@ -258,10 +262,26 @@ def to_json(self) -> Any:
258262
"description": self.documentation.description,
259263
"docstring": self.documentation.full_docstring,
260264
"code": self.code,
261-
"instance_attributes": self.instance_attributes,
265+
"instance_attributes": [
266+
attribute.to_json() for attribute in self.instance_attributes
267+
],
262268
}
263269

264270

271+
@dataclass
272+
class Attribute:
273+
name: str
274+
types: Optional[AbstractType]
275+
276+
def to_json(self) -> dict[str, Any]:
277+
types_json = self.types.to_json() if self.types is not None else ""
278+
return {"name": self.name, "types": types_json}
279+
280+
@staticmethod
281+
def from_json(json: Any) -> Attribute:
282+
return Attribute(json["name"], AbstractType.from_json(json["types"]))
283+
284+
265285
@dataclass
266286
class Function:
267287
id: str

package-parser/package_parser/processing/api/model/_types.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,31 @@ class AbstractType(metaclass=ABCMeta):
1313
def to_json(self):
1414
pass
1515

16+
@classmethod
17+
def from_json(cls, json: Any) -> Optional[AbstractType]:
18+
value: Optional[AbstractType] = NamedType.from_json(json)
19+
if value is not None:
20+
return value
21+
value = EnumType.from_json(json)
22+
if value is not None:
23+
return value
24+
value = BoundaryType.from_json(json)
25+
if value is not None:
26+
return value
27+
value = UnionType.from_json(json)
28+
return value
29+
1630

1731
@dataclass
1832
class NamedType(AbstractType):
1933
name: str
2034

35+
@classmethod
36+
def from_json(cls, json: Any) -> Optional[NamedType]:
37+
if json["kind"] == cls.__class__.__name__:
38+
return NamedType(json["name"])
39+
return None
40+
2141
@classmethod
2242
def from_string(cls, string: str) -> NamedType:
2343
return NamedType(string)
@@ -31,6 +51,12 @@ class EnumType(AbstractType):
3151
values: set[str] = field(default_factory=set)
3252
full_match: str = ""
3353

54+
@classmethod
55+
def from_json(cls, json: Any) -> Optional[EnumType]:
56+
if json["kind"] == cls.__class__.__name__:
57+
return EnumType(json["values"])
58+
return None
59+
3460
@classmethod
3561
def from_string(cls, string: str) -> Optional[EnumType]:
3662
def remove_backslash(e: str):
@@ -88,12 +114,23 @@ class BoundaryType(AbstractType):
88114

89115
@classmethod
90116
def _is_inclusive(cls, bracket: str) -> bool:
91-
if bracket == "(" or bracket == ")":
117+
if bracket in ("(", ")"):
92118
return False
93-
elif bracket == "[" or bracket == "]":
119+
if bracket in ("[", "]"):
94120
return True
95-
else:
96-
raise Exception(f"{bracket} is not one of []()")
121+
raise Exception(f"{bracket} is not one of []()")
122+
123+
@classmethod
124+
def from_json(cls, json: Any) -> Optional[BoundaryType]:
125+
if json["kind"] == cls.__class__.__name__:
126+
return BoundaryType(
127+
json["base_type"],
128+
json["min"],
129+
json["max"],
130+
json["min_inclusive"],
131+
json["max_inclusive"],
132+
)
133+
return None
97134

98135
@classmethod
99136
def from_string(cls, string: str) -> Optional[BoundaryType]:
@@ -154,12 +191,8 @@ def __eq__(self, __o: object) -> bool:
154191
if eq:
155192
if self.max == BoundaryType.INFINITY:
156193
return True
157-
else:
158-
return self.max_inclusive == __o.max_inclusive
159-
else:
160-
return False
161-
else:
162-
return False
194+
return self.max_inclusive == __o.max_inclusive
195+
return False
163196

164197
def to_json(self) -> dict[str, Any]:
165198
return {
@@ -176,6 +209,17 @@ def to_json(self) -> dict[str, Any]:
176209
class UnionType(AbstractType):
177210
types: list[AbstractType]
178211

212+
@classmethod
213+
def from_json(cls, json: Any) -> Optional[UnionType]:
214+
if json["kind"] == cls.__class__.__name__:
215+
types = []
216+
for element in json["types"]:
217+
type_ = AbstractType.from_json(element)
218+
if type_ is not None:
219+
types.append(type_)
220+
return UnionType(types)
221+
return None
222+
179223
def to_json(self) -> dict[str, Any]:
180224
type_list = []
181225
for t in self.types:
@@ -188,7 +232,7 @@ def create_type(
188232
parameter_documentation: ParameterDocumentation,
189233
) -> Optional[AbstractType]:
190234
type_string = parameter_documentation.type
191-
types: list[AbstractType] = list()
235+
types: list[AbstractType] = []
192236

193237
# Collapse whitespaces
194238
type_string = re.sub(r"\s+", " ", type_string)
@@ -248,7 +292,6 @@ def create_type(
248292

249293
if len(types) == 1:
250294
return types[0]
251-
elif len(types) == 0:
295+
if len(types) == 0:
252296
return None
253-
else:
254-
return UnionType(types)
297+
return UnionType(types)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from abc import ABC, abstractmethod
2+
3+
from package_parser.processing.api.model import (
4+
Attribute,
5+
Class,
6+
Function,
7+
Parameter,
8+
Result,
9+
)
10+
11+
12+
class AbstractDiffer(ABC):
13+
@abstractmethod
14+
def compute_attribute_similarity(
15+
self,
16+
attributes_a: Attribute,
17+
attributes_b: Attribute,
18+
) -> float:
19+
pass
20+
21+
@abstractmethod
22+
def compute_class_similarity(self, class_a: Class, class_b: Class) -> float:
23+
pass
24+
25+
@abstractmethod
26+
def compute_function_similarity(
27+
self, function_a: Function, function_b: Function
28+
) -> float:
29+
pass
30+
31+
@abstractmethod
32+
def compute_parameter_similarity(
33+
self, parameter_a: Parameter, parameter_b: Parameter
34+
) -> float:
35+
pass
36+
37+
@abstractmethod
38+
def compute_result_similarity(self, result_a: Result, result_b: Result) -> float:
39+
pass

0 commit comments

Comments
 (0)