1111import random
1212import time
1313import uuid
14- from typing import List
14+ from typing import List , Dict
1515
1616import gradio as gr
1717import requests
@@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
119119 self .model_name = model_name
120120 self .oai_thread_id = None
121121 self .is_vision = is_vision
122+ self .ans_models = []
123+ self .router_outputs = []
122124
123125 # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
124126 self .has_csam_image = False
@@ -128,6 +130,12 @@ def __init__(self, model_name, is_vision=False):
128130 self .regen_support = False
129131 self .init_system_prompt (self .conv , is_vision )
130132
133+ def update_ans_models (self , ans : str ) -> None :
134+ self .ans_models .append (ans )
135+
136+ def update_router_outputs (self , outputs : Dict [str , float ]) -> None :
137+ self .router_outputs .append (outputs )
138+
131139 def init_system_prompt (self , conv , is_vision ):
132140 system_prompt = conv .get_system_message (is_vision )
133141 if len (system_prompt ) == 0 :
@@ -154,6 +162,20 @@ def dict(self):
154162 }
155163 )
156164
165+ if self .ans_models :
166+ base .update (
167+ {
168+ "ans_models" : self .ans_models ,
169+ }
170+ )
171+
172+ if self .router_outputs :
173+ base .update (
174+ {
175+ "router_outputs" : self .router_outputs ,
176+ }
177+ )
178+
157179 if self .is_vision :
158180 base .update ({"has_csam_image" : self .has_csam_image })
159181 return base
@@ -420,7 +442,7 @@ def is_limit_reached(model_name, ip):
420442
421443
422444def bot_response (
423- state ,
445+ state : State ,
424446 temperature ,
425447 top_p ,
426448 max_new_tokens ,
@@ -504,6 +526,8 @@ def bot_response(
504526 if not custom_system_prompt :
505527 conv .set_system_message ("" )
506528
529+ extra_body = None
530+
507531 if use_recommended_config :
508532 recommended_config = model_api_dict .get ("recommended_config" , None )
509533 if recommended_config is not None :
@@ -512,6 +536,7 @@ def bot_response(
512536 max_new_tokens = recommended_config .get (
513537 "max_new_tokens" , max_new_tokens
514538 )
539+ extra_body = recommended_config .get ("extra_body" , None )
515540
516541 stream_iter = get_api_provider_stream_iter (
517542 conv ,
@@ -521,6 +546,7 @@ def bot_response(
521546 top_p ,
522547 max_new_tokens ,
523548 state ,
549+ extra_body = extra_body ,
524550 )
525551
526552 html_code = ' <span class="cursor"></span> '
@@ -532,6 +558,18 @@ def bot_response(
532558 try :
533559 data = {"text" : "" }
534560 for i , data in enumerate (stream_iter ):
561+ # Change for P2L:
562+ if i == 0 :
563+ if "ans_model" in data :
564+ ans_model = data .get ("ans_model" )
565+
566+ state .update_ans_models (ans_model )
567+
568+ if "router_outputs" in data :
569+ router_outputs = data .get ("router_outputs" )
570+
571+ state .update_router_outputs (router_outputs )
572+
535573 if data ["error_code" ] == 0 :
536574 output = data ["text" ].strip ()
537575 conv .update_last_message (output + "▌" )
@@ -688,6 +726,22 @@ def bot_response(
688726.block {
689727 overflow-y: hidden !important;
690728}
729+
730+ .visualizer {
731+ overflow: hidden;
732+ height: 60vw;
733+ border: 1px solid lightgrey;
734+ border-radius: 10px;
735+ }
736+
737+ @media screen and (max-width: 769px) {
738+ .visualizer {
739+ height: 180vw;
740+ overflow-y: scroll;
741+ width: 100%;
742+ overflow-x: hidden;
743+ }
744+ }
691745"""
692746
693747
0 commit comments