-
Notifications
You must be signed in to change notification settings - Fork 56
Implement JobHandle and use it in WCC #1057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any | ||
|
|
||
| from pandas import DataFrame | ||
|
|
@@ -13,6 +15,12 @@ | |
| ) | ||
| from graphdatascience.procedure_surface.api.default_values import ALL_LABELS, ALL_TYPES | ||
| from graphdatascience.procedure_surface.api.estimation_result import EstimationResult | ||
| from graphdatascience.procedure_surface.arrow.job_handle import ( | ||
| JobHandle, | ||
| NodePropertyMutateHandle, | ||
| NodePropertyStreamHandle, | ||
| NodePropertyWriteHandle, | ||
| ) | ||
| from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper | ||
|
|
||
|
|
||
|
|
@@ -171,6 +179,39 @@ def write( | |
|
|
||
| return WccWriteResult(**result) | ||
|
|
||
| def compute( | ||
| self, | ||
| G: GraphV2, | ||
| threshold: float = 0.0, | ||
| relationship_types: list[str] = ALL_TYPES, | ||
| node_labels: list[str] = ALL_LABELS, | ||
| sudo: bool = False, | ||
| log_progress: bool = True, | ||
| username: str | None = None, | ||
| concurrency: int | None = None, | ||
| job_id: str | None = None, | ||
| seed_property: str | None = None, | ||
| consecutive_ids: bool = False, | ||
| relationship_weight_property: str | None = None, | ||
| ) -> WccJobHandle: | ||
| config = self._node_property_endpoints.create_base_config( | ||
| G, | ||
| concurrency=concurrency, | ||
| consecutive_ids=consecutive_ids, | ||
| job_id=job_id, | ||
| log_progress=log_progress, | ||
| node_labels=node_labels, | ||
| relationship_types=relationship_types, | ||
| relationship_weight_property=relationship_weight_property, | ||
| seed_property=seed_property, | ||
| sudo=sudo, | ||
| threshold=threshold, | ||
| ) | ||
|
|
||
| job_id = self._node_property_endpoints.run_job("v2/community.wcc", config) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: maybe we want to name the method |
||
|
|
||
| return WccJobHandle(self._node_property_endpoints._arrow_client, job_id) | ||
|
|
||
| def estimate( | ||
| self, | ||
| G: GraphV2 | dict[str, Any], | ||
|
|
@@ -192,3 +233,18 @@ def estimate( | |
| relationship_weight_property=relationship_weight_property, | ||
| ) | ||
| return self._node_property_endpoints.estimate("v2/community.wcc.estimate", G, config) | ||
|
|
||
|
|
||
| class WccJobHandle( | ||
| JobHandle[WccStatsResult], NodePropertyMutateHandle, NodePropertyWriteHandle, NodePropertyStreamHandle | ||
| ): | ||
| def __init__( | ||
| self, | ||
| arrow_client: AuthenticatedArrowClient, | ||
| job_id: str, | ||
| remote_write_back_client: RemoteWriteBackClient | None = None, | ||
| ): | ||
| super().__init__(arrow_client, job_id, remote_write_back_client=remote_write_back_client) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think due to multiple inheritance in this case only the init method of |
||
|
|
||
| def _parse_result(self, raw_result: dict[str, Any]) -> WccStatsResult: | ||
| return WccStatsResult(**raw_result) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Generic, TypeVar | ||
|
|
||
| from pandas import DataFrame | ||
|
|
||
| from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient | ||
| from graphdatascience.arrow_client.v2.api_types import MutateResult, JobStatus | ||
| from graphdatascience.arrow_client.v2.job_client import JobClient | ||
| from graphdatascience.arrow_client.v2.mutation_client import MutationClient | ||
| from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient, WriteBackResult | ||
| from graphdatascience.procedure_surface.api.base_result import BaseResult | ||
| from graphdatascience.procedure_surface.api.catalog import GraphV2 | ||
| from graphdatascience.procedure_surface.arrow.endpoints_helper_base import EndpointsHelperBase | ||
|
|
||
| T = TypeVar("T", bound=BaseResult) | ||
|
|
||
|
|
||
| class JobHandle(ABC, Generic[T]): | ||
| def __init__(self, client: AuthenticatedArrowClient, job_id: str, **kwargs: Any): | ||
| self._client = client | ||
| self._job_id = job_id | ||
|
|
||
| def progress(self) -> float: | ||
| return JobClient.get_job_status(self._client, self._job_id).progress | ||
|
|
||
| def status(self) -> JobStatus: | ||
| return JobClient.get_job_status(self._client, self._job_id) | ||
|
|
||
| def result(self) -> T: | ||
| summary = JobClient.get_summary(self._client, self._job_id) | ||
| EndpointsHelperBase.drop_write_internals(summary) | ||
| return self._parse_result(summary) | ||
|
|
||
| def wait(self, show_progress: bool = False) -> None: | ||
| JobClient().wait_for_job(self._client, self._job_id, show_progress=show_progress) | ||
|
|
||
| @abstractmethod | ||
| def _parse_result(self, summary: dict[str, Any]) -> T: | ||
| """Parse the raw summary dictionary into the specific result type.""" | ||
| pass | ||
|
|
||
|
|
||
| class NodePropertyMutateHandle(ABC): | ||
| def __init__(self, client: AuthenticatedArrowClient, job_id: str, **kwargs: Any): | ||
| self._job_id = job_id | ||
| self._client = client | ||
|
|
||
| def mutate(self, mutate_property: str) -> MutateResult: | ||
| return MutationClient().mutate_node_property(self._client, self._job_id, mutate_property) | ||
|
|
||
|
|
||
| class NodePropertyWriteHandle(ABC): | ||
| def __init__( | ||
| self, | ||
| client: AuthenticatedArrowClient, | ||
| job_id: str, | ||
| remote_write_back_client: RemoteWriteBackClient | None = None, | ||
| **kwargs: Any, | ||
| ): | ||
| self._job_id = job_id | ||
| self._client = client | ||
| self._remote_write_back_client = remote_write_back_client | ||
|
|
||
| def write( | ||
| self, G: GraphV2, write_property: str, write_concurrency: int | None = None, log_progress: bool = True | ||
| ) -> WriteBackResult: | ||
|
|
||
| if self._remote_write_back_client is None: | ||
| raise Exception("Write back client is not initialized") | ||
|
|
||
| return self._remote_write_back_client.write( | ||
| G.name(), self._job_id, write_concurrency, write_property, None, log_progress | ||
| ) | ||
|
|
||
|
|
||
| class NodePropertyStreamHandle(ABC): | ||
| def __init__(self, client: AuthenticatedArrowClient, job_id: str, **kwargs: Any): | ||
| self._job_id = job_id | ||
| self._client = client | ||
|
|
||
| def stream(self, G: GraphV2) -> DataFrame: | ||
| return JobClient().stream_results(self._client, G.name(), self._job_id) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dont we need to check the status of the job before streaming? |
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And if user provides
dict[str, str]are keys also ignored?