Skip to content

Commit 28bc404

Browse files
committed
fix(enumerator): restore full pre-split functionality and remove test skips
- Restore source_node_match/destination_node_match filter support - Restore WHERE + multi-hop path pruning logic - Remove skip decorators that hid oracle feature gaps - Keep only legitimate xfail for edge alias on multi-hop (oracle limitation) - Remove conftest workaround for multi-hop + WHERE
1 parent 8fb926c commit 28bc404

File tree

5 files changed

+83
-67
lines changed

5 files changed

+83
-67
lines changed

graphistry/gfql/ref/enumerator.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Minimal GFQL reference enumerator used as the correctness oracle."""
2+
# ruff: noqa: E501
23

34
from __future__ import annotations
45

56
from dataclasses import dataclass
6-
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple
7+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
78

89
import pandas as pd
910

@@ -16,21 +17,7 @@
1617
from graphistry.compute.ast import ASTEdge, ASTNode, ASTObject
1718
from graphistry.compute.chain import Chain
1819
from graphistry.compute.filter_by_dict import filter_by_dict
19-
ComparisonOp = Literal["==", "!=", "<", "<=", ">", ">="]
20-
21-
22-
23-
@dataclass(frozen=True)
24-
class StepColumnRef:
25-
alias: str
26-
column: str
27-
28-
29-
@dataclass(frozen=True)
30-
class WhereComparison:
31-
left: StepColumnRef
32-
op: ComparisonOp
33-
right: StepColumnRef
20+
from graphistry.compute.gfql.same_path_types import ComparisonOp, WhereComparison
3421

3522

3623
@dataclass(frozen=True)
@@ -52,14 +39,6 @@ class OracleResult:
5239
edge_hop_labels: Optional[Dict[Any, int]] = None
5340

5441

55-
def col(alias: str, column: str) -> StepColumnRef:
56-
return StepColumnRef(alias, column)
57-
58-
59-
def compare(left: StepColumnRef, op: ComparisonOp, right: StepColumnRef) -> WhereComparison:
60-
return WhereComparison(left, op, right)
61-
62-
6342
def enumerate_chain(
6443
g: Plottable,
6544
ops: Sequence[ASTObject],
@@ -140,11 +119,9 @@ def enumerate_chain(
140119
paths = paths.drop(columns=[current])
141120
current = node_step["id_col"]
142121
else:
143-
if where:
144-
raise ValueError("WHERE clauses not supported for multi-hop edges in enumerator")
145-
if edge_step["alias"] or node_step["alias"]:
146-
# Alias tagging for multi-hop not yet supported in enumerator
147-
raise ValueError("Aliases not supported for multi-hop edges in enumerator")
122+
if edge_step["alias"]:
123+
# Edge alias tagging for multi-hop not yet supported in enumerator
124+
raise ValueError("Edge aliases not supported for multi-hop edges in enumerator")
148125

149126
dest_allowed: Optional[Set[Any]] = None
150127
if not node_frame.empty:
@@ -164,6 +141,12 @@ def enumerate_chain(
164141
for dst in bp_result.seed_to_nodes.get(seed_id, set()):
165142
new_rows.append([*row, dst])
166143
paths = pd.DataFrame(new_rows, columns=[*base_cols, node_step["id_col"]])
144+
paths = paths.merge(
145+
node_frame,
146+
on=node_step["id_col"],
147+
how="inner",
148+
validate="m:1",
149+
)
167150
current = node_step["id_col"]
168151

169152
# Stash edges/nodes and hop labels for final selection
@@ -182,6 +165,72 @@ def enumerate_chain(
182165

183166
if where:
184167
paths = paths[_apply_where(paths, where)]
168+
169+
# After WHERE filtering, prune collected_nodes/edges to only those in surviving paths
170+
# For multi-hop edges, we stored all reachable nodes/edges before WHERE filtering
171+
# Now we need to keep only those that participate in valid paths
172+
if len(paths) > 0:
173+
for i, edge_step in enumerate(edge_steps):
174+
if "collected_nodes" not in edge_step:
175+
continue
176+
start_col = node_steps[i]["id_col"]
177+
end_col = node_steps[i + 1]["id_col"]
178+
if start_col not in paths.columns or end_col not in paths.columns:
179+
continue
180+
valid_starts = set(paths[start_col].tolist())
181+
valid_ends = set(paths[end_col].tolist())
182+
183+
# Re-trace paths from valid_starts to valid_ends to find valid nodes/edges
184+
# Build adjacency from original edges, respecting direction
185+
direction = edge_step.get("direction", "forward")
186+
adjacency: Dict[Any, List[Tuple[Any, Any]]] = {}
187+
for _, row in edges_df.iterrows(): # type: ignore[assignment]
188+
src, dst, eid = row[edge_src], row[edge_dst], row[edge_id] # type: ignore[call-overload]
189+
if direction == "reverse":
190+
# Reverse: traverse dst -> src
191+
adjacency.setdefault(dst, []).append((eid, src))
192+
elif direction == "undirected":
193+
# Undirected: traverse both ways
194+
adjacency.setdefault(src, []).append((eid, dst))
195+
adjacency.setdefault(dst, []).append((eid, src))
196+
else:
197+
# Forward: traverse src -> dst
198+
adjacency.setdefault(src, []).append((eid, dst))
199+
200+
# BFS from valid_starts to find paths to valid_ends
201+
valid_nodes: Set[Any] = set()
202+
valid_edge_ids: Set[Any] = set()
203+
min_hops = edge_step.get("min_hops", 1)
204+
max_hops = edge_step.get("max_hops", 10)
205+
206+
for start in valid_starts:
207+
# Track paths: (current_node, path_edges, path_nodes)
208+
stack: List[Tuple[Any, List[Any], List[Any]]] = [(start, [], [start])]
209+
while stack:
210+
node, path_edges, path_nodes = stack.pop()
211+
if len(path_edges) >= max_hops:
212+
continue
213+
for eid, dst in adjacency.get(node, []):
214+
new_edges = path_edges + [eid]
215+
new_nodes = path_nodes + [dst]
216+
# Only include paths within [min_hops, max_hops] range
217+
if dst in valid_ends and len(new_edges) >= min_hops:
218+
# This path reaches a valid end - include all nodes/edges
219+
valid_nodes.update(new_nodes)
220+
valid_edge_ids.update(new_edges)
221+
if len(new_edges) < max_hops:
222+
stack.append((dst, new_edges, new_nodes))
223+
224+
edge_step["collected_nodes"] = valid_nodes
225+
edge_step["collected_edges"] = valid_edge_ids
226+
else:
227+
# No surviving paths - clear all collected nodes/edges
228+
for edge_step in edge_steps:
229+
if "collected_nodes" in edge_step:
230+
edge_step["collected_nodes"] = set()
231+
if "collected_edges" in edge_step:
232+
edge_step["collected_edges"] = set()
233+
185234
seq_cols: List[str] = []
186235
for i, node_step in enumerate(node_steps):
187236
seq_cols.append(node_step["id_col"])

tests/gfql/ref/conftest.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66

77
from graphistry.Engine import Engine
8-
from graphistry.compute.ast import ASTEdge
98
from graphistry.compute.gfql.df_executor import (
109
build_same_path_inputs,
1110
DFSamePathExecutor,
@@ -17,17 +16,6 @@
1716
TEST_CUDF = "TEST_CUDF" in os.environ and os.environ["TEST_CUDF"] == "1"
1817

1918

20-
def _has_multihop(chain) -> bool:
21-
"""Check if chain has any multi-hop edges (oracle doesn't support multi-hop + WHERE)."""
22-
for op in chain:
23-
if isinstance(op, ASTEdge):
24-
min_h = op.min_hops if op.min_hops is not None else (op.hops if isinstance(op.hops, int) else 1)
25-
max_h = op.max_hops if op.max_hops is not None else (op.hops if isinstance(op.hops, int) else min_h)
26-
if min_h != 1 or max_h != 1:
27-
return True
28-
return False
29-
30-
3119
def make_simple_graph():
3220
"""Create a simple account->user graph for basic tests."""
3321
nodes = pd.DataFrame(
@@ -70,26 +58,14 @@ def make_hop_graph():
7058

7159

7260
def assert_executor_parity(graph, chain, where):
73-
"""Assert executor parity with oracle. Tests pandas, and cudf if TEST_CUDF=1.
74-
75-
For multi-hop + WHERE, oracle comparison is skipped (oracle doesn't support it).
76-
We just verify the executor runs and produces valid output.
77-
"""
61+
"""Assert executor parity with oracle. Tests pandas, and cudf if TEST_CUDF=1."""
7862
inputs = build_same_path_inputs(graph, chain, where, Engine.PANDAS)
7963
executor = DFSamePathExecutor(inputs)
8064
executor._forward()
8165
result = executor._run_native()
8266

8367
assert result._nodes is not None and result._edges is not None
8468

85-
# Oracle doesn't support multi-hop + WHERE, skip comparison
86-
if where and _has_multihop(chain):
87-
# Just verify executor produced valid output
88-
assert "id" in result._nodes.columns
89-
assert "src" in result._edges.columns
90-
assert "dst" in result._edges.columns
91-
return
92-
9369
oracle = enumerate_chain(
9470
graph,
9571
chain,

tests/gfql/ref/test_df_executor_amplify.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,6 @@ class TestNodeEdgeMatchFilters:
979979
of the endpoint node filters or WHERE clauses.
980980
"""
981981

982-
@pytest.mark.skip(reason="Oracle doesn't support destination_node_match correctly")
983982
def test_destination_node_match_single_hop(self):
984983
"""
985984
destination_node_match restricts which nodes can be reached.
@@ -1012,7 +1011,6 @@ def test_destination_node_match_single_hop(self):
10121011
assert "b" in result_nodes, "should reach target type node"
10131012
assert "c" not in result_nodes, "should not reach other type node"
10141013

1015-
@pytest.mark.skip(reason="Oracle doesn't support source_node_match correctly")
10161014
def test_source_node_match_single_hop(self):
10171015
"""
10181016
source_node_match restricts which nodes can be traversed FROM.
@@ -1111,7 +1109,6 @@ def test_destination_node_match_multi_hop(self):
11111109
assert "b" in result_nodes, "should reach b (target) at hop 1"
11121110
assert "c" in result_nodes, "should reach c (target) at hop 2"
11131111

1114-
@pytest.mark.skip(reason="Oracle doesn't support source/destination_node_match correctly")
11151112
def test_combined_source_and_dest_match(self):
11161113
"""
11171114
Both source_node_match and destination_node_match together.

tests/gfql/ref/test_df_executor_core.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,6 @@ def test_cycle_with_branch(self):
12821282

12831283
_assert_parity(graph, chain, where)
12841284

1285-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
12861285
def test_oracle_cudf_parity_comprehensive(self):
12871286
"""
12881287
P0 Test 4: Oracle and cuDF executor must produce identical results.
@@ -1407,7 +1406,6 @@ class TestP1FeatureComposition:
14071406
cuDF executor's handling of multi-hop + WHERE combinations.
14081407
"""
14091408

1410-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
14111409
def test_multi_hop_edge_where_filtering(self):
14121410
"""
14131411
P1 Test 5: WHERE must be applied even for multi-hop edges.
@@ -1597,7 +1595,6 @@ class TestUnfilteredStarts:
15971595
instead of hop labels (which become ambiguous when all nodes can be starts).
15981596
"""
15991597

1600-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
16011598
def test_unfiltered_start_node_multihop(self):
16021599
"""
16031600
Unfiltered start node with multi-hop works via public API.
@@ -1663,7 +1660,6 @@ def test_unfiltered_start_single_hop(self):
16631660
result = execute_same_path_chain(graph, chain, where, Engine.PANDAS)
16641661
assert set(result._nodes["id"]) == set(oracle.nodes["id"])
16651662

1666-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
16671663
def test_unfiltered_start_with_cycle(self):
16681664
"""
16691665
Unfiltered start with cycle in graph.
@@ -1694,7 +1690,6 @@ def test_unfiltered_start_with_cycle(self):
16941690
result = execute_same_path_chain(graph, chain, where, Engine.PANDAS)
16951691
assert set(result._nodes["id"]) == set(oracle.nodes["id"])
16961692

1697-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
16981693
def test_unfiltered_start_multihop_reverse(self):
16991694
"""
17001695
Unfiltered start node with multi-hop REVERSE traversal + WHERE.
@@ -1729,7 +1724,6 @@ def test_unfiltered_start_multihop_reverse(self):
17291724
result = execute_same_path_chain(graph, chain, where, Engine.PANDAS)
17301725
assert set(result._nodes["id"]) == set(oracle.nodes["id"])
17311726

1732-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
17331727
def test_unfiltered_start_multihop_undirected(self):
17341728
"""
17351729
Unfiltered start node with multi-hop UNDIRECTED traversal + WHERE.
@@ -1762,7 +1756,6 @@ def test_unfiltered_start_multihop_undirected(self):
17621756
result = execute_same_path_chain(graph, chain, where, Engine.PANDAS)
17631757
assert set(result._nodes["id"]) == set(oracle.nodes["id"])
17641758

1765-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
17661759
def test_filtered_start_multihop_reverse_where(self):
17671760
"""
17681761
Filtered start node with multi-hop REVERSE + WHERE.
@@ -1796,7 +1789,6 @@ def test_filtered_start_multihop_reverse_where(self):
17961789
result = execute_same_path_chain(graph, chain, where, Engine.PANDAS)
17971790
assert set(result._nodes["id"]) == set(oracle.nodes["id"])
17981791

1799-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
18001792
def test_filtered_start_multihop_undirected_where(self):
18011793
"""
18021794
Filtered start with multi-hop UNDIRECTED + WHERE.
@@ -1841,7 +1833,10 @@ class TestOracleLimitations:
18411833
These test features the oracle doesn't support.
18421834
"""
18431835

1844-
@pytest.mark.skip(reason="Oracle doesn't support edge aliases on multi-hop edges")
1836+
@pytest.mark.xfail(
1837+
reason="Oracle doesn't support edge aliases on multi-hop edges",
1838+
strict=True,
1839+
)
18451840
def test_edge_alias_on_multihop(self):
18461841
"""
18471842
ORACLE LIMITATION: Edge alias on multi-hop edge.

tests/gfql/ref/test_df_executor_patterns.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2429,7 +2429,6 @@ def test_string_equality(self):
24292429
# Note: 'b' IS included because it's an intermediate node in the valid path a→b→c
24302430
# The executor returns ALL nodes participating in valid paths, not just endpoints
24312431

2432-
@pytest.mark.skip(reason="Oracle doesn't support multi-hop + WHERE")
24332432
def test_neq_with_nulls(self):
24342433
"""!= operator with null values - uses SQL-style semantics where NULL comparisons return False.
24352434

0 commit comments

Comments
 (0)