1111import random
1212import time
1313import uuid
14- from typing import List
14+ from typing import List , Dict
1515
1616import gradio as gr
1717import requests
@@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
119119 self .model_name = model_name
120120 self .oai_thread_id = None
121121 self .is_vision = is_vision
122+ self .ans_models = []
123+ self .router_outputs = []
122124
123125 # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
124126 self .has_csam_image = False
@@ -128,6 +130,15 @@ def __init__(self, model_name, is_vision=False):
128130 self .regen_support = False
129131 self .init_system_prompt (self .conv , is_vision )
130132
133+ def update_ans_models (self , ans : str ) -> None :
134+
135+ self .ans_models .append (ans )
136+
137+ def update_router_outputs (self , outputs : Dict [str , float ]) -> None :
138+
139+ self .router_outputs .append (outputs )
140+
141+
131142 def init_system_prompt (self , conv , is_vision ):
132143 system_prompt = conv .get_system_message (is_vision )
133144 if len (system_prompt ) == 0 :
@@ -154,6 +165,20 @@ def dict(self):
154165 }
155166 )
156167
168+ if self .ans_models :
169+ base .update (
170+ {
171+ "ans_models" : self .ans_models ,
172+ }
173+ )
174+
175+ if self .router_outputs :
176+ base .update (
177+ {
178+ "router_outputs" : self .router_outputs ,
179+ }
180+ )
181+
157182 if self .is_vision :
158183 base .update ({"has_csam_image" : self .has_csam_image })
159184 return base
@@ -420,7 +445,7 @@ def is_limit_reached(model_name, ip):
420445
421446
422447def bot_response (
423- state ,
448+ state : State ,
424449 temperature ,
425450 top_p ,
426451 max_new_tokens ,
@@ -532,6 +557,23 @@ def bot_response(
532557 try :
533558 data = {"text" : "" }
534559 for i , data in enumerate (stream_iter ):
560+
561+ # Change for P2L:
562+ if i == 0 :
563+
564+ if "ans_model" in data :
565+
566+ ans_model = data .get ("ans_model" )
567+
568+ state .update_ans_models (ans_model )
569+
570+ if "router_outputs" in data :
571+
572+ router_outputs = data .get ("router_outputs" )
573+
574+ state .update_router_outputs (router_outputs )
575+
576+
535577 if data ["error_code" ] == 0 :
536578 output = data ["text" ].strip ()
537579 conv .update_last_message (output + "▌" )
0 commit comments