|
1 | 1 | from __future__ import annotations |
| 2 | +from math import exp |
2 | 3 |
|
3 | 4 | import random |
4 | 5 | import logging |
@@ -148,13 +149,41 @@ def expand_prompt(model_name: str, text: str, seed: int, log_prompt: bool): |
148 | 149 | seed = abs(seed) |
149 | 150 | elif seed > max_seed: |
150 | 151 | seed = seed % max_seed |
151 | | - |
152 | | - expansion_text = expansion(prompt, seed) |
153 | | - expanded_prompt = join_prompts(prompt, expansion_text) |
| 152 | + |
| 153 | + prompt_parts = [] |
| 154 | + expanded_parts = [] |
| 155 | + |
| 156 | + # Split prompt if longer than 256 |
| 157 | + if len(prompt) > 256: |
| 158 | + prompt_lines = prompt.splitlines() |
| 159 | + # Fill part until 256 |
| 160 | + prompt_parts = [""] |
| 161 | + filled_chars = 0 |
| 162 | + for line in prompt_lines: |
| 163 | + # When adding the line would exceed 256, start a new part |
| 164 | + if filled_chars + len(line) > 256: |
| 165 | + prompt_parts.append(line) |
| 166 | + filled_chars = len(line) |
| 167 | + else: |
| 168 | + prompt_parts[-1] += line |
| 169 | + filled_chars += len(line) |
| 170 | + else: |
| 171 | + prompt_parts = [prompt] |
154 | 172 |
|
| 173 | + for i, part in enumerate(prompt_parts): |
| 174 | + expansion_part = expansion(part, seed) |
| 175 | + full_part = join_prompts(part, expansion_part) |
| 176 | + expanded_parts.append(full_part) |
| 177 | + |
| 178 | + expanded_prompt = "\n".join(expanded_parts) |
| 179 | + |
155 | 180 | if log_prompt: |
156 | | - logger.info(f"Prompt: {prompt}") |
157 | | - logger.info(f"Expanded Prompt: {expanded_prompt}") |
| 181 | + if logger.isEnabledFor(logging.INFO): |
| 182 | + logger.info(f"Prompt: {prompt!r}") |
| 183 | + logger.info(f"Expanded Prompt: {expanded_prompt!r}") |
| 184 | + else: |
| 185 | + print(f"Prompt: {prompt!r}") |
| 186 | + print(f"Expanded Prompt: {expanded_prompt!r}") |
158 | 187 |
|
159 | 188 | return expanded_prompt, seed |
160 | 189 |
|
|
0 commit comments