Skip to content

Commit a66e865

Browse files
authored
feature: add llm tokens limit (#327)
* fix: typo in available keys * docs: update parameters * chore: rename config to configloader * fix: add correct number parsing * feat: add context window parameter * chore: rename config to config_loader * build: bump project version
1 parent bd62994 commit a66e865

File tree

16 files changed

+122
-69
lines changed

16 files changed

+122
-69
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,10 @@ documentation, see the [Workflow Generator README](./osa_tool/workflow/README.md
196196
| `--api` | LLM API service provider | `itmo` |
197197
| `--base-url` | URL of the provider compatible with API OpenAI | `https://api.openai.com/v1` |
198198
| `--model` | Specific LLM model to use | `gpt-3.5-turbo` |
199-
| `--top_p` | Nucleus sampling probability | `None` |
200-
| `--temperature` | Sampling temperature to use for the LLM output (0 = deterministic, 1 = creative). | `None` |
201-
| `--max_tokens` | Maximum number of tokens the model can generate in a single response | `None` |
199+
| `--top_p` | Nucleus sampling probability | `0.95` |
200+
| `--temperature` | Sampling temperature to use for the LLM output (0 = deterministic, 1 = creative). | `0.05` |
201+
| `--max_tokens` | Maximum number of output tokens the model can generate in a single response | `4096` |
202+
| `--context_window` | Total number of model context (Input + Output) | `16385` |
202203
| `--attachment` | Path to a local PDF or .docx file, or a URL to a PDF resource | `None` |
203204
| `-m`, `--mode` | Operation mode for repository processing: `basic`, `auto` (default), or `advanced`. | `auto` |
204205
| `--delete-dir` | Enable deleting the downloaded repository after processing | `disabled` |

docs/index.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,10 @@ documentation, see the [GitHub Action Workflow Generator README](../osa_tool/git
165165
| `--api` | LLM API service provider | `itmo` |
166166
| `--base-url` | URL of the provider compatible with API OpenAI | `https://api.openai.com/v1` |
167167
| `--model` | Specific LLM model to use | `gpt-3.5-turbo` |
168-
| `--top_p` | Nucleus sampling probability | `None` |
169-
| `--temperature` | Sampling temperature to use for the LLM output (0 = deterministic, 1 = creative). | `None` |
170-
| `--max_tokens` | Maximum number of tokens the model can generate in a single response | `None` |
168+
| `--top_p` | Nucleus sampling probability | `0.95` |
169+
| `--temperature` | Sampling temperature to use for the LLM output (0 = deterministic, 1 = creative). | `0.05` |
170+
| `--max_tokens` | Maximum number of output tokens the model can generate in a single response | `4096` |
171+
| `--context_window` | Total number of model context (Input + Output) | `16385` |
171172
| `--attachment` | Path to a local PDF or .docx file, or a URL to a PDF resource | `None` |
172173
| `-m`, `--mode` | Operation mode for repository processing: `basic`, `auto` (default), or `advanced`. | `auto` |
173174
| `--delete-dir` | Enable deleting the downloaded repository after processing | `disabled` |

osa_tool/config/settings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44

55
import os.path
66
from pathlib import Path
7-
from typing import Any, Literal, List
7+
from typing import Any, List, Literal
88

99
import tomli
1010
from pydantic import (
1111
AnyHttpUrl,
1212
BaseModel,
1313
ConfigDict,
1414
Field,
15-
model_validator,
1615
NonNegativeFloat,
1716
PositiveInt,
17+
model_validator,
1818
)
1919

20-
from osa_tool.utils.utils import parse_git_url, build_config_path
20+
from osa_tool.utils.utils import build_config_path, parse_git_url
2121

2222

2323
class GitSettings(BaseModel):
@@ -46,14 +46,14 @@ class ModelSettings(BaseModel):
4646
api: str
4747
rate_limit: PositiveInt
4848
base_url: str
49-
context_window: PositiveInt
5049
encoder: str
5150
host_name: AnyHttpUrl
5251
localhost: AnyHttpUrl
5352
model: str
5453
path: str
5554
temperature: NonNegativeFloat
56-
tokens: PositiveInt
55+
max_tokens: PositiveInt
56+
context_window: PositiveInt
5757
top_p: NonNegativeFloat
5858

5959

osa_tool/config/settings/arguments.yaml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,25 @@ model:
5656
5757
temperature:
5858
aliases: [ "--temperature" ]
59-
type: str
59+
type: float
6060
description: "Sampling temperature to use for the LLM output (0 = deterministic, 1 = creative)."
6161
example: 0.3, 0.9
6262

63-
tokens:
63+
max_tokens:
6464
aliases: [ "--max-tokens" ]
65-
type: str
66-
description: "Maximum number of tokens the model can generate in a single response."
65+
type: int
66+
description: "Maximum number of output tokens the model can generate in a single response."
6767
example: 256, 1024
6868

69+
context_window:
70+
aliases: [ "--context-window" ]
71+
type: int
72+
description: "Total number of model context (Input + Output)."
73+
example: 16000, 200000
74+
6975
top_p:
7076
aliases: [ "--top-p" ]
71-
type: str
77+
type: float
7278
description: "Nucleus sampling probability (1.0 = all tokens considered)."
7379
example: 0.8, 0.95
7480

osa_tool/config/settings/config.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ no_pull_request = false
99
api = "itmo"
1010
rate_limit = 10
1111
base_url = "https://api.openai.com/v1"
12-
context_window = 4096
1312
encoder = "cl100k_base"
1413
host_name = "https://api.openai.com/v1"
1514
localhost = "http://localhost:11434/"
1615
model = "gpt-3.5-turbo"
1716
path = "generate"
1817
temperature = 0.05
19-
tokens = 4096
18+
max_tokens = 4096
19+
context_window = 16385
2020
top_p = 0.95
2121

2222
# General CLI-related defaults

osa_tool/models/models.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from uuid import uuid4
55

66
import dotenv
7+
import tiktoken
78
from langchain.schema import SystemMessage
89
from protollm.connectors import create_llm_connector
910

@@ -81,7 +82,7 @@ def __init__(self, config: Settings, prompt: str):
8182
"""
8283
self.job_id = str(uuid4())
8384
self.temperature = config.llm.temperature
84-
self.tokens_limit = config.llm.tokens
85+
self.tokens_limit = config.llm.max_tokens
8586
self.prompt = prompt
8687
self.roles = [
8788
SystemMessage(content="You are a helpful assistant for analyzing open-source repositories."),
@@ -160,7 +161,8 @@ def send_request(self, prompt: str) -> str:
160161
Returns:
161162
str: The response received from the request.
162163
"""
163-
self.initialize_payload(self.config, prompt)
164+
safe_prompt = self._limit_tokens(prompt)
165+
self.initialize_payload(self.config, safe_prompt)
164166
messages = self.payload["messages"]
165167
response = self.client.invoke(messages)
166168
return response.content
@@ -176,7 +178,8 @@ async def async_request(self, prompt: str) -> str:
176178
Returns:
177179
str: The response received from the request.
178180
"""
179-
self.initialize_payload(self.config, prompt)
181+
safe_prompt = self._limit_tokens(prompt)
182+
self.initialize_payload(self.config, safe_prompt)
180183
response = await self.client.ainvoke(self.payload["messages"])
181184
return response.content
182185

@@ -229,6 +232,40 @@ def _configure_api(self, api: str, model_name: str) -> None:
229232

230233
self.client = create_llm_connector(model_url=self._build_model_url(), **self._get_llm_params())
231234

235+
def _limit_tokens(self, text: str, safety_buffer: int = 100, mode: str = "middle-out") -> str:
236+
"""
237+
Limits text to fit within the model's context window.
238+
239+
Calculates: Available Input = Total Context - Max Output - Safety Buffer
240+
"""
241+
model_context_limit = getattr(self.config.llm, "context_window")
242+
max_output_tokens = self.config.llm.max_tokens
243+
encoding_name = self.config.llm.encoder
244+
245+
max_input_tokens = model_context_limit - max_output_tokens - safety_buffer
246+
247+
try:
248+
encoding = tiktoken.get_encoding(encoding_name)
249+
except ValueError:
250+
encoding = tiktoken.get_encoding("cl100k_base")
251+
252+
tokens = encoding.encode(text)
253+
254+
if len(tokens) <= max_input_tokens:
255+
return text
256+
257+
if mode == "start":
258+
truncated_tokens = tokens[:max_input_tokens]
259+
elif mode == "end":
260+
truncated_tokens = tokens[-max_input_tokens:]
261+
elif mode == "middle-out":
262+
half_limit = max_input_tokens // 2
263+
truncated_tokens = tokens[:half_limit] + tokens[-half_limit:]
264+
else:
265+
raise ValueError(f"Unknown mode: {mode}")
266+
267+
return encoding.decode(truncated_tokens)
268+
232269

233270
class ModelHandlerFactory:
234271
"""

osa_tool/run.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@
3131
from osa_tool.translation.dir_translator import DirectoryTranslator
3232
from osa_tool.translation.readme_translator import ReadmeTranslator
3333
from osa_tool.utils.arguments_parser import build_parser_from_yaml
34-
from osa_tool.utils.logger import setup_logging, logger
34+
from osa_tool.utils.logger import logger, setup_logging
3535
from osa_tool.utils.prompts_builder import PromptLoader
36-
from osa_tool.utils.utils import delete_repository, parse_folder_name, rich_section, osa_project_root
36+
from osa_tool.utils.utils import (
37+
delete_repository,
38+
osa_project_root,
39+
parse_folder_name,
40+
rich_section,
41+
)
3742
from osa_tool.validation.doc_validator import DocValidator
3843
from osa_tool.validation.paper_validator import PaperValidator
3944
from osa_tool.validation.report_generator import (
@@ -78,13 +83,14 @@ def main():
7883
logger.info(f"Output path changed to {output_path}")
7984

8085
# Load configurations and update
81-
config = load_configuration(
86+
config_loader = load_configuration(
8287
repo_url=args.repository,
8388
api=args.api,
8489
base_url=args.base_url,
8590
model_name=args.model,
8691
temperature=args.temperature,
8792
max_tokens=args.max_tokens,
93+
context_window=args.context_window,
8894
top_p=args.top_p,
8995
)
9096

@@ -107,8 +113,8 @@ def main():
107113
git_agent.clone_repository()
108114

109115
# Initialize ModeScheduler
110-
sourcerank = SourceRank(config)
111-
scheduler = ModeScheduler(config, sourcerank, prompts, args, workflow_manager, git_agent.metadata)
116+
sourcerank = SourceRank(config_loader)
117+
scheduler = ModeScheduler(config_loader, sourcerank, prompts, args, workflow_manager, git_agent.metadata)
112118
plan = scheduler.plan
113119

114120
if create_fork:
@@ -118,25 +124,25 @@ def main():
118124
# NOTE: Must run first - switches GitHub branches
119125
if plan.get("report"):
120126
rich_section("Report generation")
121-
analytics = ReportGenerator(config, sourcerank, prompts, git_agent.metadata)
127+
analytics = ReportGenerator(config_loader, sourcerank, prompts, git_agent.metadata)
122128
analytics.build_pdf()
123129
if create_fork:
124130
git_agent.upload_report(analytics.filename, analytics.output_path)
125131

126132
# NOTE: Must run first - switches GitHub branches
127133
if plan.get("validate_doc"):
128134
rich_section("Document validation")
129-
content = DocValidator(config, prompts).validate(plan.get("attachment"))
130-
va_re_gen = ValidationReportGenerator(config, git_agent.metadata, sourcerank)
135+
content = DocValidator(config_loader, prompts).validate(plan.get("attachment"))
136+
va_re_gen = ValidationReportGenerator(config_loader, git_agent.metadata, sourcerank)
131137
va_re_gen.build_pdf("Document", content)
132138
if create_fork:
133139
git_agent.upload_report(va_re_gen.filename, va_re_gen.output_path)
134140

135141
# NOTE: Must run first - switches GitHub branches
136142
if plan.get("validate_paper"):
137143
rich_section("Paper validation")
138-
content = PaperValidator(config, prompts).validate(plan.get("attachment"))
139-
va_re_gen = ValidationReportGenerator(config, git_agent.metadata, sourcerank)
144+
content = PaperValidator(config_loader, prompts).validate(plan.get("attachment"))
145+
va_re_gen = ValidationReportGenerator(config_loader, git_agent.metadata, sourcerank)
140146
va_re_gen.build_pdf("Paper", content)
141147
if create_fork:
142148
git_agent.upload_report(va_re_gen.filename, va_re_gen.output_path)
@@ -149,13 +155,13 @@ def main():
149155
# Auto translating names of directories
150156
if plan.get("translate_dirs"):
151157
rich_section("Directory and file translation")
152-
translation = DirectoryTranslator(config)
158+
translation = DirectoryTranslator(config_loader)
153159
translation.rename_directories_and_files()
154160

155161
# Docstring generation
156162
if plan.get("docstring"):
157163
rich_section("Docstrings generation")
158-
generate_docstrings(config, loop)
164+
generate_docstrings(config_loader, loop)
159165

160166
# License compiling
161167
if license_type := plan.get("ensure_license"):
@@ -165,7 +171,7 @@ def main():
165171
# Generate community documentation
166172
if plan.get("community_docs"):
167173
rich_section("Community docs generation")
168-
generate_documentation(config, git_agent.metadata)
174+
generate_documentation(config_loader, git_agent.metadata)
169175

170176
# Requirements generation
171177
if plan.get("requirements"):
@@ -176,21 +182,21 @@ def main():
176182
if plan.get("readme"):
177183
rich_section("README generation")
178184
readme_agent = ReadmeAgent(
179-
config, prompts, plan.get("attachment"), plan.get("refine_readme"), git_agent.metadata
185+
config_loader, prompts, plan.get("attachment"), plan.get("refine_readme"), git_agent.metadata
180186
)
181187
readme_agent.generate_readme()
182188

183189
# Readme translation
184190
translate_readme = plan.get("translate_readme")
185191
if translate_readme:
186192
rich_section("README translation")
187-
ReadmeTranslator(config, prompts, git_agent.metadata, translate_readme).translate_readme()
193+
ReadmeTranslator(config_loader, prompts, git_agent.metadata, translate_readme).translate_readme()
188194

189195
# About section generation
190196
about_gen = None
191197
if plan.get("about"):
192198
rich_section("About Section generation")
193-
about_gen = AboutGenerator(config, prompts, git_agent)
199+
about_gen = AboutGenerator(config_loader, prompts, git_agent)
194200
about_gen.generate_about_content()
195201
if create_fork:
196202
git_agent.update_about_section(about_gen.get_about_content())
@@ -200,8 +206,8 @@ def main():
200206
# Generate platform-specified CI/CD files
201207
if plan.get("generate_workflows"):
202208
rich_section("Workflows generation")
203-
workflow_manager.update_workflow_config(config, plan)
204-
workflow_manager.generate_workflow(config)
209+
workflow_manager.update_workflow_config(config_loader, plan)
210+
workflow_manager.generate_workflow(config_loader)
205211

206212
# Organize repository by adding 'tests' and 'examples' directories if they aren't exist
207213
if plan.get("organize"):
@@ -349,6 +355,7 @@ def load_configuration(
349355
model_name: str,
350356
temperature: Optional[str] = None,
351357
max_tokens: Optional[str] = None,
358+
context_window: Optional[str] = None,
352359
top_p: Optional[str] = None,
353360
) -> ConfigLoader:
354361
"""
@@ -360,7 +367,8 @@ def load_configuration(
360367
base_url: URL of the provider compatible with API OpenAI
361368
model_name: Specific LLM model to use.
362369
temperature: Sampling temperature for the model.
363-
max_tokens: Maximum number of tokens to generate.
370+
max_tokens: Maximum number of output tokens to generate.
371+
context_window: Total number of model context (Input + Output).
364372
top_p: Nucleus sampling value.
365373
366374
Returns:
@@ -380,6 +388,7 @@ def load_configuration(
380388
"model": model_name,
381389
"temperature": temperature,
382390
"max_tokens": max_tokens,
391+
"context_window": context_window,
383392
"top_p": top_p,
384393
}
385394
)

0 commit comments

Comments
 (0)