@@ -324,12 +324,86 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
324324 tracking .TRACKER .instance (input_tokens , output_tokens , cost )
325325
326326 if n_samples == 1 :
327- res = AIMessage (completion .choices [0 ].message .content )
327+ think , action = self ._extract_thinking_content_from_response (completion )
328+ res_think = AIMessage (think or "" )
329+ res_action = AIMessage (action or "" )
328330 if self .log_probs :
329- res ["log_probs" ] = completion .choices [0 ].log_probs
330- return res
331+ res_think ["log_probs" ] = completion .choices [0 ].logprobs
332+ return res_think , res_action
331333 else :
332- return [AIMessage (c .message .content ) for c in completion .choices ]
334+ return [
335+ self ._build_think_action_pair (choice )
336+ for choice in completion .choices
337+ ]
338+
339+ def _extract_thinking_content_from_response (self , response , wrap_tag = "think" ) -> tuple [str , str ]:
340+ """Extract reasoning and action content from an API response.
341+
342+ Handles multiple formats:
343+ 1. OpenAI/DeepSeek: reasoning in 'reasoning_content' or 'reasoning' field
344+ 2. Apriel: reasoning before [BEGIN FINAL RESPONSE]...[END FINAL RESPONSE] tags
345+ 3. Standard: content as-is
346+
347+ Args:
348+ response: The API response object.
349+ wrap_tag: Tag name to wrap reasoning content (default: "think").
350+
351+ Returns:
352+ tuple: (reasoning_wrapped, action_wrapped)
353+ """
354+ message = response .choices [0 ].message
355+ msg_dict = message .to_dict () if hasattr (message , 'to_dict' ) else dict (message )
356+
357+ reasoning = msg_dict .get ("reasoning_content" ) or msg_dict .get ("reasoning" )
358+ content = msg_dict .get ("content" , "" ) or msg_dict .get ("text" , "" )
359+
360+ # Case 1: Explicit reasoning field from API
361+ if reasoning :
362+ reasoning_wrapped = f"<{ wrap_tag } >{ reasoning } </{ wrap_tag } >\n "
363+ if "[BEGIN FINAL RESPONSE]" in content and "[END FINAL RESPONSE]" in content :
364+ action = self ._extract_last_action_from_tags (content )
365+ action_wrapped = f"<action>\n { action } \n </action>"
366+ else :
367+ action_wrapped = content
368+ return reasoning_wrapped , action_wrapped
369+
370+ # Case 2: Apriel-style format in content
371+ if "[BEGIN FINAL RESPONSE]" in content :
372+ reasoning_text , action_text = self ._parse_apriel_format (content )
373+ reasoning_wrapped = f"<{ wrap_tag } >\n { reasoning_text } \n </{ wrap_tag } >" if reasoning_text else ""
374+ action_wrapped = f"<action>\n { action_text } \n </action>" if action_text else ""
375+ return reasoning_wrapped , action_wrapped
376+
377+ # Case 3: No special format
378+ return "" , content
379+
380+ def _extract_last_action_from_tags (self , content : str ) -> str :
381+ """Extract content from the LAST [BEGIN FINAL RESPONSE]...[END FINAL RESPONSE] block."""
382+ pattern = r'\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]'
383+ matches = re .findall (pattern , content , re .DOTALL )
384+ return matches [- 1 ].strip () if matches else ""
385+
386+ def _parse_apriel_format (self , content : str ) -> tuple [str , str ]:
387+ """Parse Apriel format: reasoning before [BEGIN FINAL RESPONSE] tags."""
388+ last_begin = content .rfind ("[BEGIN FINAL RESPONSE]" )
389+ if last_begin == - 1 :
390+ return "" , content
391+
392+ reasoning = content [:last_begin ].strip ()
393+ if reasoning .startswith ("Here are my reasoning steps:" ):
394+ reasoning = reasoning [len ("Here are my reasoning steps:" ):].strip ()
395+
396+ action = self ._extract_last_action_from_tags (content )
397+ return reasoning , action
398+
399+ def _build_think_action_pair (self , choice ) -> tuple [AIMessage , AIMessage ]:
400+ """Build (think, action) pair from a single choice."""
401+ # Create minimal response-like object for the extraction method
402+ mock_response = type ('MockResponse' , (), {
403+ 'choices' : [choice ]
404+ })()
405+ think , action = self ._extract_thinking_content_from_response (mock_response )
406+ return AIMessage (think or "" ), AIMessage (action or "" )
333407
334408 def get_stats (self ):
335409 return {
@@ -484,6 +558,55 @@ def __init__(
484558 )
485559
486560
561+ class AprielChatModel (ChatModel ):
562+ """Chat model for Apriel models hosted on DGX Cloud."""
563+
564+ def __init__ (
565+ self ,
566+ model_name = "Slam-15B" ,
567+ api_key = None ,
568+ base_url = None ,
569+ temperature = 0.5 ,
570+ max_tokens = 15000 ,
571+ max_retry = 4 ,
572+ min_retry_wait_time = 60 ,
573+ ):
574+ base_url = base_url or os .getenv (
575+ "APRIEL_API_URL" ,
576+ ""
577+ )
578+ api_key = api_key or os .getenv ("APRIEL_API_KEY" )
579+
580+ super ().__init__ (
581+ model_name = model_name ,
582+ api_key = api_key ,
583+ temperature = temperature ,
584+ max_tokens = max_tokens ,
585+ max_retry = max_retry ,
586+ min_retry_wait_time = min_retry_wait_time ,
587+ client_class = OpenAI ,
588+ client_args = {"base_url" : base_url },
589+ pricing_func = None ,
590+ )
591+
592+
593+ @dataclass
594+ class AprielModelArgs (BaseModelArgs ):
595+ """Serializable args for Apriel models."""
596+
597+ base_url : str = None
598+ api_key : str = None
599+
600+ def make_model (self ):
601+ return AprielChatModel (
602+ model_name = self .model_name ,
603+ base_url = self .base_url ,
604+ api_key = self .api_key ,
605+ temperature = self .temperature ,
606+ max_tokens = self .max_new_tokens ,
607+ )
608+
609+
487610class AnthropicChatModel (AbstractChatModel ):
488611 def __init__ (
489612 self ,
0 commit comments