Skip to content

Commit 05e98be

Browse files
author
oravidov
committed
Added typing and fixed unit tests
1 parent 6807d77 commit 05e98be

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

lineage/query_context.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from datetime import datetime
2-
from typing import Optional
2+
from typing import Optional, Union
33
import dateutil.parser
44

55

66
class QueryContext(object):
7-
def __init__(self, queried_database: str, queried_schema: str, query_time: Optional[datetime] = None,
8-
query_volume: Optional[int] = None, query_type: Optional[str] = None, user_name: Optional[str] = None,
9-
role_name: Optional[str] = None):
7+
def __init__(self, queried_database: Optional[str] = None, queried_schema: Optional[str] = None,
8+
query_time: Optional[datetime] = None, query_volume: Optional[int] = None,
9+
query_type: Optional[str] = None, user_name: Optional[str] = None,
10+
role_name: Optional[str] = None) -> None:
1011
self.queried_database = queried_database
1112
self.queried_schema = queried_schema
1213
self.query_time = query_time
@@ -15,32 +16,38 @@ def __init__(self, queried_database: str, queried_schema: str, query_time: Optio
1516
self.user_name = user_name
1617
self.role_name = role_name
1718

18-
def to_dict(self):
19-
query_time_str = self.query_time.isoformat() if self.query_time is not None else None
19+
def to_dict(self) -> dict:
2020
return {'queried_database': self.queried_database,
2121
'queried_schema': self.queried_schema,
22-
'query_time': query_time_str,
22+
'query_time': self._query_time_to_str(self.query_time),
2323
'query_volume': self.query_volume,
2424
'query_type': self.query_type,
2525
'user_name': self.user_name,
2626
'role_name': self.role_name}
2727

2828
@staticmethod
29-
def _html_param_with_default(param, default='unknown'):
29+
def _query_time_to_str(query_time: Optional[datetime], fmt: str = None) -> Optional[str]:
30+
if query_time is None:
31+
return None
32+
33+
if fmt is None:
34+
return query_time.isoformat()
35+
36+
return query_time.strftime(fmt)
37+
38+
@staticmethod
39+
def _html_param_with_default(param: Union[str, int], default: Union[str, int] = 'unknown') -> Union[str, int]:
3040
return default if param is None else param
3141

32-
def to_html(self):
42+
def to_html(self) -> str:
3343
query_type = self._html_param_with_default(self.query_type)
3444
user_name = self._html_param_with_default(self.user_name)
3545
role_name = self._html_param_with_default(self.role_name)
36-
query_time = self.query_time.strftime('%Y-%m-%d %H:%M:%S')
37-
38-
if self.query_volume is not None and self.query_volume > 0:
39-
volume_color = "DarkSlateGrey"
40-
query_volume = self.query_volume
41-
else:
46+
query_time = self._query_time_to_str(self.query_time, fmt='%Y-%m-%d %H:%M:%S')
47+
query_volume = self._html_param_with_default(self.query_volume, 0)
48+
volume_color = "DarkSlateGrey"
49+
if query_volume == 0:
4250
volume_color = "tomato"
43-
query_volume = 0
4451

4552
return f"""
4653
<html>
@@ -60,7 +67,7 @@ def to_html(self):
6067
"""
6168

6269
@staticmethod
63-
def from_dict(query_context_dict):
70+
def from_dict(query_context_dict: dict) -> 'QueryContext':
6471
if 'query_time' in query_context_dict and query_context_dict['query_time'] is not None:
6572
query_context_dict['query_time'] = dateutil.parser.parse(query_context_dict['query_time'])
6673
return QueryContext(**query_context_dict)

tests/test_lineage_graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,14 @@ def test_lineage_graph_add_nodes_and_edges(sources, targets, edges, show_isolate
7575
di_graph_mock = mock.create_autospec(nx.DiGraph)
7676
reference._lineage_graph = di_graph_mock
7777

78-
reference._add_nodes_and_edges(sources, targets)
78+
empty_query_context = QueryContext()
79+
reference._add_nodes_and_edges(sources, targets, empty_query_context)
7980

8081
node_calls = []
8182
if len(sources) > 0:
8283
node_calls.append(mock.call(sources))
8384
if len(targets) > 0:
84-
node_calls.append(mock.call(targets))
85+
node_calls.append(mock.call(targets, title=empty_query_context.to_html()))
8586

8687
edge_calls = []
8788
for edge in edges:

0 commit comments

Comments
 (0)