Skip to content

Commit f2de0a4

Browse files
authored
Merge pull request #731 from shangyian/add-linked-nodes-client
2 parents c588f52 + 781eb7d commit f2de0a4

File tree

7 files changed

+135
-85
lines changed

7 files changed

+135
-85
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""
22
Version for Hatch
33
"""
4-
__version__ = "0.0.1a15"
4+
__version__ = "0.0.1a18"

datajunction-clients/python/datajunction/_internal.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,16 @@ def _set_column_attributes(
407407
)
408408
return response.json()
409409

410+
def _find_nodes_with_dimension(
411+
self,
412+
node_name,
413+
):
414+
"""
415+
Find all nodes with this dimension
416+
"""
417+
response = self._session.get(f"/dimensions/{node_name}/nodes/")
418+
return response.json()
419+
410420

411421
class ClientEntity(BaseModel):
412422
"""

datajunction-clients/python/datajunction/builder.py

Lines changed: 5 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,6 @@ def restore_node(self, node_name: str) -> None:
101101
#
102102
# Nodes: SOURCE
103103
#
104-
def source(self, node_name: str) -> "Source":
105-
"""
106-
Retrieves a source node with that name if one exists.
107-
"""
108-
node_dict = self._verify_node_exists(
109-
node_name,
110-
type_=models.NodeType.SOURCE.value,
111-
)
112-
node = Source(
113-
**node_dict,
114-
dj_client=self,
115-
)
116-
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
117-
return node
118-
119104
def create_source( # pylint: disable=too-many-arguments
120105
self,
121106
name: str,
@@ -145,6 +130,7 @@ def create_source( # pylint: disable=too-many-arguments
145130
columns=columns,
146131
)
147132
self._create_node(node=new_node, mode=mode)
133+
new_node.refresh()
148134
return new_node
149135

150136
def register_table(self, catalog: str, schema: str, table: str) -> Source:
@@ -162,21 +148,6 @@ def register_table(self, catalog: str, schema: str, table: str) -> Source:
162148
#
163149
# Nodes: TRANSFORM
164150
#
165-
def transform(self, node_name: str) -> "Transform":
166-
"""
167-
Retrieves a transform node with that name if one exists.
168-
"""
169-
node_dict = self._verify_node_exists(
170-
node_name,
171-
type_=models.NodeType.TRANSFORM.value,
172-
)
173-
node = Transform(
174-
**node_dict,
175-
dj_client=self,
176-
)
177-
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
178-
return node
179-
180151
def create_transform( # pylint: disable=too-many-arguments
181152
self,
182153
name: str,
@@ -200,26 +171,12 @@ def create_transform( # pylint: disable=too-many-arguments
200171
query=query,
201172
)
202173
self._create_node(node=new_node, mode=mode)
174+
new_node.refresh()
203175
return new_node
204176

205177
#
206178
# Nodes: DIMENSION
207179
#
208-
def dimension(self, node_name: str) -> "Dimension":
209-
"""
210-
Retrieves a Dimension node with that name if one exists.
211-
"""
212-
node_dict = self._verify_node_exists(
213-
node_name,
214-
type_=models.NodeType.DIMENSION.value,
215-
)
216-
node = Dimension(
217-
**node_dict,
218-
dj_client=self,
219-
)
220-
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
221-
return node
222-
223180
def create_dimension( # pylint: disable=too-many-arguments
224181
self,
225182
name: str,
@@ -243,26 +200,12 @@ def create_dimension( # pylint: disable=too-many-arguments
243200
query=query,
244201
)
245202
self._create_node(node=new_node, mode=mode)
203+
new_node.refresh()
246204
return new_node
247205

248206
#
249207
# Nodes: METRIC
250208
#
251-
def metric(self, node_name: str) -> "Metric":
252-
"""
253-
Retrieves a Metric node with that name if one exists.
254-
"""
255-
node_dict = self._verify_node_exists(
256-
node_name,
257-
type_=models.NodeType.METRIC.value,
258-
)
259-
node = Metric(
260-
**node_dict,
261-
dj_client=self,
262-
)
263-
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
264-
return node
265-
266209
def create_metric( # pylint: disable=too-many-arguments
267210
self,
268211
name: str,
@@ -286,35 +229,12 @@ def create_metric( # pylint: disable=too-many-arguments
286229
query=query,
287230
)
288231
self._create_node(node=new_node, mode=mode)
232+
new_node.refresh()
289233
return new_node
290234

291235
#
292236
# Nodes: CUBE
293237
#
294-
def cube(self, node_name: str) -> "Cube": # pragma: no cover
295-
"""
296-
Retrieves a Cube node with that name if one exists.
297-
"""
298-
node_dict = self._get_cube(node_name)
299-
if "name" not in node_dict:
300-
raise DJClientException(f"Cube `{node_name}` does not exist")
301-
dimensions = [
302-
f'{col["node_name"]}.{col["name"]}'
303-
for col in node_dict["cube_elements"]
304-
if col["type"] != "metric"
305-
]
306-
metrics = [
307-
f'{col["node_name"]}.{col["name"]}'
308-
for col in node_dict["cube_elements"]
309-
if col["type"] == "metric"
310-
]
311-
return Cube(
312-
**node_dict,
313-
metrics=metrics,
314-
dimensions=dimensions,
315-
dj_client=self,
316-
)
317-
318238
def create_cube( # pylint: disable=too-many-arguments
319239
self,
320240
name: str,
@@ -338,4 +258,5 @@ def create_cube( # pylint: disable=too-many-arguments
338258
display_name=display_name,
339259
)
340260
self._create_node(node=new_node, mode=mode) # pragma: no cover
261+
new_node.refresh()
341262
return new_node # pragma: no cover

datajunction-clients/python/datajunction/client.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from datajunction import _internal, models
1010
from datajunction.exceptions import DJClientException
11+
from datajunction.nodes import Cube, Dimension, Metric, Source, Transform
1112

1213

1314
class DJClient(_internal.DJClient):
@@ -280,3 +281,88 @@ def list_engines(self) -> List[dict]:
280281
{"name": engine["name"], "version": engine["version"]}
281282
for engine in json_response
282283
]
284+
285+
# Read nodes
286+
def source(self, node_name: str) -> Source:
287+
"""
288+
Retrieves a source node with that name if one exists.
289+
"""
290+
node_dict = self._verify_node_exists(
291+
node_name,
292+
type_=models.NodeType.SOURCE.value,
293+
)
294+
node = Source(
295+
**node_dict,
296+
dj_client=self,
297+
)
298+
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
299+
return node
300+
301+
def transform(self, node_name: str) -> Transform:
302+
"""
303+
Retrieves a transform node with that name if one exists.
304+
"""
305+
node_dict = self._verify_node_exists(
306+
node_name,
307+
type_=models.NodeType.TRANSFORM.value,
308+
)
309+
node = Transform(
310+
**node_dict,
311+
dj_client=self,
312+
)
313+
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
314+
return node
315+
316+
def dimension(self, node_name: str) -> "Dimension":
317+
"""
318+
Retrieves a Dimension node with that name if one exists.
319+
"""
320+
node_dict = self._verify_node_exists(
321+
node_name,
322+
type_=models.NodeType.DIMENSION.value,
323+
)
324+
node = Dimension(
325+
**node_dict,
326+
dj_client=self,
327+
)
328+
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
329+
return node
330+
331+
def metric(self, node_name: str) -> "Metric":
332+
"""
333+
Retrieves a Metric node with that name if one exists.
334+
"""
335+
node_dict = self._verify_node_exists(
336+
node_name,
337+
type_=models.NodeType.METRIC.value,
338+
)
339+
node = Metric(
340+
**node_dict,
341+
dj_client=self,
342+
)
343+
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
344+
return node
345+
346+
def cube(self, node_name: str) -> "Cube": # pragma: no cover
347+
"""
348+
Retrieves a Cube node with that name if one exists.
349+
"""
350+
node_dict = self._get_cube(node_name)
351+
if "name" not in node_dict:
352+
raise DJClientException(f"Cube `{node_name}` does not exist")
353+
dimensions = [
354+
f'{col["node_name"]}.{col["name"]}'
355+
for col in node_dict["cube_elements"]
356+
if col["type"] != "metric"
357+
]
358+
metrics = [
359+
f'{col["node_name"]}.{col["name"]}'
360+
for col in node_dict["cube_elements"]
361+
if col["type"] == "metric"
362+
]
363+
return Cube(
364+
**node_dict,
365+
metrics=metrics,
366+
dimensions=dimensions,
367+
dj_client=self,
368+
)

datajunction-clients/python/datajunction/nodes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,15 @@ class Dimension(NodeWithQuery):
327327
query: str
328328
columns: Optional[List[models.Column]]
329329

330+
def linked_nodes(self):
331+
"""
332+
Find all nodes linked to this dimension
333+
"""
334+
return [
335+
node["name"]
336+
for node in self.dj_client._find_nodes_with_dimension(self.name)
337+
]
338+
330339

331340
class Cube(Node): # pylint: disable=abstract-method
332341
"""

datajunction-clients/python/tests/test_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def test_create_nodes(self, client): # pylint: disable=unused-argument
378378
mode=NodeMode.PUBLISHED,
379379
)
380380
assert account_type_dim.name == "default.account_type"
381+
assert len(account_type_dim.columns) == 3
381382
assert "default.account_type" in client.list_dimensions(namespace="default")
382383

383384
# transform nodes
@@ -395,6 +396,8 @@ def test_create_nodes(self, client): # pylint: disable=unused-argument
395396
"default.large_revenue_payments_only"
396397
in client.namespace("default").transforms()
397398
)
399+
assert len(large_revenue_payments_only.columns) == 4
400+
398401
client.transform("default.large_revenue_payments_only")
399402

400403
result = large_revenue_payments_only.add_materialization(

datajunction-clients/python/tests/test_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,30 @@ def test_list_nodes(self, client):
280280
"foo.bar.repair_orders_thin",
281281
]
282282

283+
def test_find_nodes_with_dimension(self, client):
284+
"""
285+
Check that `dimension.linked_nodes()` works as expected.
286+
"""
287+
repair_order_dim = client.dimension("default.repair_order")
288+
assert repair_order_dim.linked_nodes() == [
289+
"default.repair_order_details",
290+
"default.avg_repair_price",
291+
"default.total_repair_cost",
292+
"default.total_repair_order_discounts",
293+
"default.avg_repair_order_discounts",
294+
]
295+
283296
#
284297
# Get common metrics and dimensions
285298
#
299+
def test_common_dimensions(self, client):
300+
"""
301+
Test that getting common dimensions for metrics works
302+
"""
303+
dims = client.common_dimensions(
304+
metrics=["default.num_repair_orders", "default.avg_repair_price"],
305+
)
306+
assert len(dims) == 8
286307

287308
#
288309
# SQL and data

0 commit comments

Comments
 (0)