Skip to content
4 changes: 4 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

from codeflash.cli_cmds.console import logger

def encoded_tokens_len(s: str) -> int:
'''Function for returning the approximate length of the encoded tokens
It's an approximation of BPE encoding (https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)'''
return len(s)//2

def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
if not full_qualified_name:
Expand Down
14 changes: 6 additions & 8 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@

import jedi
import libcst as cst
import tiktoken
from jedi.api.classes import Name
from libcst import CSTNode

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages, encoded_tokens_len
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import (
Expand Down Expand Up @@ -73,8 +72,7 @@ def get_code_optimization_context(
)

# Handle token limits
tokenizer = tiktoken.encoding_for_model("gpt-4o")
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")

Expand All @@ -87,7 +85,7 @@ def get_code_optimization_context(
)
read_only_context_code = read_only_code_markdown.markdown

read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code))
read_only_code_markdown_tokens = encoded_tokens_len(read_only_context_code)
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
if total_tokens > optim_token_limit:
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
Expand All @@ -96,7 +94,7 @@ def get_code_optimization_context(
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True
)
read_only_context_code = read_only_code_no_docstring_markdown.markdown
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code))
read_only_code_no_docstring_markdown_tokens = encoded_tokens_len(read_only_context_code)
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
if total_tokens > optim_token_limit:
logger.debug("Code context has exceeded token limit, removing read-only code")
Expand All @@ -111,7 +109,7 @@ def get_code_optimization_context(
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
testgen_code_markdown = extract_code_string_context_from_files(
helpers_of_fto_dict,
Expand All @@ -121,7 +119,7 @@ def get_code_optimization_context(
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")

Expand Down
Loading