Skip to content
Merged
File renamed without changes.
1,783 changes: 1,783 additions & 0 deletions docs/library/sygra_library_examples.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ nav:
- Image to QnA: tutorials/image_to_qna_tutorial.md
- Structured Output: tutorials/structured_output_tutorial.md
- Structured Output with Multi-LLM: tutorials/structured_output_with_multi_llm_tutorial.md
- SyGra Library: sygra_library.md
- SyGra Library:
- API Reference: library/sygra_library.md
- Examples: library/sygra_library_examples.md
- Contribute: development.md
plugins:
- search
Expand Down
10 changes: 7 additions & 3 deletions sygra/configuration/loader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from pathlib import Path
from typing import Any, Union, cast
from typing import Any, Union

import yaml

try:
from sygra.core.dataset.dataset_config import DataSourceConfig, OutputConfig # noqa: F401
from sygra.core.graph.graph_config import GraphConfig # noqa: F401
from sygra.utils import utils
from sygra.workflow import AutoNestedDict

UTILS_AVAILABLE = True
except ImportError:
Expand All @@ -34,7 +35,10 @@ def load(self, config_path: Union[str, Path, dict[str, Any]]) -> dict[str, Any]:
raise FileNotFoundError(f"Configuration file not found: {config_path}")

with open(config_path, "r") as f:
config = cast(dict[str, Any], yaml.safe_load(f))
loaded_config = yaml.safe_load(f)
if not isinstance(loaded_config, dict):
raise TypeError(f"Expected dict from YAML, got {type(loaded_config)}")
config: dict[str, Any] = loaded_config

return config

Expand All @@ -46,7 +50,7 @@ def load_and_create(self, config_path: Union[str, Path, dict[str, Any]]):
from ..workflow import Workflow

workflow = Workflow()
workflow._config = config
workflow._config = AutoNestedDict.convert_dict(config)

workflow._supports_subgraphs = True
workflow._supports_multimodal = True
Expand Down
10 changes: 7 additions & 3 deletions sygra/core/base_task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ def __init__(self, args: Any, graph_config_dict: Optional[dict] = None):

self.dataset = self.init_dataset()
output_transform_args = {"oasst": args.oasst, "quality": args.quality}

graph_props = self.config.get("graph_config", {}).get("graph_properties", {})

self.graph_config = GraphConfig(
utils.get_file_in_task_dir(self.task_name, "graph_config.yaml"),
self.config,
self.dataset,
output_transform_args,
graph_properties=graph_props,
)
self.output_generator: Optional[BaseOutputGenerator] = self._init_output_generator()

Expand Down Expand Up @@ -662,6 +666,6 @@ class DefaultTaskExecutor(BaseTaskExecutor):
we fall back to this class by default.
"""

def __init__(self, args):
super().__init__(args)
def __init__(self, args, graph_config_dict=None):
super().__init__(args, graph_config_dict)
logger.info("Using DefaultTaskExecutor for task: %s", self.task_name)
2 changes: 2 additions & 0 deletions sygra/core/graph/graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
output_transform_args: dict,
parent_node: str = "",
override_config=None, # New parameter for overrides
graph_properties: Optional[dict] = None,
) -> None:
"""
Initialize a GraphConfig.
Expand All @@ -42,6 +43,7 @@ def __init__(
self.state_variables: set = set()
self.graph_state_config: dict[str, Any] = {}
self.pattern_to_extract_variables = r"(?<!\{)\{([^{}]+)\}(?!\})"
self.graph_properties = graph_properties or {}

if isinstance(config, str):
config = utils.load_yaml_file(filepath=config)
Expand Down
2 changes: 1 addition & 1 deletion sygra/core/graph/nodes/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, node_name: str, config: dict):
self._initialize_model()

self.task_name = utils.current_task
self.graph_properties = utils.get_graph_properties(self.task_name)
self.graph_properties = getattr(utils, "_current_graph_properties", {})

def _initialize_model(self):
"""
Expand Down
11 changes: 10 additions & 1 deletion sygra/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sygra.utils import constants


def load_model_config() -> Any:
def load_model_config(config_path: Optional[str] = None) -> Any:
"""
Load model configurations from both models.yaml and environment variables.

Expand All @@ -39,6 +39,10 @@ def load_model_config() -> Any:
Example: "http://url1.com|http://url2.com|http://url3.com"
- SYGRA_{MODEL_NAME}_TOKEN: Authentication token or API key for the model

Args:
config_path: Optional path to custom config file.
Custom configs override default models.yaml values.

Returns:
Dict containing combined model configurations
"""
Expand All @@ -50,6 +54,11 @@ def load_model_config() -> Any:
# Load base configurations from models.yaml
base_configs = load_yaml_file(constants.MODEL_CONFIG_YAML)

# Load and merge custom config if provided
if config_path and os.path.exists(config_path):
custom_configs = load_yaml_file(config_path)
base_configs = {**base_configs, **custom_configs}

# For each model, look for corresponding environment variables
for model_name, config in base_configs.items():
# Convert model name to uppercase for environment variable lookup
Expand Down
Loading