Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ def write(
graph_name: str,
job_id: str,
concurrency: int | None = None,
property_overwrites: dict[str, str] | None = None,
property_overwrites: str | dict[str, str] | None = None,
relationship_type_overwrite: str | None = None,
log_progress: bool = True,
) -> WriteBackResult:
if isinstance(property_overwrites, str):
# The remote write back procedure allows specifying a single overwrite. The key is ignored.
Copy link
Contributor

@RafalSkolasinski RafalSkolasinski Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if user provides dict[str, str] are keys also ignored?

property_overwrites = {property_overwrites: property_overwrites}

arrow_config = self._arrow_configuration()

configuration: dict[str, Any] = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any

from pandas import DataFrame
Expand All @@ -13,6 +15,12 @@
)
from graphdatascience.procedure_surface.api.default_values import ALL_LABELS, ALL_TYPES
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
from graphdatascience.procedure_surface.arrow.job_handle import (
JobHandle,
NodePropertyMutateHandle,
NodePropertyStreamHandle,
NodePropertyWriteHandle,
)
from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper


Expand Down Expand Up @@ -171,6 +179,39 @@ def write(

return WccWriteResult(**result)

def compute(
self,
G: GraphV2,
threshold: float = 0.0,
relationship_types: list[str] = ALL_TYPES,
node_labels: list[str] = ALL_LABELS,
sudo: bool = False,
log_progress: bool = True,
username: str | None = None,
concurrency: int | None = None,
job_id: str | None = None,
seed_property: str | None = None,
consecutive_ids: bool = False,
relationship_weight_property: str | None = None,
) -> WccJobHandle:
config = self._node_property_endpoints.create_base_config(
G,
concurrency=concurrency,
consecutive_ids=consecutive_ids,
job_id=job_id,
log_progress=log_progress,
node_labels=node_labels,
relationship_types=relationship_types,
relationship_weight_property=relationship_weight_property,
seed_property=seed_property,
sudo=sudo,
threshold=threshold,
)

job_id = self._node_property_endpoints.run_job("v2/community.wcc", config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: maybe we want to name the method start_job instead


return WccJobHandle(self._node_property_endpoints._arrow_client, job_id)

def estimate(
self,
G: GraphV2 | dict[str, Any],
Expand All @@ -192,3 +233,18 @@ def estimate(
relationship_weight_property=relationship_weight_property,
)
return self._node_property_endpoints.estimate("v2/community.wcc.estimate", G, config)


class WccJobHandle(
JobHandle[WccStatsResult], NodePropertyMutateHandle, NodePropertyWriteHandle, NodePropertyStreamHandle
):
def __init__(
self,
arrow_client: AuthenticatedArrowClient,
job_id: str,
remote_write_back_client: RemoteWriteBackClient | None = None,
):
super().__init__(arrow_client, job_id, remote_write_back_client=remote_write_back_client)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think due to multiple inheritance in this case only the init method of JobHandle will be called, so remote_write_back_client may not get set?


def _parse_result(self, raw_result: dict[str, Any]) -> WccStatsResult:
return WccStatsResult(**raw_result)
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def run_job_and_get_summary(self, endpoint: str, config: dict[str, Any]) -> dict
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress)
result = JobClient.get_summary(self._arrow_client, job_id)
if nested_config := result.get("configuration", None):
self._drop_write_internals(nested_config)
self.drop_write_internals(nested_config)
return result

def run_job(self, endpoint: str, config: dict[str, Any]) -> str:
"""Run a job and return the job id."""
return JobClient().run_job(self._arrow_client, endpoint, config)

def _run_job_and_mutate(
self,
endpoint: str,
Expand Down Expand Up @@ -68,7 +72,7 @@ def _run_job_and_mutate(
nested_config["mutateProperty"] = mutate_property
if mutate_relationship_type is not None:
nested_config["mutateRelationshipType"] = mutate_relationship_type
self._drop_write_internals(nested_config)
self.drop_write_internals(nested_config)

return computation_result

Expand Down Expand Up @@ -97,10 +101,6 @@ def _run_job_and_write(
if self._write_back_client is None:
raise Exception("Write back client is not initialized")

if isinstance(property_overwrites, str):
# The remote write back procedure allows specifying a single overwrite. The key is ignored.
property_overwrites = {property_overwrites: property_overwrites}

write_result = self._write_back_client.write(
G.name(),
job_id,
Expand Down Expand Up @@ -148,7 +148,8 @@ def estimate(

return EstimationResult(**deserialize_single(res))

def _drop_write_internals(self, config: dict[str, Any]) -> None:
@staticmethod
def drop_write_internals(config: dict[str, Any]) -> None:
config.pop("writeConcurrency", None)
config.pop("writeToResultStore", None)
config.pop("writeProperty", None)
Expand Down
82 changes: 82 additions & 0 deletions src/graphdatascience/procedure_surface/arrow/job_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar

from pandas import DataFrame

from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
from graphdatascience.arrow_client.v2.api_types import MutateResult, JobStatus
from graphdatascience.arrow_client.v2.job_client import JobClient
from graphdatascience.arrow_client.v2.mutation_client import MutationClient
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient, WriteBackResult
from graphdatascience.procedure_surface.api.base_result import BaseResult
from graphdatascience.procedure_surface.api.catalog import GraphV2
from graphdatascience.procedure_surface.arrow.endpoints_helper_base import EndpointsHelperBase

T = TypeVar("T", bound=BaseResult)


class JobHandle(ABC, Generic[T]):
def __init__(self, client: AuthenticatedArrowClient, job_id: str, **kwargs: Any):
self._client = client
self._job_id = job_id

def progress(self) -> float:
return JobClient.get_job_status(self._client, self._job_id).progress

def status(self) -> JobStatus:
return JobClient.get_job_status(self._client, self._job_id)

def result(self) -> T:
summary = JobClient.get_summary(self._client, self._job_id)
EndpointsHelperBase.drop_write_internals(summary)
return self._parse_result(summary)

def wait(self, show_progress: bool = False) -> None:
JobClient().wait_for_job(self._client, self._job_id, show_progress=show_progress)

@abstractmethod
def _parse_result(self, summary: dict[str, Any]) -> T:
"""Parse the raw summary dictionary into the specific result type."""
pass


class NodePropertyMutateHandle(ABC):
def __init__(self, client: AuthenticatedArrowClient, job_id: str, **kwargs: Any):
self._job_id = job_id
self._client = client

def mutate(self, mutate_property: str) -> MutateResult:
return MutationClient().mutate_node_property(self._client, self._job_id, mutate_property)


class NodePropertyWriteHandle(ABC):
def __init__(
self,
client: AuthenticatedArrowClient,
job_id: str,
remote_write_back_client: RemoteWriteBackClient | None = None,
**kwargs: Any,
):
self._job_id = job_id
self._client = client
self._remote_write_back_client = remote_write_back_client

def write(
self, G: GraphV2, write_property: str, write_concurrency: int | None = None, log_progress: bool = True
) -> WriteBackResult:

if self._remote_write_back_client is None:
raise Exception("Write back client is not initialized")

return self._remote_write_back_client.write(
G.name(), self._job_id, write_concurrency, write_property, None, log_progress
)


class NodePropertyStreamHandle(ABC):
def __init__(self, client: AuthenticatedArrowClient, job_id: str, **kwargs: Any):
self._job_id = job_id
self._client = client

def stream(self, G: GraphV2) -> DataFrame:
return JobClient().stream_results(self._client, G.name(), self._job_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont we need to check the status of the job before streaming?

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class NodePropertyEndpointsHelper(EndpointsHelperBase):
def run_job_and_mutate(self, endpoint: str, config: dict[str, Any], mutate_property: str) -> dict[str, Any]:
return self._run_job_and_mutate(endpoint, config, mutate_property=mutate_property)

def run_job(self, endpoint: str, config: dict[str, Any]) -> str:
return super().run_job(endpoint, config)

def run_job_and_write(
self,
endpoint: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,21 @@ def test_wcc_estimate(wcc_endpoints: WccArrowEndpoints, sample_graph: GraphV2) -
assert result.bytes_max > 0
assert result.heap_percentage_min > 0
assert result.heap_percentage_max > 0


def test_wcc_compute_and_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: GraphV2) -> None:
job = wcc_endpoints.compute(
G=sample_graph,
)

assert job.progress() >= 0

job.wait()

summary = job.result()

assert summary.component_count == 2
assert "p10" in summary.component_distribution
assert summary.pre_processing_millis >= 0
assert summary.compute_millis >= 0
assert summary.post_processing_millis >= 0