@@ -26,6 +26,19 @@ def __init__(self) -> None:
2626 description = "Check tool call format including think, answer and tool_call tags with JSON validation." ,
2727 )
2828
29+ # patterns for identifiying tags
30+ self ._think_pattern = re .compile (r"<think>(.*?)</think>" , re .DOTALL )
31+ self ._answer_pattern = re .compile (r"<answer>(.*?)</answer>" , re .DOTALL )
32+ self ._tool_call_pattern = re .compile (r"<tool_call>(.*?)</tool_call>" , re .DOTALL )
33+
34+ self ._think_answer_pattern = re .compile (r"^\s*<think>.*?</think>\s*<answer>.*?</answer>\s*$" , re .DOTALL )
35+ self ._think_tool_call_pattern = re .compile (
36+ r"^\s*<think>.*?</think>\s*(?:<tool_call>.*?</tool_call>\s*)+$" , re .DOTALL
37+ )
38+
39+ self ._consecutive_start_tool_call_tag_pattern = re .compile (r"<tool_call>\s*<tool_call>" )
40+ self ._consecutive_end_tool_call_tag_pattern = re .compile (r"</tool_call>\s*</tool_call>" )
41+
2942 # pylint: disable=too-many-statements
3043 async def aevaluate (self , response : str , ** kwargs : Any ) -> GraderScore :
3144 """
@@ -69,13 +82,9 @@ async def aevaluate(self, response: str, **kwargs: Any) -> GraderScore:
6982 """
7083
7184 # Extract tag contents
72- think_pattern = r"<think>(.*?)</think>"
73- answer_pattern = r"<answer>(.*?)</answer>"
74- tool_call_pattern = r"<tool_call>(.*?)</tool_call>"
75-
76- think_matches = re .search (think_pattern , response , re .DOTALL )
77- answer_matches = re .search (answer_pattern , response , re .DOTALL )
78- tool_call_matches = re .findall (tool_call_pattern , response , re .DOTALL )
85+ think_matches = self ._think_pattern .search (response )
86+ answer_matches = self ._answer_pattern .search (response )
87+ tool_call_matches = self ._tool_call_pattern .findall (response )
7988
8089 has_think_tag = think_matches is not None
8190 has_answer_tag = answer_matches is not None
@@ -89,9 +98,8 @@ async def aevaluate(self, response: str, **kwargs: Any) -> GraderScore:
8998 # Case 1: <think></think> + <answer></answer>
9099 if has_answer_tag and not has_tool_call_tag :
91100 # Check overall format
92- format_pattern = r"^\s*<think>.*?</think>\s*<answer>.*?</answer>\s*$"
93101 valid_format = bool (
94- re . match (format_pattern , response , re . DOTALL ),
102+ self . _think_answer_pattern . match (response ),
95103 )
96104
97105 # Check tag occurrence count
@@ -115,9 +123,8 @@ async def aevaluate(self, response: str, **kwargs: Any) -> GraderScore:
115123 # Case 2: <think></think> + <tool_call></tool_call>
116124 elif has_tool_call_tag and not has_answer_tag :
117125 # Check overall format
118- format_pattern = r"^\s*<think>.*?</think>\s*(?:<tool_call>.*?</tool_call>\s*)+$"
119126 valid_format = bool (
120- re . match (format_pattern , response , re . DOTALL ),
127+ self . _think_tool_call_pattern . match (response ),
121128 )
122129
123130 # Check <think> tag occurrence count
@@ -133,11 +140,9 @@ async def aevaluate(self, response: str, **kwargs: Any) -> GraderScore:
133140
134141 # Check for consecutive duplicate tags
135142 if valid_format :
136- if re .search (
137- r"</tool_call>\s*</tool_call>" ,
143+ if self ._consecutive_end_tool_call_tag_pattern .search (
138144 response ,
139- ) or re .search (
140- r"<tool_call>\s*<tool_call>" ,
145+ ) or self ._consecutive_start_tool_call_tag_pattern .search (
141146 response ,
142147 ):
143148 valid_format = False
0 commit comments