Skip to content

Commit 7925e4e

Browse files
committed
Judge layer for tool calling models
1 parent 8f74e5b commit 7925e4e

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

interpreter/core/llm/run_function_calling_llm.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,30 @@ def run_function_calling_llm(llm, request_params):
5757
if "content" in delta and delta["content"]:
5858
if function_call_detected:
5959
# More content after a code block? This is a code review by a judge layer.
60+
6061
# print("Code safety review:", delta["content"])
61-
accumulated_review += delta["content"]
6262

6363
if review_category == None:
64-
if "<UNSAFE>" in accumulated_review:
64+
accumulated_review += delta["content"]
65+
66+
if "<unsafe>" in accumulated_review:
6567
review_category = "unsafe"
66-
if "<WARNING>" in accumulated_review:
68+
if "<warning>" in accumulated_review:
6769
review_category = "warning"
68-
if "<SAFE>" in accumulated_review:
70+
if "<safe>" in accumulated_review:
6971
review_category = "safe"
7072

7173
if review_category != None:
74+
for tag in [
75+
"<safe>",
76+
"</safe>",
77+
"<warning>",
78+
"</warning>",
79+
"<unsafe>",
80+
"</unsafe>",
81+
]:
82+
delta["content"] = delta["content"].replace(tag, "")
83+
7284
yield {
7385
"type": "review",
7486
"format": review_category,

interpreter/core/llm/run_tool_calling_llm.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def run_tool_calling_llm(llm, request_params):
6363
accumulated_deltas = {}
6464
language = None
6565
code = ""
66+
function_call_detected = False
6667

6768
for chunk in llm.completions(**request_params):
6869
if "choices" not in chunk or len(chunk["choices"]) == 0:
@@ -73,6 +74,8 @@ def run_tool_calling_llm(llm, request_params):
7374

7475
# Convert tool call into function call, which we have great parsing logic for below
7576
if "tool_calls" in delta and delta["tool_calls"]:
77+
function_call_detected = True
78+
7679
# import pdb; pdb.set_trace()
7780
if len(delta["tool_calls"]) > 0 and delta["tool_calls"][0].function:
7881
delta = {
@@ -87,7 +90,40 @@ def run_tool_calling_llm(llm, request_params):
8790
accumulated_deltas = merge_deltas(accumulated_deltas, delta)
8891

8992
if "content" in delta and delta["content"]:
90-
yield {"type": "message", "content": delta["content"]}
93+
if function_call_detected:
94+
# More content after a code block? This is a code review by a judge layer.
95+
96+
# print("Code safety review:", delta["content"])
97+
98+
if review_category == None:
99+
accumulated_review += delta["content"]
100+
101+
if "<unsafe>" in accumulated_review:
102+
review_category = "unsafe"
103+
if "<warning>" in accumulated_review:
104+
review_category = "warning"
105+
if "<safe>" in accumulated_review:
106+
review_category = "safe"
107+
108+
if review_category != None:
109+
for tag in [
110+
"<safe>",
111+
"</safe>",
112+
"<warning>",
113+
"</warning>",
114+
"<unsafe>",
115+
"</unsafe>",
116+
]:
117+
delta["content"] = delta["content"].replace(tag, "")
118+
119+
yield {
120+
"type": "review",
121+
"format": review_category,
122+
"content": delta["content"],
123+
}
124+
125+
else:
126+
yield {"type": "message", "content": delta["content"]}
91127

92128
if (
93129
accumulated_deltas.get("function_call")

0 commit comments

Comments
 (0)