1+ from threading import Thread
2+ import gc
3+ import torch
4+ from transformers import TextIteratorStreamer
5+
6+ def generate_stream_gemma3 (
7+ model ,
8+ tokenizer ,
9+ params ,
10+ device ,
11+ context_len ,
12+ stream_interval = 2 ,
13+ judge_sent_end = False
14+ ):
15+ """Custom generate stream function for Gemma-3 models"""
16+ # Get parameters from the request
17+ prompt = params .get ("prompt" , "" )
18+ messages = params .get ("messages" , None )
19+ temperature = float (params .get ("temperature" , 1.0 ))
20+ repetition_penalty = float (params .get ("repetition_penalty" , 1.0 ))
21+ top_p = float (params .get ("top_p" , 1.0 ))
22+ top_k = int (params .get ("top_k" , - 1 )) # -1 means disable
23+ max_new_tokens = int (params .get ("max_new_tokens" , 256 ))
24+ echo = bool (params .get ("echo" , True ))
25+ stop_str = params .get ("stop" , None )
26+ stop_token_ids = params .get ("stop_token_ids" , None ) or []
27+ model_name = params .get ("model" , None )
28+
29+ if tokenizer .eos_token_id not in stop_token_ids :
30+ stop_token_ids .append (tokenizer .eos_token_id )
31+
32+ is_base_model = "pt" in model_name .lower () or "base" in model_name .lower ()
33+
34+ if not is_base_model :
35+ # Format input based on whether we have messages or a plain prompt
36+ if messages :
37+ inputs = tokenizer .apply_chat_template (
38+ messages , add_generation_prompt = True , tokenize = True , return_dict = True , return_tensors = "pt"
39+ ).to (model .device )
40+ else :
41+ messages = [{"role" : "user" , "content" : [{"type" : "text" , "text" : prompt }]}]
42+ inputs = tokenizer .apply_chat_template (
43+ messages , add_generation_prompt = True , tokenize = True , return_dict = True , return_tensors = "pt"
44+ ).to (model .device )
45+ else :
46+ inputs = tokenizer (prompt , return_tensors = "pt" ).to (model .device )
47+
48+ input_ids = inputs ["input_ids" ]
49+ input_echo_len = input_ids .shape [1 ]
50+
51+ # Configure generation parameters
52+ generate_kwargs = {
53+ "max_new_tokens" : max_new_tokens ,
54+ "do_sample" : temperature > 0.0 ,
55+ "temperature" : temperature if temperature > 0.0 else 1.0 ,
56+ }
57+
58+ if top_p < 1.0 :
59+ generate_kwargs ["top_p" ] = top_p
60+ if top_k > 0 :
61+ generate_kwargs ["top_k" ] = top_k
62+ if repetition_penalty > 1.0 :
63+ generate_kwargs ["repetition_penalty" ] = repetition_penalty
64+
65+ streamer = TextIteratorStreamer (tokenizer , skip_prompt = not echo , skip_special_tokens = True )
66+ generate_kwargs ["streamer" ] = streamer
67+
68+ # Start generation in a separate thread
69+ thread = Thread (target = lambda : model .generate (input_ids = input_ids , ** generate_kwargs ))
70+ thread .start ()
71+
72+ # Track generation progress
73+ generated_tokens = 0
74+ output_text = ""
75+
76+ # Stream tokens
77+ for new_text in streamer :
78+ output_text += new_text
79+ generated_tokens += 1
80+
81+ # Check for stop strings
82+ should_stop = False
83+ if stop_str :
84+ if isinstance (stop_str , str ):
85+ if stop_str in output_text :
86+ output_text = output_text [: output_text .find (stop_str )]
87+ should_stop = True
88+ elif isinstance (stop_str , list ):
89+ for stop in stop_str :
90+ if stop in output_text :
91+ output_text = output_text [: output_text .find (stop )]
92+ should_stop = True
93+ break
94+
95+ # Stream at intervals or when stopping
96+ if generated_tokens % stream_interval == 0 or should_stop :
97+ yield {
98+ "text" : output_text ,
99+ "usage" : {
100+ "prompt_tokens" : input_echo_len ,
101+ "completion_tokens" : generated_tokens ,
102+ "total_tokens" : input_echo_len + generated_tokens ,
103+ },
104+ "finish_reason" : "stop" if should_stop else None ,
105+ }
106+
107+ if should_stop :
108+ break
109+
110+ # Final output with finish reason
111+ if thread .is_alive ():
112+ thread .join (
113+ timeout = 3600
114+ ) # Arbitrary value, but if it doesn't complete in this much time then something is wrong
115+
116+ yield {
117+ "text" : output_text ,
118+ "usage" : {
119+ "prompt_tokens" : input_echo_len ,
120+ "completion_tokens" : generated_tokens ,
121+ "total_tokens" : input_echo_len + generated_tokens ,
122+ },
123+ "finish_reason" : "length" ,
124+ }
125+
126+ # Clean up
127+ gc .collect ()
128+ torch .cuda .empty_cache ()
129+ if device == "xpu" :
130+ torch .xpu .empty_cache ()
131+ if device == "npu" :
132+ torch .npu .empty_cache ()
0 commit comments