Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ jobs:
# unmarked tests: adapter, execution engine, etc.
- id: others
display-name: Others
pytest-mark: 'not store and not agentops and not weave and not llmproxy and not utils'
pytest-mark: 'not store and not agentops and not weave and not llmproxy and not utils and not gepa'
env:
- python-version: '3.10'
setup-script: 'legacy'
Expand Down
9 changes: 8 additions & 1 deletion agentlightning/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

if TYPE_CHECKING:
from .apo import APO as APOType
from .gepa import GEPA as GEPAType
from .verl import VERL as VERLType

__all__ = ["Algorithm", "algo", "FastAlgorithm", "Baseline", "APO", "VERL"]
__all__ = ["Algorithm", "algo", "FastAlgorithm", "Baseline", "APO", "VERL", "GEPA"]

# Shortcuts for usages like algo.APO(...)

Expand All @@ -27,3 +28,9 @@ def VERL(*args: Any, **kwargs: Any) -> VERLType:
from .verl import VERL as VERLImplementation

return VERLImplementation(*args, **kwargs)


def GEPA(*args: Any, **kwargs: Any) -> GEPAType:
from .gepa import GEPA as GEPAImplementation

return GEPAImplementation(*args, **kwargs)
6 changes: 6 additions & 0 deletions agentlightning/algorithm/gepa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

from .config import GEPAConfig
from .interface import GEPA

__all__ = ["GEPA", "GEPAConfig"]
139 changes: 139 additions & 0 deletions agentlightning/algorithm/gepa/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) Microsoft. All rights reserved.

"""Logging callback for GEPA optimization within Agent Lightning."""

from __future__ import annotations

import logging
from typing import Any, Dict

logger = logging.getLogger(__name__)


class LightningGEPACallback:
"""GEPA callback that logs optimization progress via the standard logger.

Implements the hooks from GEPA's ``GEPACallback`` protocol, forwarding
each event to Python's logging system. All methods are intentionally
defensive—exceptions are caught and logged so that a callback failure
never aborts the optimization run.

When W&B experiment tracking is enabled (via ``use_wandb``), this callback
also logs per-iteration progress metrics (pareto front aggregate, best
candidate score, number of candidates) so that optimization progress is
visible across every iteration—not only when a new candidate is accepted.
"""

def __init__(self, *, use_wandb: bool = False) -> None:
self._use_wandb = use_wandb

def on_optimization_start(self, event: Dict[str, Any]) -> None:
logger.info("GEPA optimization started")

def on_optimization_end(self, event: Dict[str, Any]) -> None:
logger.info("GEPA optimization ended")

def on_iteration_start(self, event: Dict[str, Any]) -> None:
iteration = event.get("iteration", "?")
logger.info("GEPA iteration %s started", iteration)

def on_iteration_end(self, event: Dict[str, Any]) -> None:
iteration = event.get("iteration", "?")
logger.info("GEPA iteration %s ended", iteration)
if self._use_wandb:
self._log_iteration_metrics(event)

