Skip to content

Commit fe0e4bb

Browse files
committed
Add sandbox code over fastchat
1 parent b2aabdb commit fe0e4bb

File tree

4 files changed

+381
-9
lines changed

4 files changed

+381
-9
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ FastChat's core features include:
3838
- [Fine-tuning](#fine-tuning)
3939
- [Citation](#citation)
4040

41+
----
42+
43+
For Software Areana, please follow the following extra steps:
44+
1. Set your E2B API Key: `export E2B_API_KEY=<YOUR_API_KEY>`
45+
2. Custom Component Build: Follow https://www.gradio.app/guides/custom-components-in-five-minutes to set up environment. Go into `custom_components/sandboxcomponent` and run `gradio cc build`.
46+
3. Use `pip install custom_components/sandboxcomponent/dist/gradio_sandboxcomponent-xxx-py3-none-any.whl` to install the custom components.
47+
48+
----
49+
4150
## Install
4251

4352
### Method 1: With pip

fastchat/serve/gradio_block_arena_named.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import time
88

99
import gradio as gr
10+
from gradio_sandboxcomponent import SandboxComponent
1011
import numpy as np
1112

1213
from fastchat.constants import (
@@ -30,6 +31,7 @@
3031
get_model_description_md,
3132
)
3233
from fastchat.serve.remote_logger import get_remote_logger
34+
from fastchat.serve.sandbox.code_runner import DEFAULT_SANDBOX_INSTRUCTION, create_chatbot_sandbox_state, on_click_run_code, update_sandbox_config
3335
from fastchat.utils import (
3436
build_logger,
3537
moderation_filter,
@@ -152,7 +154,10 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
152154

153155

154156
def add_text(
155-
state0, state1, model_selector0, model_selector1, text, request: gr.Request
157+
state0, state1,
158+
model_selector0, model_selector1,
159+
sandbox_state0, sandbox_state1,
160+
text, request: gr.Request
156161
):
157162
ip = get_ip(request)
158163
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
@@ -204,6 +209,10 @@ def add_text(
204209
* 6
205210
)
206211

212+
# add snadbox instructions if enabled
213+
if sandbox_state0['enable_sandbox']:
214+
text = f"> {sandbox_state0['sandbox_instruction']}\n\n" + text
215+
207216
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
208217
for i in range(num_sides):
209218
states[i].conv.append_message(states[i].conv.roles[0], text)
@@ -227,6 +236,8 @@ def bot_response_multi(
227236
temperature,
228237
top_p,
229238
max_new_tokens,
239+
sandbox_state0,
240+
sandbox_state1,
230241
request: gr.Request,
231242
):
232243
logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
@@ -251,6 +262,7 @@ def bot_response_multi(
251262
top_p,
252263
max_new_tokens,
253264
request,
265+
sandbox_state=sandbox_state0,
254266
)
255267
)
256268

@@ -327,7 +339,7 @@ def build_side_by_side_ui_named(models):
327339

328340
states = [gr.State() for _ in range(num_sides)]
329341
model_selectors = [None] * num_sides
330-
chatbots = [None] * num_sides
342+
chatbots: list[gr.Chatbot | None] = [None] * num_sides
331343

332344
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
333345

@@ -366,6 +378,38 @@ def build_side_by_side_ui_named(models):
366378
],
367379
)
368380

