diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index f63756d98..13a844015 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -10,6 +10,10 @@ from codeflash.cli_cmds.console import logger +def encoded_tokens_len(s: str) -> int: + '''Function for returning the approximate length of the encoded tokens + It's an approximation of BPE encoding (https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)''' + return len(s)//2 def get_qualified_name(module_name: str, full_qualified_name: str) -> str: if not full_qualified_name: diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index ce54bb0e2..bf55c7575 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -7,13 +7,12 @@ import jedi import libcst as cst -import tiktoken from jedi.api.classes import Name from libcst import CSTNode from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects -from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages +from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages, encoded_tokens_len from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import ( @@ -73,8 +72,7 @@ def get_code_optimization_context( ) # Handle token limits - tokenizer = tiktoken.encoding_for_model("gpt-4o") - final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code)) + final_read_writable_tokens = encoded_tokens_len(final_read_writable_code) if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") @@ -87,7 +85,7 @@ def get_code_optimization_context( ) read_only_context_code = read_only_code_markdown.markdown - read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code)) + read_only_code_markdown_tokens = encoded_tokens_len(read_only_context_code) total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens if total_tokens > optim_token_limit: logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") @@ -96,7 +94,7 @@ def get_code_optimization_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True ) read_only_context_code = read_only_code_no_docstring_markdown.markdown - read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code)) + read_only_code_no_docstring_markdown_tokens = encoded_tokens_len(read_only_context_code) total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens if total_tokens > optim_token_limit: logger.debug("Code context has exceeded token limit, removing read-only code") @@ -111,7 +109,7 @@ def get_code_optimization_context( code_context_type=CodeContextType.TESTGEN, ) testgen_context_code = testgen_code_markdown.code - testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) + testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) if testgen_context_code_tokens > testgen_token_limit: testgen_code_markdown = extract_code_string_context_from_files( helpers_of_fto_dict, @@ -121,7 +119,7 @@ def get_code_optimization_context( code_context_type=CodeContextType.TESTGEN, ) testgen_context_code = testgen_code_markdown.code - testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) + testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) if testgen_context_code_tokens > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") diff --git a/pyproject.toml b/pyproject.toml index 15dc01098..ee6fa9d6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,6 @@ pytest = ">=7.0.0,!=8.3.4" gitpython = ">=3.1.31" libcst = ">=1.0.1" jedi = ">=0.19.1" -tiktoken = ">=0.7.0" timeout-decorator = ">=0.5.0" pytest-timeout = ">=2.1.0" tomlkit = ">=0.11.7"