Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions examples/hello-kernel-world/optimize.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
import torch
import torch.nn as nn

import torch.nn.functional as F

class Model(nn.Module):
"""
Model that performs a matrix multiplication, summation, and combined scaling.
Model that performs a matrix multiplication, summation, and combined scaling,
optimized by pre-computing the combined weight vector and using torch.linalg.vecdot.
Assumes torch.compile is applied externally.
"""

def __init__(self, input_size, hidden_size, scaling_factor):
super(Model, self).__init__()
# weight is (hidden_size, input_size)
self.weight = nn.Parameter(torch.randn(hidden_size, input_size))
# Combine the division by 2 and the scaling factor into one operation
self.effective_scaling_factor = scaling_factor / 2.0

# Pre-compute the combined weight vector and scaling factor
with torch.no_grad():
summed_weight_vector = self.weight.sum(dim=0) # (input_size,)
effective_weight_vector = summed_weight_vector * (scaling_factor / 2.0) # (input_size,)
# Store as a buffer
self.register_buffer('effective_weight_vector', effective_weight_vector) # (input_size,)


def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_size).
Returns:
torch.Tensor: Output tensor of shape (batch_size, hidden_size).
torch.Tensor: Output tensor of shape (batch_size, 1).
"""
# Original operations: matmul -> / 2 -> sum -> * scaling_factor
# Optimized operations: matmul -> sum -> * (scaling_factor / 2)
x = torch.matmul(x, self.weight.T)
x = torch.sum(x, dim=1, keepdim=True)
x = x * self.effective_scaling_factor
return x
# Perform the batched dot product using torch.linalg.vecdot
# x is (batch_size, input_size) -> (B, I)
# effective_weight_vector is (input_size,) -> (I)
# torch.linalg.vecdot(A (..., N), B (..., N) or (N)) -> (...)
# In our case: A is (B, I), B is (I). Result is (B,).
output = torch.linalg.vecdot(x, self.effective_weight_vector)

# Unsqueeze to match the required output shape (batch_size, 1)
return output.unsqueeze(1)
52 changes: 24 additions & 28 deletions weco/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def handle_api_error(e: requests.exceptions.HTTPError, console: rich.console.Con
# sys.exit(1)


def start_optimization_session(
def start_optimization_run(
console: rich.console.Console,
source_code: str,
evaluation_command: str,
Expand All @@ -29,14 +29,14 @@ def start_optimization_session(
search_policy_config: Dict[str, Any],
additional_instructions: str = None,
api_keys: Dict[str, Any] = {},
auth_headers: dict = {}, # Add auth_headers
auth_headers: dict = {},
timeout: int = 800,
) -> Dict[str, Any]:
"""Start the optimization session."""
"""Start the optimization run."""
with console.status("[bold green]Starting Optimization..."):
try:
response = requests.post(
f"{__base_url__}/sessions", # Path is relative to base_url
f"{__base_url__}/runs",
json={
"source_code": source_code,
"additional_instructions": additional_instructions,
Expand All @@ -56,24 +56,24 @@ def start_optimization_session(
return response.json()
except requests.exceptions.HTTPError as e:
handle_api_error(e, console)
sys.exit(1) # Exit if starting session fails
sys.exit(1)
except requests.exceptions.RequestException as e:
console.print(f"[bold red]Network Error starting session: {e}[/]")
console.print(f"[bold red]Network Error starting run: {e}[/]")
sys.exit(1)


def evaluate_feedback_then_suggest_next_solution(
session_id: str,
run_id: str,
execution_output: str,
additional_instructions: str = None,
api_keys: Dict[str, Any] = {},
auth_headers: dict = {}, # Add auth_headers
auth_headers: dict = {},
timeout: int = 800,
) -> Dict[str, Any]:
"""Evaluate the feedback and suggest the next solution."""
try:
response = requests.post(
f"{__base_url__}/sessions/{session_id}/suggest", # Path is relative to base_url
f"{__base_url__}/runs/{run_id}/suggest",
json={
"execution_output": execution_output,
"additional_instructions": additional_instructions,
Expand All @@ -93,13 +93,13 @@ def evaluate_feedback_then_suggest_next_solution(
raise # Re-raise the exception


def get_optimization_session_status(
session_id: str, include_history: bool = False, auth_headers: dict = {}, timeout: int = 800
def get_optimization_run_status(
run_id: str, include_history: bool = False, auth_headers: dict = {}, timeout: int = 800
) -> Dict[str, Any]:
"""Get the current status of the optimization session."""
"""Get the current status of the optimization run."""
try:
response = requests.get(
f"{__base_url__}/sessions/{session_id}", # Path is relative to base_url
f"{__base_url__}/runs/{run_id}",
params={"include_history": include_history},
headers=auth_headers,
timeout=timeout,
Expand All @@ -115,48 +115,44 @@ def get_optimization_session_status(


def send_heartbeat(
session_id: str,
run_id: str,
auth_headers: dict = {},
timeout: int = 10, # Shorter timeout for non-critical heartbeat
timeout: int = 10,
) -> bool:
"""Send a heartbeat signal to the backend."""
try:
response = requests.put(f"{__base_url__}/sessions/{session_id}/heartbeat", headers=auth_headers, timeout=timeout)
response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
response = requests.put(f"{__base_url__}/runs/{run_id}/heartbeat", headers=auth_headers, timeout=timeout)
response.raise_for_status()
return True
except requests.exceptions.HTTPError as e:
# Log non-critical errors like 409 Conflict (session not running)
if e.response.status_code == 409:
print(f"Heartbeat ignored: Session {session_id} is not running.", file=sys.stderr)
print(f"Heartbeat ignored: Run {run_id} is not running.", file=sys.stderr)
else:
print(f"Heartbeat failed for session {session_id}: HTTP {e.response.status_code}", file=sys.stderr)
# Don't exit, just report failure
print(f"Heartbeat failed for run {run_id}: HTTP {e.response.status_code}", file=sys.stderr)
return False
except requests.exceptions.RequestException as e:
# Network errors are also non-fatal for heartbeats
print(f"Heartbeat network error for session {session_id}: {e}", file=sys.stderr)
print(f"Heartbeat network error for run {run_id}: {e}", file=sys.stderr)
return False


def report_termination(
session_id: str,
run_id: str,
status_update: str,
reason: str,
details: Optional[str] = None,
auth_headers: dict = {},
timeout: int = 30, # Reasonably longer timeout for important termination message
timeout: int = 30,
) -> bool:
"""Report the termination reason to the backend."""
try:
response = requests.post(
f"{__base_url__}/sessions/{session_id}/terminate",
f"{__base_url__}/runs/{run_id}/terminate",
json={"status_update": status_update, "termination_reason": reason, "termination_details": details},
headers=auth_headers,
timeout=timeout,
)
response.raise_for_status()
return True
except requests.exceptions.RequestException as e:
# Log failure, but don't prevent CLI exit
print(f"Warning: Failed to report termination to backend for session {session_id}: {e}", file=sys.stderr)
print(f"Warning: Failed to report termination to backend for run {run_id}: {e}", file=sys.stderr)
return False
Loading