|
1 | 1 | import torch |
| 2 | +from collections.abc import Generator |
2 | 3 |
|
3 | 4 | from transformers import ( |
4 | 5 | AutoModelForCausalLM, |
@@ -71,6 +72,24 @@ def generate(self, messages: MessageList, past_key_values: DynamicCache | None = |
71 | 72 | else: |
72 | 73 | return self._generate_with_cache(prompt, past_key_values) |
73 | 74 |
|
| 75 | + def generate_stream(self, messages: MessageList, past_key_values: DynamicCache | None = None) -> Generator[str, None, None]: |
| 76 | + """ |
| 77 | + Generate a streaming response from the model. |
| 78 | + Args: |
| 79 | + messages (MessageList): Chat messages for prompt construction. |
| 80 | + past_key_values (DynamicCache | None): Optional KV cache for fast generation. |
| 81 | + Yields: |
| 82 | + str: Streaming model response chunks. |
| 83 | + """ |
| 84 | + prompt = self.tokenizer.apply_chat_template( |
| 85 | + messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt |
| 86 | + ) |
| 87 | + logger.info(f"HFLLM streaming prompt: {prompt}") |
| 88 | + if past_key_values is None: |
| 89 | + yield from self._generate_full_stream(prompt) |
| 90 | + else: |
| 91 | + yield from self._generate_with_cache_stream(prompt, past_key_values) |
| 92 | + |
74 | 93 | def _generate_full(self, prompt: str) -> str: |
75 | 94 | """ |
76 | 95 | Generate output from scratch using the full prompt. |
@@ -104,6 +123,71 @@ def _generate_full(self, prompt: str) -> str: |
104 | 123 | else response |
105 | 124 | ) |
106 | 125 |
|
| 126 | + def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]: |
| 127 | + """ |
| 128 | + Generate output from scratch using the full prompt with streaming. |
| 129 | + Args: |
| 130 | + prompt (str): The input prompt string. |
| 131 | + Yields: |
| 132 | + str: Streaming response chunks. |
| 133 | + """ |
| 134 | + inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) |
| 135 | + |
| 136 | + # Get generation parameters |
| 137 | + max_new_tokens = getattr(self.config, "max_tokens", 128) |
| 138 | + do_sample = getattr(self.config, "do_sample", True) |
| 139 | + remove_think_prefix = getattr(self.config, "remove_think_prefix", False) |
| 140 | + |
| 141 | + # Manual streaming generation |
| 142 | + input_length = inputs.input_ids.shape[1] |
| 143 | + generated_ids = inputs.input_ids.clone() |
| 144 | + accumulated_text = "" |
| 145 | + |
| 146 | + for _ in range(max_new_tokens): |
| 147 | + # Forward pass |
| 148 | + with torch.no_grad(): |
| 149 | + outputs = self.model( |
| 150 | + input_ids=generated_ids, |
| 151 | + use_cache=True, |
| 152 | + return_dict=True, |
| 153 | + ) |
| 154 | + |
| 155 | + # Get next token logits |
| 156 | + next_token_logits = outputs.logits[:, -1, :] |
| 157 | + |
| 158 | + # Apply logits processors if sampling |
| 159 | + if do_sample: |
| 160 | + batch_size, _ = next_token_logits.size() |
| 161 | + dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=next_token_logits.device) |
| 162 | + filtered_logits = self.logits_processors(dummy_ids, next_token_logits) |
| 163 | + probs = torch.softmax(filtered_logits, dim=-1) |
| 164 | + next_token = torch.multinomial(probs, num_samples=1) |
| 165 | + else: |
| 166 | + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| 167 | + |
| 168 | + # Check for EOS token |
| 169 | + if self._should_stop(next_token): |
| 170 | + break |
| 171 | + |
| 172 | + # Append new token |
| 173 | + generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
| 174 | + |
| 175 | + # Decode and yield the new token |
| 176 | + new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True) |
| 177 | + if new_token_text: # Only yield non-empty tokens |
| 178 | + accumulated_text += new_token_text |
| 179 | + |
| 180 | + # Apply thinking tag removal if enabled |
| 181 | + if remove_think_prefix: |
| 182 | + processed_text = remove_thinking_tags(accumulated_text) |
| 183 | + # Only yield the difference (new content) |
| 184 | + if len(processed_text) > len(accumulated_text) - len(new_token_text): |
| 185 | + yield processed_text[len(accumulated_text) - len(new_token_text):] |
| 186 | + else: |
| 187 | + yield new_token_text |
| 188 | + else: |
| 189 | + yield new_token_text |
| 190 | + |
107 | 191 | def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: |
108 | 192 | """ |
109 | 193 | Generate output incrementally using an existing KV cache. |
@@ -137,6 +221,68 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: |
137 | 221 | else response |
138 | 222 | ) |
139 | 223 |
|
| 224 | + def _generate_with_cache_stream(self, query: str, kv: DynamicCache) -> Generator[str, None, None]: |
| 225 | + """ |
| 226 | + Generate output incrementally using an existing KV cache with streaming. |
| 227 | + Args: |
| 228 | + query (str): The new user query string. |
| 229 | + kv (DynamicCache): The prefilled KV cache. |
| 230 | + Yields: |
| 231 | + str: Streaming response chunks. |
| 232 | + """ |
| 233 | + query_ids = self.tokenizer( |
| 234 | + query, return_tensors="pt", add_special_tokens=False |
| 235 | + ).input_ids.to(self.model.device) |
| 236 | + |
| 237 | + max_new_tokens = getattr(self.config, "max_tokens", 128) |
| 238 | + do_sample = getattr(self.config, "do_sample", True) |
| 239 | + remove_think_prefix = getattr(self.config, "remove_think_prefix", False) |
| 240 | + |
| 241 | + # Initial forward pass |
| 242 | + logits, kv = self._prefill(query_ids, kv) |
| 243 | + next_token = self._select_next_token(logits) |
| 244 | + |
| 245 | + # Yield first token |
| 246 | + first_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True) |
| 247 | + accumulated_text = "" |
| 248 | + if first_token_text: |
| 249 | + accumulated_text += first_token_text |
| 250 | + if remove_think_prefix: |
| 251 | + processed_text = remove_thinking_tags(accumulated_text) |
| 252 | + if len(processed_text) > len(accumulated_text) - len(first_token_text): |
| 253 | + yield processed_text[len(accumulated_text) - len(first_token_text):] |
| 254 | + else: |
| 255 | + yield first_token_text |
| 256 | + else: |
| 257 | + yield first_token_text |
| 258 | + |
| 259 | + generated = [next_token] |
| 260 | + |
| 261 | + # Continue generation |
| 262 | + for _ in range(max_new_tokens - 1): |
| 263 | + if self._should_stop(next_token): |
| 264 | + break |
| 265 | + logits, kv = self._prefill(next_token, kv) |
| 266 | + next_token = self._select_next_token(logits) |
| 267 | + |
| 268 | + # Decode and yield the new token |
| 269 | + new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True) |
| 270 | + if new_token_text: |
| 271 | + accumulated_text += new_token_text |
| 272 | + |
| 273 | + # Apply thinking tag removal if enabled |
| 274 | + if remove_think_prefix: |
| 275 | + processed_text = remove_thinking_tags(accumulated_text) |
| 276 | + # Only yield the difference (new content) |
| 277 | + if len(processed_text) > len(accumulated_text) - len(new_token_text): |
| 278 | + yield processed_text[len(accumulated_text) - len(new_token_text):] |
| 279 | + else: |
| 280 | + yield new_token_text |
| 281 | + else: |
| 282 | + yield new_token_text |
| 283 | + |
| 284 | + generated.append(next_token) |
| 285 | + |
140 | 286 | @torch.no_grad() |
141 | 287 | def _prefill( |
142 | 288 | self, input_ids: torch.Tensor, kv: DynamicCache |
|
0 commit comments