381+
# sandbox states and components
382+
sandbox_states: list[gr.State | None] = [None for _ in range(num_sides)]
383+
sandboxes_components: list[tuple[
384+
gr.Markdown, # sandbox_output
385+
SandboxComponent, # sandbox_ui
386+
gr.Code, # sandbox_code
387+
] | None] = [None for _ in range(num_sides)]
388+
389+
with gr.Group():
390+
with gr.Row():
391+
for chatbotIdx in range(num_sides):
392+
with gr.Column(scale=1):
393+
sandbox_state = gr.State(create_chatbot_sandbox_state())
394+
# Add containers for the sandbox output
395+
sandbox_title = gr.Markdown(value=f"### Model {chatbotIdx + 1} Sandbox", visible=True)
396+
with gr.Tab(label="Output"):
397+
sandbox_output = gr.Markdown(value="", visible=False)
398+
sandbox_ui = SandboxComponent(
399+
value=("", ""),
400+
show_label=True,
401+
visible=False,
402+
)
403+
with gr.Tab(label="Code"):
404+
sandbox_code = gr.Code(value="", interactive=False, visible=False)
405+
406+
sandbox_states[chatbotIdx] = sandbox_state
407+
sandboxes_components[chatbotIdx] = (
408+
sandbox_output,
409+
sandbox_ui,
410+
sandbox_code,
411+
)
412+
369413
with gr.Row():
370414
leftvote_btn = gr.Button(
371415
value="👈 A is better", visible=False, interactive=False
@@ -378,6 +422,30 @@ def build_side_by_side_ui_named(models):
378422
value="👎 Both are bad", visible=False, interactive=False
379423
)
380424

