diff --git a/literalai/client.py b/literalai/client.py index c3743cd..8401e96 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, overload, TypeVar, Union from literalai.api import AsyncLiteralAPI, LiteralAPI from literalai.callback.langchain_callback import get_langchain_callback @@ -26,6 +26,8 @@ from literalai.requirements import check_all_requirements +T = TypeVar("T", bound=Callable) + class BaseLiteralClient: """ @@ -144,9 +146,15 @@ def langchain_callback( **kwargs, ) + @overload + def thread(self, original_function: Callable[..., T], **kwargs) -> Callable[..., T]: ... + + @overload + def thread(self, **kwargs) -> ThreadContextManager: ... + def thread( self, - original_function=None, + original_function: Optional[Callable] = None, *, thread_id: Optional[str] = None, name: Optional[str] = None,