1414from parea .helpers import timezone_aware_now
1515from parea .schemas .models import TraceLog , UpdateTraceScenario
1616from parea .utils .trace_utils import call_eval_funcs_then_log , fill_trace_data , trace_context , trace_data
17- from parea .wrapper .utils import skip_decorator_if_func_in_stack
17+ from parea .wrapper .utils import safe_format_template_to_prompt , skip_decorator_if_func_in_stack
1818
1919logger = logging .getLogger ()
2020
@@ -72,13 +72,18 @@ def _get_decorator(self, unwrapped_func: Callable, original_func: Callable):
7272 else :
7373 return self .sync_decorator (original_func )
7474
75- def _init_trace (self ) -> Tuple [str , datetime , contextvars .Token ]:
75+ def _init_trace (self , kwargs ) -> Tuple [str , datetime , contextvars .Token ]:
7676 start_time = timezone_aware_now ()
7777 trace_id = str (uuid4 ())
7878
7979 new_trace_context = trace_context .get () + [trace_id ]
8080 token = trace_context .set (new_trace_context )
8181
82+ if template_inputs := kwargs .pop ("template_inputs" , None ):
83+ for m in kwargs ["messages" ] or []:
84+ if isinstance (m , dict ) and "content" in m :
85+ m ["content" ] = safe_format_template_to_prompt (m ["content" ], ** template_inputs )
86+
8287 if TURN_OFF_PAREA_LOGGING :
8388 return trace_id , start_time , token
8489 try :
@@ -93,7 +98,7 @@ def _init_trace(self) -> Tuple[str, datetime, contextvars.Token]:
9398 metadata = None ,
9499 target = None ,
95100 tags = None ,
96- inputs = {} ,
101+ inputs = template_inputs ,
97102 experiment_uuid = os .getenv (PAREA_OS_ENV_EXPERIMENT_UUID , None ),
98103 )
99104
@@ -109,7 +114,7 @@ def _init_trace(self) -> Tuple[str, datetime, contextvars.Token]:
109114 def async_decorator (self , orig_func : Callable ) -> Callable :
110115 @functools .wraps (orig_func )
111116 async def wrapper (* args , ** kwargs ):
112- trace_id , start_time , context_token = self ._init_trace ()
117+ trace_id , start_time , context_token = self ._init_trace (kwargs )
113118 response = None
114119 exception = None
115120 error = None
@@ -141,7 +146,7 @@ async def wrapper(*args, **kwargs):
141146 def sync_decorator (self , orig_func : Callable ) -> Callable :
142147 @functools .wraps (orig_func )
143148 def wrapper (* args , ** kwargs ):
144- trace_id , start_time , context_token = self ._init_trace ()
149+ trace_id , start_time , context_token = self ._init_trace (kwargs )
145150 response = None
146151 error = None
147152 cache_hit = False
0 commit comments