425+
426+
# chatbox sandbox global config
427+
with gr.Group():
428+
with gr.Row():
429+
enable_sandbox_checkbox = gr.Checkbox(value=False, label="Enable Sandbox", interactive=True)
430+
sandbox_env_choice = gr.Dropdown(choices=["React", "Auto"], label="Sandbox Environment", interactive=True)
431+
with gr.Group():
432+
with gr.Accordion("Sandbox Instructions", open=False):
433+
sandbox_instruction_textarea = gr.TextArea(
434+
value=DEFAULT_SANDBOX_INSTRUCTION
435+
)
436+
437+
# update sandbox global config
438+
enable_sandbox_checkbox.change(
439+
fn=update_sandbox_config,
440+
inputs=[
441+
enable_sandbox_checkbox,
442+
sandbox_env_choice,
443+
sandbox_instruction_textarea,
444+
*sandbox_states
445+
],
446+
outputs=[*sandbox_states]
447+
)
448+
381449
with gr.Row():
382450
textbox = gr.Textbox(
383451
show_label=False,
@@ -452,7 +520,7 @@ def build_side_by_side_ui_named(models):
452520
regenerate, states, states + chatbots + [textbox] + btn_list
453521
).then(
454522
bot_response_multi,
455-
states + [temperature, top_p, max_output_tokens],
523+
states + [temperature, top_p, max_output_tokens] + sandbox_states,
456524
states + chatbots + btn_list,
457525
).then(
458526
flash_buttons, [], btn_list
@@ -488,25 +556,38 @@ def build_side_by_side_ui_named(models):
488556

489557
textbox.submit(
490558
add_text,
491-
states + model_selectors + [textbox],
559+
states + model_selectors + sandbox_states + [textbox],
492560
states + chatbots + [textbox] + btn_list,
493561
).then(
494562
bot_response_multi,
495-
states + [temperature, top_p, max_output_tokens],
563+
states + [temperature, top_p, max_output_tokens] + sandbox_states,
496564
states + chatbots + btn_list,
497565
).then(
498566
flash_buttons, [], btn_list
499567
)
500568
send_btn.click(
501569
add_text,
502-
states + model_selectors + [textbox],
570+
states + model_selectors + sandbox_states + [textbox],
503571
states + chatbots + [textbox] + btn_list,
504572
).then(
505573
bot_response_multi,
506-
states + [temperature, top_p, max_output_tokens],
574+
states + [temperature, top_p, max_output_tokens] + sandbox_states,
507575
states + chatbots + btn_list,
508576
).then(
509577
flash_buttons, [], btn_list
510578
)
511579

580+
for chatbotIdx in range(num_sides):
581+
chatbot = chatbots[chatbotIdx]
582+
state = states[chatbotIdx]
583+
sandbox_state = sandbox_states[chatbotIdx]
584+
sandbox_components = sandboxes_components[chatbotIdx]
585+
586+
# trigger sandbox run
587+
chatbot.select(
588+
fn=on_click_run_code,
589+
inputs=[state, sandbox_state, *sandbox_components],
590+
outputs=[*sandbox_components],
591+
)
592+
512593
return states + model_selectors

fastchat/serve/gradio_web_server.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import time
1313
import uuid
1414
from typing import List
15+
from gradio_sandboxcomponent import SandboxComponent
1516

1617
import gradio as gr
1718
import requests
@@ -29,13 +30,15 @@
2930
SESSION_EXPIRATION_TIME,
3031
SURVEY_LINK,
3132
)
33+
from fastchat.conversation import Conversation
3234
from fastchat.model.model_adapter import (
3335
get_conversation_template,
3436
)
3537
from fastchat.model.model_registry import get_model_info, model_info
3638
from fastchat.serve.api_provider import get_api_provider_stream_iter
3739
from fastchat.serve.gradio_global_state import Context
3840
from fastchat.serve.remote_logger import get_remote_logger
41+
from fastchat.serve.sandbox.code_runner import RUN_CODE_BUTTON_HTML, ChatbotSandboxState
3942
from fastchat.utils import (
4043
build_logger,
4144
get_window_url_params_js,
@@ -427,7 +430,11 @@ def bot_response(
427430
request: gr.Request,
428431
apply_rate_limit=True,
429432
use_recommended_config=False,
433+
sandbox_state: ChatbotSandboxState | None = None,
430434
):
435+
'''
436+
The main function for generating responses from the model.
437+
'''
431438
ip = get_ip(request)
432439
logger.info(f"bot_response. ip: {ip}")
433440
start_tstamp = time.time()
@@ -450,7 +457,9 @@ def bot_response(
450457
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
451458
return
452459

453-
conv, model_name = state.conv, state.model_name
460+
conv: Conversation = state.conv
461+
model_name: str = state.model_name
462+
454463
model_api_dict = (
455464
api_endpoint_info[model_name] if model_name in api_endpoint_info else None
456465
)
@@ -550,6 +559,14 @@ def bot_response(
550559
return
551560
output = data["text"].strip()
552561
conv.update_last_message(output)
562+
563+
# Add a "Run in Sandbox" button to the last message if code is detected
564+
if sandbox_state is not None and sandbox_state["enable_sandbox"]:
565+
last_message = conv.messages[-1]
566+
if "```" in last_message[1]:
567+
if not last_message[1].endswith(RUN_CODE_BUTTON_HTML):
568+
last_message[1] += "\n\n" + RUN_CODE_BUTTON_HTML
569+
553570
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
554571
except requests.exceptions.RequestException as e:
555572
conv.update_last_message(
@@ -880,6 +897,17 @@ def build_single_model_ui(models, add_promotion_links=False):
880897
{"left": r"\[", "right": r"\]", "display": True},
881898
],
882899
)
900+
901+
# Add containers for the sandbox output and JavaScript
902+
# with gr.Column():
903+
# sandbox_output = gr.Markdown(value="", visible=False)
904+
# sandbox = SandboxComponent(
905+
# label="Sandbox",
906+
# value=("", ""),
907+
# show_label=True,
908+
# visible=False,
909+
# )
910+
883911
with gr.Row():
884912
textbox = gr.Textbox(
885913
show_label=False,
@@ -969,10 +997,15 @@ def build_single_model_ui(models, add_promotion_links=False):
969997
[state, chatbot] + btn_list,
970998
)
971999

1000+
# trigger sandbox run
1001+
# chatbot.select(on_click_run_code,
1002+
# inputs=[state, sandbox_output, sandbox],
1003+
# outputs=[sandbox_output, sandbox])
1004+
9721005
return [state, model_selector]
9731006

9741007

975-
def build_demo(models):
1008+
def build_demo(models) -> gr.Blocks:
9761009
with gr.Blocks(
9771010
title="Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots",
9781011
theme=gr.themes.Default(),

0 commit comments

Comments
 (0)