Skip to content

Commit 91fe759

Browse files
committed
fix(gfql): use domain helpers for same-path ids
1 parent 0acddd1 commit 91fe759

File tree

10 files changed

+410
-259
lines changed

10 files changed

+410
-259
lines changed

graphistry/compute/gfql/df_executor.py

Lines changed: 81 additions & 52 deletions
Large diffs are not rendered by default.

graphistry/compute/gfql/same_path/bfs.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
Contains pure functions for building edge pairs and computing BFS reachability.
44
"""
55

6-
from typing import Any, Set
7-
8-
import pandas as pd
6+
from typing import Any, Sequence
97

108
from graphistry.compute.typing import DataFrameT
119
from .edge_semantics import EdgeSemantics
12-
from .df_utils import concat_frames, df_cons
10+
from .df_utils import (
11+
concat_frames,
12+
series_values,
13+
domain_from_values,
14+
domain_diff,
15+
domain_union,
16+
domain_is_empty,
17+
domain_to_frame,
18+
)
1319

1420

1521
def build_edge_pairs(
@@ -23,23 +29,22 @@ def build_edge_pairs(
2329
For undirected edges, both directions are included.
2430
For directed edges, direction follows sem.join_cols().
2531
"""
26-
is_cudf = edges_df.__class__.__module__.startswith("cudf")
2732
if sem.is_undirected:
2833
fwd = edges_df[[src_col, dst_col]].copy()
29-
fwd.columns = pd.Index(['__from__', '__to__'])
34+
fwd.columns = ['__from__', '__to__']
3035
rev = edges_df[[dst_col, src_col]].copy()
31-
rev.columns = pd.Index(['__from__', '__to__'])
36+
rev.columns = ['__from__', '__to__']
3237
result = concat_frames([fwd, rev])
3338
return result.drop_duplicates() if result is not None else fwd.iloc[:0]
3439
else:
3540
join_col, result_col = sem.join_cols(src_col, dst_col)
3641
pairs = edges_df[[join_col, result_col]].copy()
37-
pairs.columns = pd.Index(['__from__', '__to__'])
42+
pairs.columns = ['__from__', '__to__']
3843
return pairs
3944

4045

4146
def bfs_reachability(
42-
edge_pairs: DataFrameT, start_nodes: Set[Any], max_hops: int, hop_col: str
47+
edge_pairs: DataFrameT, start_nodes: Sequence[Any], max_hops: int, hop_col: str
4348
) -> DataFrameT:
4449
"""Compute BFS reachability with hop distance tracking.
4550
@@ -48,19 +53,18 @@ def bfs_reachability(
4853
4954
Args:
5055
edge_pairs: DataFrame with ['__from__', '__to__'] columns
51-
start_nodes: Set of starting node IDs (hop 0)
56+
start_nodes: Starting node domain (hop 0)
5257
max_hops: Maximum number of hops to traverse
5358
hop_col: Name for the hop distance column in output
5459
5560
Returns:
5661
DataFrame with all reachable nodes and their hop distances
5762
"""
58-
from .df_utils import series_values
59-
import pandas as pd
60-
6163
# Use same DataFrame type as input
62-
result = df_cons(edge_pairs, {'__node__': list(start_nodes), hop_col: 0})
63-
visited_idx = pd.Index(start_nodes) if not isinstance(start_nodes, pd.Index) else start_nodes
64+
start_domain = domain_from_values(start_nodes, edge_pairs)
65+
result = domain_to_frame(edge_pairs, start_domain, '__node__')
66+
result[hop_col] = 0
67+
visited_idx = start_domain
6468

