Skip to content

Commit a85f259

Browse files
authored
[Fix] Fix lambda node failures (#67)
Fix lambda node failures
1 parent d18cb3f commit a85f259

File tree

5 files changed

+10
-13
lines changed

5 files changed

+10
-13
lines changed

sygra/core/graph/backend_factory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@ class BackendFactory(ABC):
1010
"""
1111

1212
@abstractmethod
13-
def create_lambda_runnable(self, function_to_execute, node_config):
13+
def create_lambda_runnable(self, exec_wrapper):
1414
"""
1515
Abstract method to create a Lambda runnable.
1616
1717
Args:
18-
function_to_execute: Python function to execute, if it is a class it should be callable(__call__)
19-
node_config:node config dictionary
18+
exec_wrapper: Async function to execute
2019
2120
Returns:
2221
Any: backend specific runnable object like Runnable for backend=Langgraph

sygra/core/graph/langgraph/langgraph_factory.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import partial
21
from typing import Any
32

43
from langchain_core.messages import BaseMessage
@@ -18,18 +17,17 @@ class LangGraphFactory(BackendFactory):
1817
A factory class to convert Nodes into Runnable objects which LangGraph framework can execute.
1918
"""
2019

21-
def create_lambda_runnable(self, function_to_execute, node_config):
20+
def create_lambda_runnable(self, exec_wrapper):
2221
"""
2322
Abstract method to create a Lambda runnable.
2423
2524
Args:
26-
function_to_execute: Python function to execute, if it is a class it should be callable(__call__)
27-
node_config:node config dictionary
25+
exec_wrapper: Async function to execute
2826
2927
Returns:
3028
Any: backend specific runnable object like Runnable for backend=Langgraph
3129
"""
32-
return RunnableLambda(partial(function_to_execute, node_config))
30+
return RunnableLambda(lambda x: x, afunc=exec_wrapper)
3331

3432
def create_llm_runnable(self, exec_wrapper):
3533
"""

sygra/core/graph/nodes/lambda_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def _exec_wrapper(self, state: dict[str, Any]) -> dict[str, Any]:
3838
success = True
3939

4040
try:
41-
result: dict[str, Any] = self.func_to_execute(state)
41+
result: dict[str, Any] = self.func_to_execute(self.node_config, state)
4242
return result
4343
except Exception:
4444
success = False
@@ -53,7 +53,7 @@ def to_backend(self) -> Any:
5353
Returns:
5454
Any: platform specific runnable object like Runnable in LangGraph.
5555
"""
56-
return utils.backend_factory.create_lambda_runnable(self._exec_wrapper, self.node_config)
56+
return utils.backend_factory.create_lambda_runnable(self._exec_wrapper)
5757

5858
def validate_node(self):
5959
"""

sygra/core/models/lite_llm/vllm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, model_config: dict[str, Any]) -> None:
2222
utils.validate_required_keys(["url", "auth_token"], model_config, "model")
2323
self.model_config = model_config
2424
self.auth_token = str(model_config.get("auth_token")).replace("Bearer ", "")
25-
self.model_serving_name = model_config.get("model_serving_name", self.name())
25+
self.model_name = model_config.get("model_serving_name", self.name())
2626

2727
def _validate_completions_api_model_support(self) -> None:
2828
logger.info(f"Model {self.name()} supports completion API.")

tests/core/models/lite_llm/test_litellm_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def test_init(self):
5656
self.assertEqual(model.generation_params, self.base_config["parameters"])
5757
self.assertEqual(model.name(), "vllm_model")
5858
self.assertEqual(model.auth_token, "test_token_123")
59-
self.assertEqual(model.model_serving_name, "vllm_model")
59+
self.assertEqual(model.model_name, "vllm_model")
6060

6161
def test_init_with_custom_serving_name(self):
6262
model = LiteLLMVLLM(self.serving_name_config)
63-
self.assertEqual(model.model_serving_name, "custom_serving_name")
63+
self.assertEqual(model.model_name, "custom_serving_name")
6464

6565
@patch("sygra.core.models.lite_llm.vllm_model.logger")
6666
def test_init_with_completions_api(self, mock_logger):

0 commit comments

Comments
 (0)