Skip to content

Commit c447936

Browse files
refactor: refactor build_kg to accomodate ray data
1 parent c844d65 commit c447936

File tree

9 files changed

+258
-226
lines changed

9 files changed

+258
-226
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
7979
backend = config.pop("backend")
8080
llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
8181
return llm_wrapper
82+
83+
# TODO: use ray serve when loading large models to avoid re-loading in each actor

graphgen/engine.py

Lines changed: 165 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,186 @@
1-
"""
2-
orchestration engine for GraphGen
3-
"""
1+
import inspect
2+
import logging
3+
from collections import defaultdict, deque
4+
from functools import wraps
5+
from typing import Any, Callable, Dict, List, Set
46

5-
import threading
6-
import traceback
7-
from typing import Any, Callable, List
7+
import ray
8+
import ray.data
89

10+
from graphgen.bases import Config, Node
911

10-
class Context(dict):
11-
_lock = threading.Lock()
1212

13-
def set(self, k, v):
14-
with self._lock:
15-
self[k] = v
13+
class Engine:
14+
def __init__(
15+
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
16+
):
17+
self.config = Config(**config)
18+
self.functions = functions
19+
self.datasets: Dict[str, ray.data.Dataset] = {}
20+
21+
if not ray.is_initialized():
22+
context = ray.init(
23+
ignore_reinit_error=True,
24+
logging_level=logging.ERROR,
25+
log_to_driver=True,
26+
**ray_init_kwargs,
27+
)
28+
print(f"Ray Dashboard URL: {context.dashboard_url}")
1629

17-
def get(self, k, default=None):
18-
with self._lock:
19-
return super().get(k, default)
30+
@staticmethod
31+
def _topo_sort(nodes: List[Node]) -> List[Node]:
32+
id_to_node: Dict[str, Node] = {}
33+
for n in nodes:
34+
id_to_node[n.id] = n
35+
36+
indeg: Dict[str, int] = {nid: 0 for nid in id_to_node}
37+
adj: Dict[str, List[str]] = defaultdict(list)
38+
39+
for n in nodes:
40+
nid = n.id
41+
deps: List[str] = n.dependencies
42+
uniq_deps: Set[str] = set(deps)
43+
for d in uniq_deps:
44+
if d not in id_to_node:
45+
raise ValueError(
46+
f"The dependency node id {d} of node {nid} is not defined in the configuration."
47+
)
48+
indeg[nid] += 1
49+
adj[d].append(nid)
50+
51+
zero_deg: deque = deque(
52+
[id_to_node[nid] for nid, deg in indeg.items() if deg == 0]
53+
)
54+
sorted_nodes: List[Node] = []
55+
56+
while zero_deg:
57+
cur = zero_deg.popleft()
58+
sorted_nodes.append(cur)
59+
cur_id = cur.id
60+
for nb_id in adj.get(cur_id, []):
61+
indeg[nb_id] -= 1
62+
if indeg[nb_id] == 0:
63+
zero_deg.append(id_to_node[nb_id])
64+
65+
if len(sorted_nodes) != len(nodes):
66+
remaining = [nid for nid, deg in indeg.items() if deg > 0]
67+
raise ValueError(
68+
f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}"
69+
)
2070

71+
return sorted_nodes
2172

22-
class OpNode:
23-
def __init__(
24-
self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
25-
):
26-
self.name, self.deps, self.func = name, deps, func
73+
def _get_input_dataset(
74+
self, node: Node, initial_ds: ray.data.Dataset
75+
) -> ray.data.Dataset:
76+
deps = node.dependencies
2777

78+
if not deps:
79+
return initial_ds
2880

29-
class Engine:
30-
def __init__(self, max_workers: int = 4):
31-
self.max_workers = max_workers
32-
33-
def run(self, ops: List[OpNode], ctx: Context):
34-
self._validate(ops)
35-
name2op = {operation.name: operation for operation in ops}
36-
37-
# topological sort
38-
graph = {n: set(name2op[n].deps) for n in name2op}
39-
topo = []
40-
q = [n for n, d in graph.items() if not d]
41-
while q:
42-
cur = q.pop(0)
43-
topo.append(cur)
44-
for child in [c for c, d in graph.items() if cur in d]:
45-
graph[child].remove(cur)
46-
if not graph[child]:
47-
q.append(child)
48-
49-
if len(topo) != len(ops):
81+
if len(deps) == 1:
82+
return self.datasets[deps[0]]
83+
84+
main_ds = self.datasets[deps[0]]
85+
other_dss = [self.datasets[d] for d in deps[1:]]
86+
if not all(ds.schema() == main_ds.schema() for ds in other_dss):
5087
raise ValueError(
51-
"Cyclic dependencies detected among operations."
52-
"Please check your configuration."
88+
f"Union requires all datasets to have the same schema for node {node.id}"
5389
)
90+
return main_ds.union(*other_dss)
91+
92+
def _execute_node(self, node: Node, initial_ds: ray.data.Dataset):
93+
if node.op_name not in self.functions:
94+
raise ValueError(f"Operator {node.op_name} not found for node {node.id}")
95+
96+
if node.type == "source":
97+
op_handler = self.functions[node.op_name]
98+
node_params = node.params
99+
self.datasets[node.id] = op_handler(**node_params)
100+
return
54101

