Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: schema_guided
schema_path: graphgen/templates/extraction/schemas/legal_contract.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true # save output
params:
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML
1 change: 1 addition & 0 deletions examples/generate/generate_atomic_qa/atomic_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: atomic
data_format: Alpaca
1 change: 1 addition & 0 deletions examples/generate/generate_cot_qa/cot_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: cot
data_format: Sharegpt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: multi_hop
data_format: ChatML
1 change: 1 addition & 0 deletions examples/generate/generate_vqa/vqa_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: vqa
data_format: ChatML
6 changes: 5 additions & 1 deletion graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ class Node(BaseModel):
default_factory=list, description="list of dependent node ids"
)
execution_params: dict = Field(
default_factory=dict, description="execution parameters like replicas, batch_size"
default_factory=dict,
description="execution parameters like replicas, batch_size",
)
save_output: bool = Field(
default=False, description="whether to save the output of this node"
)

@classmethod
Expand Down
26 changes: 7 additions & 19 deletions graphgen/engine.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
import inspect
import logging
import os
from collections import defaultdict, deque
from functools import wraps
from typing import Any, Callable, Dict, List, Set
from dotenv import load_dotenv

import ray
import ray.data
from dotenv import load_dotenv
from ray.data import DataContext
Comment on lines 8 to 11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports within a group should be sorted alphabetically. This improves readability and makes it easier to find imports.

Suggested change
import ray
import ray.data
from dotenv import load_dotenv
from ray.data import DataContext
from dotenv import load_dotenv
import ray
import ray.data
from ray.data import DataContext


from graphgen.bases import Config, Node
from graphgen.utils import logger
from graphgen.common import init_llm, init_storage
from graphgen.utils import logger

load_dotenv()


class Engine:
def __init__(
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
Expand All @@ -42,7 +43,7 @@ def __init__(
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
ray_init_kwargs["runtime_env"]["env_vars"] = {
**all_env_vars,
**existing_env_vars
**existing_env_vars,
}

if not ray.is_initialized():
Expand Down Expand Up @@ -265,24 +266,11 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
f"Unsupported node type {node.type} for node {node.id}"
)

@staticmethod
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
all_ids = {n.id for n in nodes}
deps_set = set()
for n in nodes:
deps_set.update(n.dependencies)
return all_ids - deps_set

def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
sorted_nodes = self._topo_sort(self.config.nodes)

for node in sorted_nodes:
self._execute_node(node, initial_ds)

leaf_nodes = self._find_leaf_nodes(sorted_nodes)

@ray.remote
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
return ds.take_all()

return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
return {node.id: self.datasets[node.id] for node in output_nodes}
Loading