Skip to content

Commit 4f39794

Browse files
committed
get it working
1 parent 2686682 commit 4f39794

File tree

5 files changed

+20
-19
lines changed

5 files changed

+20
-19
lines changed

codeflash/api/cfapi.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
import json
44
import os
55
import sys
6-
import git
76
from functools import lru_cache
87
from pathlib import Path
9-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, List
8+
from typing import TYPE_CHECKING, Any, Optional
109

10+
import git
1111
import requests
1212
import sentry_sdk
1313
from pydantic.json import pydantic_encoder
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.version import __version__
1817
from codeflash.code_utils.git_utils import get_repo_owner_and_name
18+
from codeflash.version import __version__
1919

2020
if TYPE_CHECKING:
2121
from requests import Response
@@ -194,7 +194,9 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
194194
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
195195

196196

197-
def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, code_contexts: List[Dict[str, str]]) -> Dict:
197+
def is_function_being_optimized_again(
198+
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
199+
) -> dict:
198200
"""Check if the function being optimized is being optimized again."""
199201
response = make_cfapi_request(
200202
"/is-already-optimized",
@@ -204,8 +206,9 @@ def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, cod
204206
response.raise_for_status()
205207
return response.json()
206208

207-
def add_code_context_hash( code_context_hash: str):
208-
"""Add code context to the DB cache"""
209+
210+
def add_code_context_hash(code_context_hash: str) -> None:
211+
"""Add code context to the DB cache."""
209212
pr_number = get_pr_number()
210213
if pr_number is None:
211214
return
@@ -215,16 +218,9 @@ def add_code_context_hash( code_context_hash: str):
215218
except git.exc.InvalidGitRepositoryError:
216219
return
217220

218-
219221
if owner and repo and pr_number is not None:
220222
make_cfapi_request(
221223
"/add-code-hash",
222224
"POST",
223-
{
224-
"owner": owner,
225-
"repo": repo,
226-
"pr_number": pr_number,
227-
"code_context_hash": code_context_hash
228-
}
225+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
229226
)
230-

codeflash/context/code_context_extractor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import os
45
from collections import defaultdict
56
from itertools import chain
@@ -132,12 +133,15 @@ def get_code_optimization_context(
132133
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
133134
if testgen_context_code_tokens > testgen_token_limit:
134135
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
136+
code_hash_context = hashing_code_context.markdown
137+
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
135138

136139
return CodeOptimizationContext(
137140
testgen_context_code=testgen_context_code,
138141
read_writable_code=final_read_writable_code,
139142
read_only_context_code=read_only_context_code,
140-
hashing_code_context=hashing_code_context.markdown,
143+
hashing_code_context=code_hash_context,
144+
hashing_code_context_hash=code_hash,
141145
helper_functions=helpers_of_fto_list,
142146
preexisting_objects=preexisting_objects,
143147
)

codeflash/discovery/functions_to_optimize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import ast
4-
import hashlib
54
import os
65
import random
76
import warnings
@@ -446,7 +445,7 @@ def check_optimization_status(function_to_optimize: FunctionToOptimize, code_con
446445

447446
code_contexts = []
448447

449-
func_hash = hashlib.sha256(code_context.hashing_code_context.encode("utf-8")).hexdigest()
448+
func_hash = code_context.hashing_code_context_hash
450449
# Use a unique path identifier that includes function info
451450

452451
code_contexts.append(

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class CodeOptimizationContext(BaseModel):
158158
read_writable_code: str = Field(min_length=1)
159159
read_only_context_code: str = ""
160160
hashing_code_context: str = ""
161+
hashing_code_context_hash: str = ""
161162
helper_functions: list[FunctionSource]
162163
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]
163164

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
156156
if has_any_async_functions(code_context.read_writable_code):
157157
return Failure("Codeflash does not support async functions in the code to optimize.")
158158
if check_optimization_status(self.function_to_optimize, code_context):
159-
return Failure("This function has already been optimized, skipping.")
159+
return Failure("This function has previously been optimized, skipping.")
160160

161161
code_print(code_context.read_writable_code)
162162
generated_test_paths = [
@@ -377,7 +377,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
377377

378378
# Add function to code context hash if in gh actions
379379

380-
add_code_context_hash(self.function_to_optimize.get_code_context_hash())
380+
add_code_context_hash(code_context.hashing_code_context_hash)
381381

382382
if self.args.override_fixtures:
383383
restore_conftest(original_conftest_content)
@@ -689,6 +689,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
689689
read_writable_code=new_code_ctx.read_writable_code,
690690
read_only_context_code=new_code_ctx.read_only_context_code,
691691
hashing_code_context=new_code_ctx.hashing_code_context,
692+
hashing_code_context_hash=new_code_ctx.hashing_code_context_hash,
692693
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
693694
preexisting_objects=new_code_ctx.preexisting_objects,
694695
)

0 commit comments

Comments
 (0)