6569
for hop in range(1, max_hops + 1):
6670
frontier = result[result[hop_col] == hop - 1][['__node__']].rename(columns={'__node__': '__from__'})
@@ -69,14 +73,15 @@ def bfs_reachability(
6973
next_df = edge_pairs.merge(frontier, on='__from__', how='inner')[['__to__']].drop_duplicates()
7074
next_df = next_df.rename(columns={'__to__': '__node__'})
7175

72-
# Filter out already visited nodes using pd.Index operations
76+
# Filter out already visited nodes using domain operations
7377
candidate_nodes = series_values(next_df['__node__'])
74-
new_node_ids = candidate_nodes.difference(visited_idx)
75-
if len(new_node_ids) == 0:
78+
new_node_ids = domain_diff(candidate_nodes, visited_idx)
79+
if domain_is_empty(new_node_ids):
7680
break
7781

78-
new_nodes = df_cons(edge_pairs, {'__node__': list(new_node_ids), hop_col: hop})
79-
visited_idx = visited_idx.union(new_node_ids)
82+
new_nodes = domain_to_frame(edge_pairs, new_node_ids, '__node__')
83+
new_nodes[hop_col] = hop
84+
visited_idx = domain_union(visited_idx, new_node_ids)
8085

8186
result = concat_frames([result, new_nodes])
8287
if result is None:

graphistry/compute/gfql/same_path/df_utils.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,25 @@
33
Contains pure functions for series/dataframe operations used across the executor.
44
"""
55

6-
from typing import Any, Optional, Sequence, Set
6+
from typing import Any, Optional, Sequence
77

88
import pandas as pd
99

1010
from graphistry.compute.typing import DataFrameT
1111

1212

13+
def _is_cudf_obj(obj: Any) -> bool:
14+
return hasattr(obj, "__class__") and obj.__class__.__module__.startswith("cudf")
15+
16+
17+
def _cudf_index_op(left: Any, right: Any, op: str) -> Any:
18+
method = getattr(left, op)
19+
try:
20+
return method(right, sort=False)
21+
except TypeError:
22+
return method(right)
23+
24+
1325
def df_cons(template_df: DataFrameT, data: dict) -> DataFrameT:
1426
"""Construct a DataFrame of the same type as template_df.
1527
@@ -59,26 +71,99 @@ def series_unique(series: Any) -> Any:
5971
6072
For set operations (intersection, union), use series_values() instead.
6173
"""
74+
if _is_cudf_obj(series):
75+
return series.dropna().unique()
76+
if isinstance(series, pd.Index):
77+
return series.dropna().unique()
6278
if hasattr(series, 'dropna'):
6379
return series.dropna().unique()
6480
pandas_series = to_pandas_series(series)
6581
return pandas_series.dropna().unique()
6682

6783

68-
def series_values(series: Any) -> pd.Index:
69-
"""Extract unique non-null values from a series as a pd.Index.
70-
71-
Returns pd.Index which supports:
72-
- .intersection() for & operations
73-
- .union() for | operations
74-
- Direct use in .isin() (no conversion needed)
84+
def series_values(series: Any) -> Any:
85+
"""Extract unique non-null values from a series as an Index-like domain.
7586
76-
This is ~9x faster than the previous set-based approach.
87+
Returns a pandas.Index for pandas objects, and cudf.Index for cuDF objects.
88+
These Index types support .intersection/.union/.difference and are safe to
89+
pass into .isin() without host syncs.
7790
"""
91+
if _is_cudf_obj(series):
92+
import cudf # type: ignore
93+
if isinstance(series, cudf.Index):
94+
return series.dropna().unique()
95+
return cudf.Index(series.dropna().unique())
96+
if isinstance(series, pd.Index):
97+
return series.dropna().unique()
7898
pandas_series = to_pandas_series(series)
7999
return pd.Index(pandas_series.dropna().unique())
80100

81101

102+
def domain_empty(template: Optional[Any] = None) -> Any:
103+
if _is_cudf_obj(template):
104+
import cudf # type: ignore
105+
return cudf.Index([])
106+
return pd.Index([])
107+
108+
109+
def domain_is_empty(domain: Any) -> bool:
110+
return domain is None or len(domain) == 0
111+
112+
113+
def domain_from_values(values: Any, template: Optional[Any] = None) -> Any:
114+
if domain_is_empty(values):
115+
return domain_empty(template)
116+
if _is_cudf_obj(values):
117+
import cudf # type: ignore
118+
if isinstance(values, cudf.Index):
119+
return values
120+
return cudf.Index(values)
121+
if isinstance(values, pd.Index):
122+
return values
123+
if _is_cudf_obj(template):
124+
import cudf # type: ignore
125+
return cudf.Index(values)
126+
return pd.Index(values)
127+
128+
129+
def domain_intersect(left: Any, right: Any) -> Any:
130+
if domain_is_empty(left) or domain_is_empty(right):
131+
return domain_empty(left if left is not None else right)
132+
if isinstance(left, pd.Index):
133+
return left.intersection(right)
134+
if _is_cudf_obj(left):
135+
return _cudf_index_op(left, right, "intersection")
136+
return left.intersection(right)
137+
138+
139+
def domain_union(left: Any, right: Any) -> Any:
140+
if domain_is_empty(left):
141+
return right
142+
if domain_is_empty(right):
143+
return left
144+
if isinstance(left, pd.Index):
145+
return left.union(right)
146+
if _is_cudf_obj(left):
147+
return _cudf_index_op(left, right, "union")
148+
return left.union(right)
149+
150+
151+
def domain_diff(left: Any, right: Any) -> Any:
152+
if domain_is_empty(left) or domain_is_empty(right):
153+
return left
154+
if isinstance(left, pd.Index):
155+
return left.difference(right)
156+
if _is_cudf_obj(left):
157+
return _cudf_index_op(left, right, "difference")
158+
return left.difference(right)
159+
160+
161+
def domain_to_frame(template_df: DataFrameT, domain: Any, col: str) -> DataFrameT:
162+
if domain is None:
163+
return df_cons(template_df, {col: []})
164+
return df_cons(template_df, {col: domain})
165+
166+
82167
# Standard column name for ID DataFrames used in semi-joins
83168
_ID_COL = "__id__"
84169

graphistry/compute/gfql/same_path/edge_semantics.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Tuple, TYPE_CHECKING
7+
from typing import Any, Tuple, TYPE_CHECKING
88

99
from graphistry.compute.ast import ASTEdge
10-
from .df_utils import series_values
10+
from .df_utils import series_values, domain_union
1111

1212
if TYPE_CHECKING:
1313
pass
@@ -96,7 +96,7 @@ def endpoint_cols(self, src_col: str, dst_col: str) -> Tuple[str, str]:
9696

9797
def start_nodes(
9898
self, edges_df, src_col: str, dst_col: str
99-
) -> set:
99+
) -> Any:
100100
"""Get starting nodes for edge traversal (for backward propagation).
101101
102102
For forward: returns src nodes (where traversal starts)
@@ -109,10 +109,13 @@ def start_nodes(
109109
dst_col: Destination column name
110110
111111
Returns:
112-
pd.Index of node IDs where traversal starts
112+
Index-like domain of node IDs where traversal starts
113113
"""
114114
if self.is_undirected:
115-
return series_values(edges_df[src_col]).union(series_values(edges_df[dst_col]))
115+
return domain_union(
116+
series_values(edges_df[src_col]),
117+
series_values(edges_df[dst_col]),
118+
)
116119
elif self.is_reverse:
117120
return series_values(edges_df[dst_col])
118121
else:

0 commit comments

Comments
 (0)