Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions codeflash/after_aiagents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
if not content:
return 0
if isinstance(content, str):
return len(_TOKEN_SPLIT_RE.split(content.strip()))
tokens = 0
for part in content:
if isinstance(part, str):
tokens += len(_TOKEN_SPLIT_RE.split(part.strip()))
elif isinstance(part, BinaryContent):
tokens += len(part.data)
return tokens


_TOKEN_SPLIT_RE = re.compile(r'[\s",.:]+')
9 changes: 9 additions & 0 deletions codeflash/after_algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def postprocess(self, predictions: tuple[np.ndarray, ...], max_detections: int):
bboxes, logits = predictions
batch_size, num_queries, num_classes = logits.shape
logits_sigmoid = self.sigmoid_stable(logits)
for batch_idx in range(batch_size):
logits_flat = logits_sigmoid[batch_idx].reshape(-1)
# Use argpartition for better performance when max_detections is smaller than logits_flat
partition_indices = np.argpartition(-logits_flat, max_detections)[:max_detections]
sorted_indices = partition_indices[np.argsort(-logits_flat[partition_indices])]
3 changes: 3 additions & 0 deletions codeflash/after_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def check_missing_data(data: pd.DataFrame):
"""Check if there is any missing data in the DataFrame"""
return data.isnull().values.any()
3 changes: 3 additions & 0 deletions codeflash/after_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def get_authors(session):
query = session.query(Author).join(Book).distinct(Author.id).order_by(Author.id)
return query.all()
11 changes: 11 additions & 0 deletions codeflash/after_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import jax.numpy as jnp


def sum_hand(hand):
"""Returns the total points in a hand."""
return jnp.sum(hand) + (10 * usable_ace(hand))


def usable_ace(hand):
"""Checks to se if a hand has a usable ace."""
return jnp.logical_and(jnp.any(hand == 1), jnp.sum(hand) + 10 <= 21)
17 changes: 17 additions & 0 deletions codeflash/after_numerical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def _equalize_cv(img: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()

# Find the first non-zero index with a numpy operation
i = np.flatnonzero(histogram)[0] if np.any(histogram) else 255

total = np.sum(histogram)
if histogram[i] == total:
return np.full_like(img, i)

scale = 255.0 / (total - histogram[i])

# Optimize cumulative sum and scale to generate LUT
cumsum_histogram = np.cumsum(histogram)
lut = np.clip(((cumsum_histogram - cumsum_histogram[i]) * scale)
.round(), 0, 255).astype(np.uint8)
return sz_lut(img, lut, inplace=True)
11 changes: 11 additions & 0 deletions codeflash/after_web_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
async def get_endpoint(session, url):
async with session.get(url) as response:
return await response.text()


async def some_api_call(urls):
async with aiohttp.ClientSession() as session:
tasks = [get_endpoint(session, url) for url in urls]
# Run requests concurrently
results = await asyncio.gather(*tasks)
return results
16 changes: 16 additions & 0 deletions codeflash/before_aiagents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
if not content:
return 0
if isinstance(content, str):
return len(re.split(r'[\s",.:]+', content.strip()))
tokens = 0
for part in content:
if isinstance(part, str):
tokens += len(re.split(r'[\s",.:]+', part.strip()))
if isinstance(part, (AudioUrl, ImageUrl)):
tokens += 0
elif isinstance(part, BinaryContent):
tokens += len(part.data)
else:
tokens += 0
return tokens
7 changes: 7 additions & 0 deletions codeflash/before_algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def postprocess(self, predictions: tuple[np.ndarray, ...], max_detections: int):
bboxes, logits = predictions
batch_size, num_queries, num_classes = logits.shape
logits_sigmoid = self.sigmoid_stable(logits)
for batch_idx in range(batch_size):
logits_flat = logits_sigmoid[batch_idx].reshape(-1)
sorted_indices = np.argsort(-logits_flat)[:max_detections]
4 changes: 4 additions & 0 deletions codeflash/before_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def check_missing_data(data: pd.DataFrame):
"""Check if there is any missing data in the DataFrame"""
missing_data = data.isnull().sum().sum() > 0
return missing_data
6 changes: 6 additions & 0 deletions codeflash/before_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def get_authors(session):
books = session.query(Book).all()
_authors = []
for book in books:
_authors.append(book.author)
return sorted(list(set(_authors)), key=lambda x: x.id)
11 changes: 11 additions & 0 deletions codeflash/before_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import jax.numpy as jnp


def sum_hand(hand):
"""Returns the total points in a hand."""
return sum(hand) + (10 * usable_ace(hand))


def usable_ace(hand):
"""Checks to se if a hand has a usable ace."""
return jnp.logical_and((jnp.count_nonzero(hand == 1) > 0), (sum(hand) + 10 <= 21))
22 changes: 22 additions & 0 deletions codeflash/before_numerical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def _equalize_cv(img: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
i = 0
for val in histogram:
if val > 0:
break
i += 1
i = min(i, 255)

total = np.sum(histogram)
if histogram[i] == total:
return np.full_like(img, i)

scale = 255.0 / (total - histogram[i])
_sum = 0

lut = np.zeros(256, dtype=np.uint8)

for idx in range(i + 1, len(histogram)):
_sum += histogram[idx]
lut[idx] = clip(round(_sum * scale), np.uint8)
return sz_lut(img, lut, inplace=True)
12 changes: 12 additions & 0 deletions codeflash/before_web_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
async def get_endpoint(session, url):
async with session.get(url) as response:
return await response.text()


async def some_api_call(urls):
async with aiohttp.ClientSession() as session:
results = []
for url in urls:
result = await get_endpoint(session, url)
results.append(result)
return results
Loading