def on_candidate_selected(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA candidate selected: %s", event)

def on_candidate_accepted(self, event: Dict[str, Any]) -> None:
logger.info("GEPA candidate accepted: %s", event)

def on_candidate_rejected(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA candidate rejected: %s", event)

def on_evaluation_start(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA evaluation started")

def on_evaluation_end(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA evaluation ended")

def on_evaluation_skipped(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA evaluation skipped: %s", event)

def on_valset_evaluated(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA valset evaluated: %s", event)

def on_reflective_dataset_built(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA reflective dataset built")

def on_proposal_start(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA proposal started")

def on_proposal_end(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA proposal ended")

def on_merge_attempted(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA merge attempted")

def on_merge_accepted(self, event: Dict[str, Any]) -> None:
logger.info("GEPA merge accepted")

def on_merge_rejected(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA merge rejected")

def on_pareto_front_updated(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA Pareto front updated: %s", event)

def on_state_saved(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA state saved")

def on_budget_updated(self, event: Dict[str, Any]) -> None:
remaining = event.get("metric_calls_remaining", "?")
logger.info("GEPA budget updated, remaining: %s", remaining)

def on_error(self, event: Dict[str, Any]) -> None:
error = event.get("exception", "unknown")
logger.error("GEPA error: %s", error)

def on_minibatch_sampled(self, event: Dict[str, Any]) -> None:
logger.debug("GEPA minibatch sampled: %s", event)

def _log_iteration_metrics(self, event: Dict[str, Any]) -> None:
"""Log per-iteration progress metrics to W&B.

Extracts running totals from ``GEPAState`` so that metrics like
pareto front aggregate and best candidate score are logged at every
iteration, not only when a new candidate is accepted.
"""
try:
import wandb

if wandb.run is None:
return

state = event.get("state")
if state is None:
return

iteration: int = event.get("iteration", 0)

# Pareto front aggregate (average of per-example best scores)
pareto_scores = list(state.pareto_front_valset.values())
pareto_agg = sum(pareto_scores) / len(pareto_scores) if pareto_scores else 0.0

# Best single-candidate aggregate score on valset
val_scores = state.program_full_scores_val_set
best_candidate_score = max(val_scores) if val_scores else 0.0

metrics: Dict[str, Any] = {
"agl/pareto_front_agg": pareto_agg,
"agl/best_candidate_valset_score": best_candidate_score,
"agl/num_candidates": len(state.program_candidates),
"agl/total_metric_calls": state.total_num_evals,
"agl/proposal_accepted": event.get("proposal_accepted", False),
}

wandb.log(metrics, step=iteration)
except Exception as e:
logger.debug("Failed to log iteration metrics to W&B: %s", e)
84 changes: 84 additions & 0 deletions agentlightning/algorithm/gepa/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft. All rights reserved.

"""Configuration for the GEPA algorithm integration."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional


@dataclass
class GEPAConfig:
"""Configuration for the GEPA evolutionary prompt optimizer.

Groups GEPA optimizer knobs, rollout execution parameters, and
reflection LLM configuration into a single configuration object.

Args:
max_metric_calls: Maximum number of evaluation calls before stopping.
Maps to GEPA's budget constraint. ``None`` means unlimited.
candidate_selection_strategy: Strategy for selecting candidates.
``"pareto"`` tracks per-instance best, ``"current_best"`` keeps the
single highest scorer, ``"epsilon_greedy"`` adds exploration.
frontier_type: Pareto frontier granularity. ``"instance"`` tracks
per-example bests; ``"aggregate"`` tracks only the overall best.
reflection_minibatch_size: Number of examples sampled for each
reflection step. ``None`` lets GEPA choose automatically.
module_selector: Strategy for choosing which component to update.
``"round_robin"`` cycles through components sequentially.
seed: Random seed for reproducibility.
use_merge: Whether to attempt merging top candidates.
max_merge_invocations: Maximum number of merge attempts when
``use_merge`` is enabled.
skip_perfect_score: Skip further evaluation when a candidate
achieves perfect score.
perfect_score: Value considered a perfect score.
display_progress_bar: Show a progress bar during optimization.
raise_on_exception: Raise exceptions from GEPA instead of logging.
rollout_batch_timeout: Maximum seconds to wait for a rollout batch
to complete before scoring incomplete rollouts as 0.0.
rollout_poll_interval: Seconds between polling the store for
rollout completion.
reflection_model: Model identifier passed to ``litellm.completion``
for GEPA's reflection/proposal calls. When ``None``, GEPA uses
its own default model.
"""

# GEPA optimizer knobs
max_metric_calls: Optional[int] = None
candidate_selection_strategy: Literal["pareto", "current_best", "epsilon_greedy"] = "pareto"
frontier_type: Literal["instance", "aggregate"] = "instance"
reflection_minibatch_size: Optional[int] = None
module_selector: str = "round_robin"
seed: int = 0
use_merge: bool = False
max_merge_invocations: int = 5
skip_perfect_score: bool = True
perfect_score: float = 1.0
display_progress_bar: bool = False
raise_on_exception: bool = True

# Rollout execution parameters
rollout_batch_timeout: float = 3600.0
rollout_poll_interval: float = 2.0

# Reflection LLM configuration
reflection_model: Optional[str] = None
reflection_model_kwargs: Dict[str, Any] = field(
default_factory=lambda: {}
) # pyright: ignore[reportUnknownVariableType]
"""Extra keyword arguments forwarded to ``litellm.completion`` for reflection
calls. Useful for passing authentication parameters such as
``azure_ad_token_provider`` for Azure Entra ID."""

# Experiment tracking
use_wandb: bool = False
"""Enable Weights & Biases experiment tracking during optimization."""
wandb_api_key: Optional[str] = None
"""W&B API key. When ``None``, relies on ``WANDB_API_KEY`` env var or prior ``wandb login``."""
wandb_init_kwargs: Dict[str, Any] = field(default_factory=lambda: {}) # pyright: ignore[reportUnknownVariableType]
"""Extra keyword arguments forwarded to ``wandb.init()`` (e.g. ``project``, ``name``, ``tags``)."""

# Extra kwargs forwarded to gepa.optimize()
extra_kwargs: Dict[str, Any] = field(default_factory=lambda: {}) # pyright: ignore[reportUnknownVariableType]
Loading
Loading