Skip to content

Commit 0aca3e6

Browse files
committed
[feat] Added count property to WorkflowSearchResponse
1 parent 89faa8b commit 0aca3e6

File tree

4 files changed

+29
-30
lines changed

4 files changed

+29
-30
lines changed

pyatlan/model/workflow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ def __init__(self, **data: Any):
234234
self._size = data.get("size") # type: ignore[assignment]
235235
self._start = data.get("start") # type: ignore[assignment]
236236

237+
@property
238+
def count(self):
239+
return self.hits.total.get("value", 0) if self.hits and self.hits.total else 0
240+
237241
def current_page(self) -> Optional[List[WorkflowSearchResult]]:
238242
return self.hits.hits # type: ignore
239243

tests/integration/test_workflow_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ def test_workflow_get_runs_and_stop(client: AtlanClient, workflow: WorkflowRespo
134134
runs = client.workflow.get_runs(
135135
workflow_name=workflow.metadata.name, workflow_phase=AtlanWorkflowPhase.RUNNING
136136
)
137-
assert runs
138-
total_runs = runs.hits.hits # type: ignore
139-
assert len(total_runs) == 1 # type: ignore
140-
run = total_runs[0] # type: ignore
137+
assert runs and runs.count == 1
138+
current_page = runs.current_page()
139+
assert current_page is not None and len(current_page) == 1
140+
run = current_page[0]
141141
assert run and run.id
142142
assert workflow.metadata.name and (workflow.metadata.name in run.id)
143143

@@ -178,7 +178,7 @@ def test_workflow_get_runs_and_stop(client: AtlanClient, workflow: WorkflowRespo
178178
[AtlanWorkflowPhase.FAILED], started_at="now-1h"
179179
)
180180
assert runs_status
181-
workflow_run_status = runs_status.hits.hits[0] # type: ignore
181+
workflow_run_status = runs_status.current_page()[0] # type: ignore
182182
start_time = workflow_run_status.source.status.startedAt # type: ignore
183183
start_datetime = datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ") # type: ignore
184184
start_datetime = start_datetime.replace(tzinfo=timezone.utc)

tests/unit/test_client.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2563,9 +2563,9 @@ def test_get_all_pagation(group_client, mock_api_caller):
25632563
]
25642564

25652565
groups = group_client.get_all(limit=2)
2566-
assert len(groups.records) == 2
2567-
assert groups.records[0].id == "1"
2568-
assert groups.records[1].id == "2"
2566+
assert len(groups.current_page()) == 2
2567+
assert groups.current_page()[0].id == "1"
2568+
assert groups.current_page()[1].id == "2"
25692569
assert mock_api_caller._call_api.call_count == 1
25702570
mock_api_caller.reset_mock()
25712571

@@ -2577,7 +2577,7 @@ def test_get_all_empty_response_with_raw_records(group_client, mock_api_caller):
25772577
]
25782578

25792579
groups = group_client.get_all()
2580-
assert len(groups.records) == 0
2580+
assert len(groups.current_page()) == 0
25812581
mock_api_caller.reset_mock()
25822582

25832583

@@ -2593,9 +2593,9 @@ def test_get_all_with_columns(group_client, mock_api_caller):
25932593
columns = ["alias"]
25942594
groups = group_client.get_all(limit=10, columns=columns)
25952595

2596-
assert len(groups.records) == 2
2597-
assert groups.records[0].id == "1"
2598-
assert groups.records[0].alias == "Group1"
2596+
assert len(groups.current_page()) == 2
2597+
assert groups.current_page()[0].id == "1"
2598+
assert groups.current_page()[0].alias == "Group1"
25992599
mock_api_caller._call_api.assert_called_once()
26002600
query_params = mock_api_caller._call_api.call_args.kwargs["query_params"]
26012601
assert query_params["columns"] == columns
@@ -2613,9 +2613,9 @@ def test_get_all_sorting(group_client, mock_api_caller):
26132613

26142614
groups = group_client.get_all(limit=10, sort="alias")
26152615

2616-
assert len(groups.records) == 2
2617-
assert groups.records[0].id == "1"
2618-
assert groups.records[0].alias == "Group1"
2616+
assert len(groups.current_page()) == 2
2617+
assert groups.current_page()[0].id == "1"
2618+
assert groups.current_page()[0].alias == "Group1"
26192619
mock_api_caller._call_api.assert_called_once()
26202620
query_params = mock_api_caller._call_api.call_args.kwargs["query_params"]
26212621
assert query_params["sort"] == "alias"

tests/unit/test_workflow_client.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -273,23 +273,19 @@ def test_find_by_type(client: WorkflowClient, mock_api_caller):
273273

274274

275275
def test_find_runs_by_status_and_time_range(client: WorkflowClient, mock_api_caller):
276-
raw_json = {"shards": {"dummy": None}, "hits": {"total": {"dummy": None}}}
276+
raw_json = {"_shards": {"dummy": None}, "hits": {"total": {"dummy": None}}}
277277
mock_api_caller._call_api.return_value = raw_json
278278

279279
status = [AtlanWorkflowPhase.SUCCESS, AtlanWorkflowPhase.FAILED]
280280
started_at = "now-2h"
281281
finished_at = "now-1h"
282-
283-
assert (
284-
client.find_runs_by_status_and_time_range(
285-
status=status,
286-
started_at=started_at,
287-
finished_at=finished_at,
288-
from_=10,
289-
size=5,
290-
)
291-
== []
292-
)
282+
assert client.find_runs_by_status_and_time_range(
283+
status=status,
284+
started_at=started_at,
285+
finished_at=finished_at,
286+
from_=10,
287+
size=5,
288+
) == WorkflowSearchResponse(**raw_json)
293289
mock_api_caller._call_api.assert_called_once()
294290
assert isinstance(
295291
mock_api_caller._call_api.call_args.kwargs["request_obj"], WorkflowSearchRequest
@@ -556,14 +552,13 @@ def test_workflow_get_runs(
556552
mock_api_caller,
557553
search_response: WorkflowSearchResponse,
558554
):
559-
mock_api_caller._call_api.return_value = search_response.dict()
555+
mock_api_caller._call_api.return_value = search_response.dict(by_alias=True)
560556
response = client.get_runs(
561557
workflow_name="test-workflow",
562558
workflow_phase=AtlanWorkflowPhase.RUNNING,
563559
)
564560

565-
assert search_response.hits
566-
assert response == search_response.hits.hits
561+
assert response == WorkflowSearchResponse(**search_response.dict())
567562
assert mock_api_caller._call_api.call_count == 1
568563
mock_api_caller.reset_mock()
569564

0 commit comments

Comments
 (0)