Skip to content

Commit 3185749

Browse files
authored
[Enhancement] Refactor graph configuration management and custom config support (#53)
* fix: apply overrides to the graph config * fix: apply overrides for config keys * fix: assign to executor config * fix: avoid graph config reload * fix: add custom config * fix: resolve mypy issues * fix: resolve formatting * docs: add test file and documentation * refactor: fix lint and formatting * docs: update Sygra lib docs structure
1 parent 1f06c8d commit 3185749

File tree

10 files changed

+2421
-74
lines changed

10 files changed

+2421
-74
lines changed
File renamed without changes.

docs/library/sygra_library_examples.md

Lines changed: 1783 additions & 0 deletions
Large diffs are not rendered by default.

mkdocs.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ nav:
5757
- Image to QnA: tutorials/image_to_qna_tutorial.md
5858
- Structured Output: tutorials/structured_output_tutorial.md
5959
- Structured Output with Multi-LLM: tutorials/structured_output_with_multi_llm_tutorial.md
60-
- SyGra Library: sygra_library.md
60+
- SyGra Library:
61+
- API Reference: library/sygra_library.md
62+
- Examples: library/sygra_library_examples.md
6163
- Contribute: development.md
6264
plugins:
6365
- search

sygra/configuration/loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Union, cast
3+
from typing import Any, Union
44

55
import yaml
66

77
try:
88
from sygra.core.dataset.dataset_config import DataSourceConfig, OutputConfig # noqa: F401
99
from sygra.core.graph.graph_config import GraphConfig # noqa: F401
1010
from sygra.utils import utils
11+
from sygra.workflow import AutoNestedDict
1112

1213
UTILS_AVAILABLE = True
1314
except ImportError:
@@ -34,7 +35,10 @@ def load(self, config_path: Union[str, Path, dict[str, Any]]) -> dict[str, Any]:
3435
raise FileNotFoundError(f"Configuration file not found: {config_path}")
3536

3637
with open(config_path, "r") as f:
37-
config = cast(dict[str, Any], yaml.safe_load(f))
38+
loaded_config = yaml.safe_load(f)
39+
if not isinstance(loaded_config, dict):
40+
raise TypeError(f"Expected dict from YAML, got {type(loaded_config)}")
41+
config: dict[str, Any] = loaded_config
3842

3943
return config
4044

@@ -46,7 +50,7 @@ def load_and_create(self, config_path: Union[str, Path, dict[str, Any]]):
4650
from ..workflow import Workflow
4751

4852
workflow = Workflow()
49-
workflow._config = config
53+
workflow._config = AutoNestedDict.convert_dict(config)
5054

5155
workflow._supports_subgraphs = True
5256
workflow._supports_multimodal = True

sygra/core/base_task_executor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,14 @@ def __init__(self, args: Any, graph_config_dict: Optional[dict] = None):
5050

5151
self.dataset = self.init_dataset()
5252
output_transform_args = {"oasst": args.oasst, "quality": args.quality}
53+
54+
graph_props = self.config.get("graph_config", {}).get("graph_properties", {})
55+
5356
self.graph_config = GraphConfig(
54-
utils.get_file_in_task_dir(self.task_name, "graph_config.yaml"),
57+
self.config,
5558
self.dataset,
5659
output_transform_args,
60+
graph_properties=graph_props,
5761
)
5862
self.output_generator: Optional[BaseOutputGenerator] = self._init_output_generator()
5963

@@ -662,6 +666,6 @@ class DefaultTaskExecutor(BaseTaskExecutor):
662666
we fall back to this class by default.
663667
"""
664668

665-
def __init__(self, args):
666-
super().__init__(args)
669+
def __init__(self, args, graph_config_dict=None):
670+
super().__init__(args, graph_config_dict)
667671
logger.info("Using DefaultTaskExecutor for task: %s", self.task_name)

sygra/core/graph/graph_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
output_transform_args: dict,
2424
parent_node: str = "",
2525
override_config=None, # New parameter for overrides
26+
graph_properties: Optional[dict] = None,
2627
) -> None:
2728
"""
2829
Initialize a GraphConfig.
@@ -42,6 +43,7 @@ def __init__(
4243
self.state_variables: set = set()
4344
self.graph_state_config: dict[str, Any] = {}
4445
self.pattern_to_extract_variables = r"(?<!\{)\{([^{}]+)\}(?!\})"
46+
self.graph_properties = graph_properties or {}
4547

4648
if isinstance(config, str):
4749
config = utils.load_yaml_file(filepath=config)

sygra/core/graph/nodes/llm_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, node_name: str, config: dict):
5959
self._initialize_model()
6060

6161
self.task_name = utils.current_task
62-
self.graph_properties = utils.get_graph_properties(self.task_name)
62+
self.graph_properties = getattr(utils, "_current_graph_properties", {})
6363

6464
def _initialize_model(self):
6565
"""

sygra/utils/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sygra.utils import constants
2525

2626

27-
def load_model_config() -> Any:
27+
def load_model_config(config_path: Optional[str] = None) -> Any:
2828
"""
2929
Load model configurations from both models.yaml and environment variables.
3030
@@ -39,6 +39,10 @@ def load_model_config() -> Any:
3939
Example: "http://url1.com|http://url2.com|http://url3.com"
4040
- SYGRA_{MODEL_NAME}_TOKEN: Authentication token or API key for the model
4141
42+
Args:
43+
config_path: Optional path to custom config file.
44+
Custom configs override default models.yaml values.
45+
4246
Returns:
4347
Dict containing combined model configurations
4448
"""
@@ -50,6 +54,11 @@ def load_model_config() -> Any:
5054
# Load base configurations from models.yaml
5155
base_configs = load_yaml_file(constants.MODEL_CONFIG_YAML)
5256

57+
# Load and merge custom config if provided
58+
if config_path and os.path.exists(config_path):
59+
custom_configs = load_yaml_file(config_path)
60+
base_configs = {**base_configs, **custom_configs}
61+
5362
# For each model, look for corresponding environment variables
5463
for model_name, config in base_configs.items():
5564
# Convert model name to uppercase for environment variable lookup

0 commit comments

Comments
 (0)