55-
# semaphore for max_workers
56-
sem = threading.Semaphore(self.max_workers)
57-
done = {n: threading.Event() for n in name2op}
58-
exc = {}
59-
60-
def _exec(n: str):
61-
with sem:
62-
for d in name2op[n].deps:
63-
done[d].wait()
64-
if any(d in exc for d in name2op[n].deps):
65-
exc[n] = Exception("Skipped due to failed dependencies")
66-
done[n].set()
67-
return
68-
try:
69-
name2op[n].func(name2op[n], ctx)
70-
except Exception:
71-
exc[n] = traceback.format_exc()
72-
done[n].set()
73-
74-
ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
75-
for t in ts:
76-
t.start()
77-
for t in ts:
78-
t.join()
79-
if exc:
80-
raise RuntimeError(
81-
"Some operations failed:\n"
82-
+ "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
102+
input_ds = self._get_input_dataset(node, initial_ds)
103+
104+
op_handler = self.functions[node.op_name]
105+
node_params = node.params
106+
107+
if inspect.isclass(op_handler):
108+
replicas = node_params.pop("replicas", 1)
109+
batch_size = (
110+
int(node_params.pop("batch_size"))
111+
if "batch_size" in node_params
112+
else "default"
83113
)
114+
compute_resources = node_params.pop("compute_resources", {})
115+
116+
if node.type == "aggregate":
117+
self.datasets[node.id] = input_ds.repartition(1).map_batches(
118+
op_handler,
119+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
120+
batch_size=None, # aggregate processes the whole dataset at once
121+
num_gpus=compute_resources.get("num_gpus", 0)
122+
if compute_resources
123+
else 0,
124+
fn_constructor_kwargs=node_params,
125+
batch_format="pandas",
126+
)
127+
else:
128+
# others like map, filter, flatmap, map_batch let actors process data inside batches
129+
self.datasets[node.id] = input_ds.map_batches(
130+
op_handler,
131+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
132+
batch_size=batch_size,
133+
num_gpus=compute_resources.get("num_gpus", 0)
134+
if compute_resources
135+
else 0,
136+
fn_constructor_kwargs=node_params,
137+
batch_format="pandas",
138+
)
84139

85-
@staticmethod
86-
def _validate(ops: List[OpNode]):
87-
name_set = set()
88-
for op in ops:
89-
if op.name in name_set:
90-
raise ValueError(f"Duplicate operation name: {op.name}")
91-
name_set.add(op.name)
92-
for op in ops:
93-
for dep in op.deps:
94-
if dep not in name_set:
95-
raise ValueError(
96-
f"Operation {op.name} has unknown dependency: {dep}"
97-
)
140+
else:
98141

142+
@wraps(op_handler)
143+
def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
144+
return op_handler(row_or_batch, **node_params)
145+
146+
if node.type == "map":
147+
self.datasets[node.id] = input_ds.map(func_wrapper)
148+
elif node.type == "filter":
149+
self.datasets[node.id] = input_ds.filter(func_wrapper)
150+
elif node.type == "flatmap":
151+
self.datasets[node.id] = input_ds.flat_map(func_wrapper)
152+
elif node.type == "aggregate":
153+
self.datasets[node.id] = input_ds.repartition(1).map_batches(
154+
func_wrapper, batch_format="default"
155+
)
156+
elif node.type == "map_batch":
157+
self.datasets[node.id] = input_ds.map_batches(func_wrapper)
158+
else:
159+
raise ValueError(
160+
f"Unsupported node type {node.type} for node {node.id}"
161+
)
99162

100-
def collect_ops(config: dict, graph_gen) -> List[OpNode]:
101-
"""
102-
build operation nodes from yaml config
103-
:param config
104-
:param graph_gen
105-
"""
106-
ops: List[OpNode] = []
107-
for stage in config["pipeline"]:
108-
name = stage["name"]
109-
method_name = stage.get("op_key")
110-
method = getattr(graph_gen, method_name)
111-
deps = stage.get("deps", [])
163+
@staticmethod
164+
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
165+
all_ids = {n.id for n in nodes}
166+
deps_set = set()
167+
for n in nodes:
168+
deps_set.update(n.dependencies)
169+
return all_ids - deps_set
112170

113-
if "params" in stage:
171+
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, List[Any]]:
172+
sorted_nodes = self._topo_sort(self.config.nodes)
114173

115-
def func(self, ctx, _method=method, _params=stage.get("params", {})):
116-
return _method(_params)
174+
for node in sorted_nodes:
175+
self._execute_node(node, initial_ds)
117176

118-
else:
177+
leaf_nodes = self._find_leaf_nodes(sorted_nodes)
119178

120-
def func(self, ctx, _method=method):
121-
return _method()
179+
@ray.remote
180+
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
181+
return ds.take_all()
122182

123-
op_node = OpNode(name=name, deps=deps, func=func)
124-
ops.append(op_node)
125-
return ops
183+
results = ray.get(
184+
[_fetch_result.remote(self.datasets[node_id]) for node_id in leaf_nodes]
185+
)
186+
return dict(zip(leaf_nodes, results))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .build_kg import build_kg
1+
from .build_kg_service import BuildKGService

graphgen/operators/build_kg/build_kg.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)