Skip to content

Commit 17a9b66

Browse files
committed
Logging and prompt chunking
1 parent f16c605 commit 17a9b66

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

src/inference_core_nodes/prompt_expansion/prompt_expansion.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from math import exp
23

34
import random
45
import logging
@@ -148,13 +149,41 @@ def expand_prompt(model_name: str, text: str, seed: int, log_prompt: bool):
148149
seed = abs(seed)
149150
elif seed > max_seed:
150151
seed = seed % max_seed
151-
152-
expansion_text = expansion(prompt, seed)
153-
expanded_prompt = join_prompts(prompt, expansion_text)
152+
153+
prompt_parts = []
154+
expanded_parts = []
155+
156+
# Split prompt if longer than 256
157+
if len(prompt) > 256:
158+
prompt_lines = prompt.splitlines()
159+
# Fill part until 256
160+
prompt_parts = [""]
161+
filled_chars = 0
162+
for line in prompt_lines:
163+
# When adding the line would exceed 256, start a new part
164+
if filled_chars + len(line) > 256:
165+
prompt_parts.append(line)
166+
filled_chars = len(line)
167+
else:
168+
prompt_parts[-1] += line
169+
filled_chars += len(line)
170+
else:
171+
prompt_parts = [prompt]
154172

173+
for i, part in enumerate(prompt_parts):
174+
expansion_part = expansion(part, seed)
175+
full_part = join_prompts(part, expansion_part)
176+
expanded_parts.append(full_part)
177+
178+
expanded_prompt = "\n".join(expanded_parts)
179+
155180
if log_prompt:
156-
logger.info(f"Prompt: {prompt}")
157-
logger.info(f"Expanded Prompt: {expanded_prompt}")
181+
if logger.isEnabledFor(logging.INFO):
182+
logger.info(f"Prompt: {prompt!r}")
183+
logger.info(f"Expanded Prompt: {expanded_prompt!r}")
184+
else:
185+
print(f"Prompt: {prompt!r}")
186+
print(f"Expanded Prompt: {expanded_prompt!r}")
158187

159188
return expanded_prompt, seed
160189

0 commit comments

Comments
 (0)