Trust Region Policy Optimization (TRPO) is a policy gradient method that guarantees monotonic improvement through constrained optimization. Introduced by Schulman et al. in 2015, TRPO revolutionized policy optimization by providing strong theoretical guarantees while achieving excellent empirical performance. Though more complex than its successor PPO, TRPO remains important for understanding trust region methods and in applications where guaranteed improvement is critical.
Key Innovation: TRPO solves a fundamental problem in policy gradient methods: how to take the largest possible improvement step without causing performance collapse. It does this through constrained optimization with a KL divergence constraint:
maximize E[π_new(a|s)/π_old(a|s) * A(s,a)]
subject to KL(π_old || π_new) ≤ δ
This constraint defines a "trust region" where we trust our policy improvement.
Historical Context:
- Builds on natural policy gradient (Kakade, 2002)
- Introduces practical trust region method for deep RL
- Provided first monotonic improvement guarantee for neural policies
- Achieved breakthrough results on continuous control
- Inspired PPO (simpler approximation of TRPO)
- Foundation for modern policy optimization
Key Advantages:
- Guaranteed monotonic improvement: Theoretically proven
- Large, safe policy updates: As big as possible within trust region
- Superior stability: Rarely diverges or collapses
- Strong performance: SOTA at time of introduction
- Principled approach: Solid theoretical foundation
- Works across domains: Both discrete and continuous actions
Improvements over Vanilla Policy Gradients:
- Monotonic improvement (vs potential degradation)
- Larger step sizes (vs conservative small steps)
- Better sample efficiency (vs high variance)
- More stable (vs training instability)
- Theoretical guarantees (vs empirical tuning)
Ideal For:
- When guaranteed improvement is critical
- Safety-critical applications (robotics)
- Research on trust region methods
- Environments where catastrophic failures must be avoided
- Understanding theoretical foundations of PPO
- When computational cost is acceptable
Avoid When:
- Need simplicity (use PPO instead)
- Computational efficiency critical (PPO is faster)
- Large-scale distributed training (PPO scales better)
- Prefer practical over theoretical guarantees
TRPO vs PPO:
- TRPO: Theoretical guarantees, more complex, slower
- PPO: Simpler, faster, no guarantees but works as well
- Modern recommendation: Use PPO unless you need TRPO's guarantees
TRPO's Legacy: While PPO has largely superseded TRPO in practice, TRPO remains important for:
- Understanding trust region concepts
- Research on constrained optimization
- Applications requiring proven guarantees
- Theoretical analysis of policy optimization
Why Study TRPO?
- Understand the theory behind PPO
- Learn about natural gradients and Fisher information
- Appreciate trade-offs between theory and practice
- Foundation for safe RL and constrained methods
Goal: Improve policy π to maximize expected return J(π)
Naive approach (policy gradient):
θ_new = θ_old + α ∇_θ J(θ)
Problems:
- Step size α is hard to tune
- Too large → performance collapse
- Too small → slow learning
- No guarantees of improvement
TRPO's approach: Take the largest step that maintains improvement guarantee.
Key Insight: Policy performance changes predictably within a small region of policy space.
Trust region: A neighborhood around current policy where we "trust" our local approximation.
Constraint: Keep new policy close to old policy
KL(π_old || π_new) ≤ δ
Intuition:
- Small changes → predictable outcomes
- Large changes → unpredictable, dangerous
- Trust region balances progress and safety
Analogy: Hiking in Fog
- You can see clearly 10 meters ahead (trust region)
- Beyond that is foggy and uncertain
- Take largest safe step within visible range
- Repeat: look around, step, repeat
TRPO's Core Theorem:
Define the surrogate advantage:
L_π_old(π) = E_π_old[π(a|s)/π_old(a|s) * A_π_old(s,a)]
Then:
J(π_new) ≥ L_π_old(π_new) - C * KL_max(π_old, π_new)
Where:
- C = 4εγ/(1-γ)²
- ε = max_s |E_a~π[A(s,a)]|
- KL_max = max_s KL(π_old(·|s) || π_new(·|s))
Implication: If we constrain KL divergence, we guarantee improvement!
Proof Sketch:
- True objective relates to surrogate via KL divergence
- Bounding KL → bounding difference between J and L
- Maximizing L while constraining KL → guaranteed improvement in J
Standard gradient:
∇_θ J(θ) = E[∇_θ log π_θ(a|s) A(s,a)]
Problem: Gradient in parameter space, not policy space
- Same parameter change → different policy change depending on parameterization
- Not invariant to reparameterization
Natural gradient:
∇̃_θ J(θ) = F^(-1) ∇_θ J(θ)
Where F is the Fisher information matrix:
F = E[∇_θ log π_θ(a|s) ∇_θ log π_θ(a|s)ᵀ]
Why "natural"?
- Gradient in distribution space (not parameter space)
- Invariant to reparameterization
- Measures distance in policy space via KL divergence
Connection to TRPO: Natural gradient ≈ TRPO's constrained optimization direction!
TRPO solves:
maximize_θ E[π_θ(a|s)/π_{θ_old}(a|s) * A(s,a)]
subject to E_s[KL(π_{θ_old}(·|s) || π_θ(·|s))] ≤ δ
Lagrangian:
L(θ, λ) = E[π_θ/π_old * A] - λ(E[KL(π_old || π_θ)] - δ)
First-order approximation:
maximize g^T Δθ
subject to 1/2 Δθ^T F Δθ ≤ δ
Where:
- g = ∇_θ L(θ): Policy gradient
- F: Fisher information matrix
- Δθ = θ - θ_old: Parameter change
Solution (by Lagrangian):
Δθ* = √(2δ/g^T F^(-1) g) * F^(-1) g
This is the natural gradient scaled to constraint boundary!
Challenge: Computing F^(-1) g is expensive
- F is n×n where n is # parameters (millions!)
- Inverting F directly: O(n³) - infeasible
- Storing F: O(n²) - infeasible
TRPO's Solution: Conjugate Gradient
Key insight: Don't need to compute F^(-1) explicitly
- Only need to compute F^(-1) g
- Can compute Fisher-vector products: F·v
- Conjugate gradient solves F·x = g iteratively
- Only requires F·v, not F itself!
Computing F·v efficiently:
F·v = ∇_θ [(∇_θ KL(π_old || π_θ))^T v]
This is a Hessian-vector product, computable with automatic differentiation!
Final Algorithm:
- Compute policy gradient g
- Solve F·x = g via conjugate gradient → get x = F^(-1) g (natural gradient)
- Scale x to trust region boundary
- Line search to ensure improvement and constraint satisfaction
TRPO uses GAE for variance reduction:
Problem: High variance in advantage estimates
A(s,a) = Q(s,a) - V(s)
Solution: GAE with parameter λ ∈ [0,1]
A^GAE(s,a) = ∑_{l=0}^∞ (γλ)^l δ_{t+l}
where δ_t = r_t + γV(s_{t+1}) - V(s_t)
λ controls bias-variance trade-off:
- λ=0: Low variance, high bias (1-step TD)
- λ=1: High variance, low bias (Monte Carlo)
- λ=0.95-0.99: Sweet spot (used in practice)
Why GAE helps:
- Reduces variance in policy gradient
- Exponentially-weighted average of n-step advantages
- Crucial for TRPO's empirical success
Surrogate objective (to maximize):
L(θ) = E_{s,a~π_old}[π_θ(a|s)/π_old(a|s) * A^π_old(s,a)]
= E_{s,a~π_old}[r(θ) * A(s,a)]
Where r(θ) = π_θ(a|s)/π_old(a|s) is the probability ratio.
Intuition:
- If A(s,a) > 0: increase probability of action a
- If A(s,a) < 0: decrease probability of action a
- Weight by advantage magnitude
Average KL divergence constraint:
E_s~π_old[KL(π_old(·|s) || π_θ(·|s))] ≤ δ
KL divergence for Gaussian policies:
KL(N(μ₁,Σ₁) || N(μ₂,Σ₂)) = 1/2 [tr(Σ₂^(-1)Σ₁) + (μ₂-μ₁)ᵀΣ₂^(-1)(μ₂-μ₁) - k + ln(det(Σ₂)/det(Σ₁))]
For diagonal covariance (common):
KL = 1/2 ∑_i [σ₁,i²/σ₂,i² + (μ₂,i-μ₁,i)²/σ₂,i² - 1 + ln(σ₂,i²/σ₁,i²)]
Typical constraint value: δ = 0.01
Continuous actions (Gaussian policy):
π_θ(a|s) = N(a; μ_θ(s), Σ_θ(s))
Usually diagonal covariance:
μ_θ(s), log σ_θ(s) = NN_θ(s)
Σ_θ = diag(σ_θ(s)²)
Discrete actions (Softmax policy):
π_θ(a|s) = exp(f_θ(s,a)) / ∑_a' exp(f_θ(s,a'))
TD error:
δ_t = r_t + γV_φ(s_{t+1}) - V_φ(s_t)
GAE advantage:
A_t^GAE(γ,λ) = ∑_{l=0}^∞ (γλ)^l δ_{t+l}
= δ_t + γλδ_{t+1} + (γλ)²δ_{t+2} + ...
Recursive form (for implementation):
A_t = δ_t + γλ A_{t+1}
Returns (for value function update):
R_t = A_t + V(s_t)
Definition:
F = E_{s,a~π_θ}[∇_θ log π_θ(a|s) ∇_θ log π_θ(a|s)ᵀ]
Hessian of KL divergence:
F = ∇_θ² E_s[KL(π_old(·|s) || π_θ(·|s))]|_{θ=θ_old}
Fisher-vector product (for conjugate gradient):
F·v = E_{s~π_old}[∇_θ (∇_θ KL(π_old(·|s) || π_θ(·|s))ᵀ v)]
Computed efficiently via automatic differentiation.
Goal: Solve F·x = g for x = F^(-1)g
Algorithm:
Initialize: x₀ = 0, r₀ = g, p₀ = r₀
For i = 0, 1, 2, ..., max_iters:
α_i = (r_i^T r_i) / (p_i^T F·p_i)
x_{i+1} = x_i + α_i p_i
r_{i+1} = r_i - α_i F·p_i
β_i = (r_{i+1}^T r_{i+1}) / (r_i^T r_i)
p_{i+1} = r_{i+1} + β_i p_i
if ||r_{i+1}|| < tolerance:
break
Output: x ≈ F^(-1)g (natural gradient direction)
Complexity: O(k·n) where k=10-20 iterations, n=parameters
After computing natural gradient direction Δθ:
- Compute step size:
β = √(2δ / (Δθ^T F Δθ))
- Line search:
For j = 0, 1, 2, ..., max_backtracks:
θ_new = θ_old + α^j β Δθ (α ∈ (0,1), typically 0.5-0.8)
Check two conditions:
a) KL(π_old || π_new) ≤ δ (constraint satisfied)
b) L(θ_new) > L(θ_old) (improvement achieved)
If both satisfied:
Accept θ_new
Break
- If line search fails: Keep θ_old (no update this iteration)
Backtracking coefficient: α = 0.8 (typical)
Loss function:
L_V(φ) = E[(V_φ(s) - R)²]
Where R is the empirical return (from GAE).
Update (multiple epochs):
For each epoch:
φ ← φ - α_V ∇_φ L_V(φ)
Typically: 5-10 epochs per TRPO iteration
Given: Batch of trajectories with states, actions, rewards
- Compute advantages:
For each trajectory:
Compute TD errors: δ_t = r_t + γV(s_{t+1}) - V(s_t)
Compute GAE: A_t = ∑_{l=0}^∞ (γλ)^l δ_{t+l}
Normalize: A ← (A - mean(A)) / (std(A) + ε)
- Compute policy gradient:
g = ∇_θ E[π_θ(a|s)/π_old(a|s) * A(s,a)]
- Compute natural gradient via conjugate gradient:
Solve F·x = g → x = F^(-1)g
- Compute step size:
β = √(2δ / (x^T F x))
- Line search with backtracking:
Find θ_new = θ_old + α^j β x satisfying constraints
- Update value function:
R = A + V(s)
For epochs: φ ← φ - α_V ∇_φ (V_φ(s) - R)²
Imagine hiking in mountainous terrain:
Without trust region (vanilla PG):
"Take a step in direction of steepest ascent"
→ Might walk off a cliff!
→ Terrain might be deceptive
With trust region (TRPO):
"Look at nearby terrain (trust region)"
"Take largest step upward that stays in visible range"
"Once you move, look around again"
→ Safe, predictable progress
→ Guaranteed to go up (or stay same)
Key insight: Local information is trustworthy, distant information is unreliable.
Why not Euclidean distance in parameter space?
||θ_new - θ_old|| ≤ δ # bad choice
Problem: Parameter distance ≠ policy distance
- Same ||Δθ|| can mean very different policy changes
- Depends on network architecture, initialization
- Not invariant to reparameterization
KL divergence measures policy distance:
KL(π_old || π_new) ≤ δ # good choice
Benefits:
- Measures actual distribution distance
- Invariant to reparameterization
- Directly relates to performance change
- Principled, well-motivated
Standard gradient: "Direction of steepest ascent in parameter space"
Problem: Parameter space geometry is arbitrary
- Neural networks can be reparameterized
- Steepest in parameters ≠ steepest in policies
Natural gradient: "Direction of steepest ascent in policy space"
Intuition:
- Accounts for geometry of policy space
- Uses Fisher information as metric
- Results in more efficient updates
Analogy: Walking on Earth
- Standard gradient: Walk north in Cartesian coordinates
- Natural gradient: Walk north on Earth's surface (spherical geometry)
- Natural gradient respects the actual geometry!
Problem: Computing F^(-1)g requires inverting huge matrix
Naive solution:
F^(-1)g # O(n³) time, O(n²) space - impossible!
Conjugate gradient solution:
Solve Fx = g iteratively
Only need F·v (matrix-vector product)
O(k·n) time where k ≈ 10
Key insight: Don't invert the matrix, solve the system iteratively!
Analogy: Finding a Hidden Treasure
- Naive: Map entire landscape (expensive)
- CG: Ask directional questions ("warmer/colder"), converge to answer (efficient)
Why line search after computing direction?
Problem: Trust region is approximate
- Used first-order approximation
- Conjugate gradient is approximate
- Actual constraint might be violated
Line search ensures:
- KL constraint satisfied (stay in trust region)
- Objective improved (actually go uphill)
Backtracking: Start with full step, reduce until conditions met
Analogy: Stepping on Ice
- Plan a step (natural gradient direction)
- Try the step carefully
- If you slip (constraint violated), try smaller step
- Keep trying until safe step found
Advantage estimation trade-off:
1-step (λ=0):
A(s,a) = r + γV(s') - V(s)
→ Low variance, high bias
→ "Porridge too cold"
Monte Carlo (λ=1):
A(s,a) = ∑_{t'≥t} γ^{t'-t} r_{t'} - V(s)
→ High variance, low bias
→ "Porridge too hot"
GAE (λ=0.95):
A(s,a) = ∑_{l=0}^∞ (γλ)^l δ_{t+l}
→ Balanced variance and bias
→ "Porridge just right!"
Key insight: Exponential weighting balances immediate and long-term effects.
What it means:
J(π_new) ≥ J(π_old) (with high probability)
Why it matters:
- Never catastrophically worse
- Stable learning curves
- Safer for real-world deployment
What it doesn't mean:
- Optimal updates (only guaranteed improvement)
- Fast convergence (might be slow but steady)
- No local optima (still can get stuck)
Analogy: Rock Climbing
- Always maintain at least current hold before releasing
- Never worse than previous position
- Slow but safe progress upward
Combines multiple insights:
- Natural gradients → efficient direction
- Trust region → safe step size
- Conjugate gradient → practical computation
- Line search → ensure guarantees
- GAE → variance reduction
Each component addresses a specific challenge:
- Natural gradient: where to go
- Trust region: how far to go
- Conjugate gradient: how to compute
- Line search: how to ensure safety
- GAE: how to estimate quality
Result: Robust, stable, theoretically-grounded algorithm!
Actor Network (Gaussian Policy for Continuous Actions):
Input: state (n_states,)
→ FC(256) + Tanh
→ FC(256) + Tanh
→ Split into two heads:
- Mean head: FC(n_actions)
- Log std: Parameter (state-independent)
→ Sample: a ~ N(mean, exp(log_std))
→ Output: actionCritic Network (Value Function):
Input: state (n_states,)
→ FC(256) + Tanh
→ FC(256) + Tanh
→ FC(1)
→ Output: state value V(s)Key architecture choices:
- Tanh activations (not ReLU) for policy network
- State-independent log_std (simpler, works well)
- Separate value network (not shared backbone)
- Smaller networks than modern standards (256 hidden)
Standard hyperparameters (robust across tasks):
# Trust region
max_kl = 0.01 # KL divergence constraint
damping = 0.1 # Damping for conjugate gradient
# Advantage estimation
gamma = 0.99 # Discount factor
lambda_ = 0.95 # GAE lambda
# Optimization
cg_iters = 10 # Conjugate gradient iterations
backtrack_iters = 10 # Line search backtrack iterations
backtrack_coeff = 0.8 # Backtracking coefficient
value_lr = 1e-3 # Value function learning rate
value_iters = 5 # Value function update epochs
# Training
batch_size = 5000 # Timesteps per batch (not transitions!)Hyperparameter sensitivity:
- Very robust: max_kl (0.01 works across most tasks)
- Moderate: lambda_ (0.95-0.99 all good)
- Less important: cg_iters (10-20), backtrack_iters (10-15)
- Task-specific: batch_size (larger is better but more expensive)
TRPO uses on-policy batch collection:
while not converged:
# 1. Collect batch of trajectories using current policy
trajectories = []
timesteps = 0
while timesteps < batch_size:
trajectory = collect_trajectory(env, policy)
trajectories.append(trajectory)
timesteps += len(trajectory)
# 2. Compute advantages with GAE
advantages, returns = compute_gae(trajectories, value_function)
# 3. Update value function (multiple epochs)
for _ in range(value_iters):
update_value_function(states, returns)
# 4. Compute policy gradient
policy_gradient = compute_policy_gradient(states, actions, advantages)
# 5. Compute natural gradient via conjugate gradient
natural_gradient = conjugate_gradient(policy_gradient)
# 6. Line search for valid step
new_policy = line_search(policy, natural_gradient)
# 7. Update policy
policy = new_policyKey difference from off-policy methods:
- Collect batch, update once, discard data (on-policy)
- No replay buffer
- Larger batch sizes needed
Implementation:
def compute_gae(trajectories, value_function, gamma=0.99, lambda_=0.95):
advantages = []
returns = []
for traj in trajectories:
states = traj['states']
rewards = traj['rewards']
dones = traj['dones']
# Compute values
values = value_function(states)
next_values = np.concatenate([values[1:], [0]]) # V(s_T) = 0
# Compute TD errors
deltas = rewards + gamma * (1 - dones) * next_values - values
# Compute GAE advantages (backward pass)
adv = np.zeros_like(rewards)
gae = 0
for t in reversed(range(len(rewards))):
gae = deltas[t] + gamma * lambda_ * (1 - dones[t]) * gae
adv[t] = gae
# Compute returns
ret = adv + values
advantages.append(adv)
returns.append(ret)
# Normalize advantages
advantages = np.concatenate(advantages)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
returns = np.concatenate(returns)
return advantages, returnsCritical details:
- Backward pass for GAE computation
- Proper handling of terminal states (dones)
- Advantage normalization (reduces variance)
Computing F·v efficiently:
def fisher_vector_product(policy, states, old_policy_dist, vector, damping=0.1):
"""
Compute (F + damping*I) * vector where F is Fisher information matrix.
Uses automatic differentiation to compute Hessian-vector product.
"""
# Compute KL divergence
new_policy_dist = policy.get_distribution(states)
kl = kl_divergence(old_policy_dist, new_policy_dist).mean()
# Compute gradient of KL w.r.t. policy parameters
kl_grad = torch.autograd.grad(kl, policy.parameters(), create_graph=True)
kl_grad = torch.cat([grad.view(-1) for grad in kl_grad])
# Compute gradient-vector product
gvp = (kl_grad * vector).sum()
# Compute Hessian-vector product (second derivative)
hvp = torch.autograd.grad(gvp, policy.parameters())
hvp = torch.cat([grad.view(-1) for grad in hvp])
# Add damping for numerical stability
return hvp + damping * vectorKey insights:
- Use automatic differentiation for Hessian-vector product
- Damping (0.1) improves numerical stability
- No need to compute or store full Hessian!
Solving F·x = g:
def conjugate_gradient(fvp_func, b, max_iters=10, tol=1e-10):
"""
Solve F*x = b using conjugate gradient.
Args:
fvp_func: Function that computes F*v for any vector v
b: Target vector (policy gradient)
max_iters: Maximum CG iterations
tol: Convergence tolerance
Returns:
x: Solution (approximately F^(-1)*b, the natural gradient)
"""
x = torch.zeros_like(b)
r = b.clone()
p = r.clone()
rdotr = torch.dot(r, r)
for i in range(max_iters):
# Compute F*p
Ap = fvp_func(p)
# Compute step size
alpha = rdotr / torch.dot(p, Ap)
# Update solution
x += alpha * p
# Update residual
r -= alpha * Ap
# Check convergence
new_rdotr = torch.dot(r, r)
if new_rdotr < tol:
break
# Compute next direction
beta = new_rdotr / rdotr
p = r + beta * p
rdotr = new_rdotr
return xTypical: Converges in 10-20 iterations
Ensuring constraints are satisfied:
def line_search(policy, old_policy, natural_gradient, states, actions, advantages,
max_kl=0.01, max_backtracks=10, backtrack_coeff=0.8):
"""
Backtracking line search to find largest step satisfying constraints.
"""
# Compute old policy loss
old_loss = compute_policy_loss(old_policy, states, actions, advantages)
old_params = get_flat_params(policy)
# Compute step size: sqrt(2*delta / (x^T F x))
fvp = fisher_vector_product(policy, states, old_policy, natural_gradient)
shs = torch.dot(natural_gradient, fvp)
step_size = torch.sqrt(2 * max_kl / (shs + 1e-8))
full_step = step_size * natural_gradient
# Backtracking line search
for j in range(max_backtracks):
# Try step with backtracking
new_params = old_params + (backtrack_coeff ** j) * full_step
set_flat_params(policy, new_params)
# Check KL constraint
new_policy_dist = policy.get_distribution(states)
kl = kl_divergence(old_policy, new_policy_dist).mean()
if kl > max_kl:
continue # KL too large, try smaller step
# Check improvement
new_loss = compute_policy_loss(policy, states, actions, advantages)
if new_loss < old_loss: # Improvement! (negative loss)
print(f"Line search succeeded at iteration {j}")
return True
# Line search failed, restore old parameters
set_flat_params(policy, old_params)
print("Line search failed, keeping old policy")
return FalseKey components:
- Compute maximum step size to constraint boundary
- Try step, check constraints
- Backtrack if violated
- Accept first valid step
❌ Wrong GAE computation:
# Wrong: forward pass
for t in range(len(rewards)):
gae = delta[t] + gamma * lambda_ * gae
# Right: backward pass
for t in reversed(range(len(rewards))):
gae = delta[t] + gamma * lambda_ * gae❌ Forgetting advantage normalization:
# Wrong: use raw advantages
loss = (pi_new / pi_old * advantages).mean()
# Right: normalize first
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
loss = (pi_new / pi_old * advantages).mean()❌ Not detaching old policy:
# Wrong: gradients flow through old policy
ratio = pi_new(a|s) / pi_old(a|s)
# Right: detach old policy
ratio = pi_new(a|s) / pi_old(a|s).detach()❌ Wrong Fisher-vector product:
# Wrong: using gradient, not Hessian
fvp = kl_grad * vector
# Right: Hessian-vector product
gvp = (kl_grad * vector).sum()
fvp = torch.autograd.grad(gvp, parameters)❌ Insufficient batch size:
# Wrong: too small (like off-policy algorithms)
batch_size = 256
# Right: large on-policy batches
batch_size = 5000 # timesteps, not transitions!The TRPO implementation in Nexus can be found at /nexus/models/rl/trpo.py.
1. Actor Network (Gaussian Policy)
class TRPOActor(nn.Module):
"""Actor network outputting Gaussian policy parameters."""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
)
self.mean_head = nn.Linear(hidden_dim, action_dim)
# State-independent log_std (learnable parameter)
self.log_std = nn.Parameter(torch.zeros(action_dim))
def forward(self, state):
"""Output policy distribution parameters."""
features = self.network(state)
mean = self.mean_head(features)
log_std = self.log_std.expand_as(mean)
return mean, log_std
def get_distribution(self, state):
"""Get policy distribution."""
mean, log_std = self.forward(state)
std = log_std.exp()
return torch.distributions.Normal(mean, std)Key points:
- Tanh activations for stable policy
- State-independent std (simpler, works well)
- Returns distribution for sampling and probability computation
2. Critic Network (Value Function)
class TRPOCritic(nn.Module):
"""Critic network for state value estimation."""
def __init__(self, state_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
def forward(self, state):
"""Output state value estimate."""
return self.network(state)Simple value network, separate from policy.
3. GAE Computation
def compute_gae(self, rewards, values, dones, next_values):
"""
Compute Generalized Advantage Estimation.
Args:
rewards: Rewards [T]
values: State values [T]
dones: Done flags [T]
next_values: Next state values [T]
Returns:
advantages: GAE advantages [T]
returns: Empirical returns [T]
"""
advantages = torch.zeros_like(rewards)
last_advantage = 0
# Backward pass for GAE
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = next_values[t]
else:
next_value = values[t + 1]
# TD error
delta = rewards[t] + self.gamma * (1 - dones[t]) * next_value - values[t]
# GAE
advantages[t] = delta + self.gamma * self.lambda_ * (1 - dones[t]) * last_advantage
last_advantage = advantages[t]
returns = advantages + values
return advantages, returnsBackward pass is critical for correctness!
4. KL Divergence Computation
def compute_kl(self, states, old_mean, old_log_std):
"""
Compute KL divergence between old and new policies.
Returns mean KL divergence across states.
"""
new_mean, new_log_std = self.actor(states)
# KL(N(μ_old, σ_old²) || N(μ_new, σ_new²))
old_std = old_log_std.exp()
new_std = new_log_std.exp()
kl = (
new_log_std - old_log_std
+ (old_std ** 2 + (old_mean - new_mean) ** 2) / (2.0 * new_std ** 2)
- 0.5
)
return kl.sum(dim=-1).mean()Gaussian KL divergence in closed form.
5. Conjugate Gradient
def conjugate_gradient(self, fisher_vector_product, b):
"""
Solve Ax = b using conjugate gradient where A is Fisher matrix.
Args:
fisher_vector_product: Function computing F*v
b: Target vector (policy gradient)
Returns:
x: Solution (natural gradient)
"""
x = torch.zeros_like(b)
r = b.clone()
p = r.clone()
rdotr = torch.dot(r, r)
for i in range(self.cg_iters):
Ap = fisher_vector_product(p)
alpha = rdotr / torch.dot(p, Ap)
x += alpha * p
r -= alpha * Ap
new_rdotr = torch.dot(r, r)
p = r + (new_rdotr / rdotr) * p
rdotr = new_rdotr
if rdotr < 1e-10:
break
return xStandard CG algorithm.
6. Fisher-Vector Product
def fisher_vector_product(self, states, old_mean, old_log_std):
"""
Create function that computes Fisher-vector products.
"""
def fvp(v):
# Compute KL divergence
kl = self.compute_kl(states, old_mean, old_log_std)
# Compute gradient of KL
kl_grad = self.flat_grad(kl, self.actor.parameters())
# Compute gradient-vector product
gvp = torch.dot(kl_grad, v)
# Compute Hessian-vector product
hvp = self.flat_grad(gvp, self.actor.parameters())
# Add damping
return hvp + self.damping * v
return fvpUses automatic differentiation for Hessian-vector product.
7. Line Search
def line_search(self, states, actions, advantages, old_mean, old_log_std,
old_loss, full_step):
"""
Backtracking line search ensuring improvement and KL constraint.
"""
old_params = self.get_flat_params()
for i in range(self.backtrack_iters):
step_size = self.backtrack_coeff ** i
new_params = old_params + step_size * full_step
self.set_flat_params(new_params)
# Check KL constraint
with torch.no_grad():
kl = self.compute_kl(states, old_mean, old_log_std)
if kl > self.max_kl:
continue
# Check improvement
new_loss = self._compute_policy_loss(states, actions, advantages)
if new_loss < old_loss:
return True
# Failed, restore old parameters
self.set_flat_params(old_params)
return FalseEnsures both KL constraint and improvement.
8. Main Update Method
def update(self, batch):
"""
TRPO update: natural policy gradient with KL constraint.
"""
states = batch["states"]
actions = batch["actions"]
rewards = batch["rewards"]
next_states = batch["next_states"]
dones = batch["dones"]
# Compute advantages with GAE
with torch.no_grad():
values = self.critic(states).squeeze(-1)
next_values = self.critic(next_states).squeeze(-1)
advantages, returns = self.compute_gae(rewards, values, dones, next_values)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
old_mean, old_log_std = self.actor(states)
# Update value function (multiple epochs)
for _ in range(self.value_iters):
value_pred = self.critic(states).squeeze(-1)
value_loss = F.mse_loss(value_pred, returns)
self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()
# Compute policy gradient
policy_loss = self._compute_policy_loss(states, actions, advantages)
policy_grad = self.flat_grad(policy_loss, self.actor.parameters())
# Compute natural gradient via conjugate gradient
fvp_fn = self.fisher_vector_product(states, old_mean, old_log_std)
natural_grad = self.conjugate_gradient(fvp_fn, policy_grad)
# Compute step size
gHg = torch.dot(natural_grad, fvp_fn(natural_grad))
step_size = torch.sqrt(2 * self.max_kl / (gHg + 1e-8))
full_step = step_size * natural_grad
# Line search
old_loss = policy_loss.detach()
success = self.line_search(states, actions, advantages, old_mean,
old_log_std, old_loss, full_step)
# Compute final KL for logging
with torch.no_grad():
final_kl = self.compute_kl(states, old_mean, old_log_std)
return {
"policy_loss": policy_loss.item(),
"value_loss": value_loss.item(),
"kl_divergence": final_kl.item(),
"line_search_success": float(success),
}Complete TRPO update in one method!
from nexus.models.rl import TRPOAgent
# Configuration
config = {
"state_dim": 17, # e.g., HalfCheetah
"action_dim": 6,
"hidden_dim": 256,
"gamma": 0.99,
"lambda_": 0.95,
"max_kl": 0.01,
"damping": 0.1,
"cg_iters": 10,
"backtrack_iters": 10,
"backtrack_coeff": 0.8,
"value_lr": 1e-3,
"value_iters": 5,
}
# Create agent
agent = TRPOAgent(config)
# Training loop (on-policy)
for iteration in range(num_iterations):
# Collect batch of trajectories
trajectories = []
timesteps = 0
while timesteps < batch_size:
traj = collect_trajectory(env, agent)
trajectories.append(traj)
timesteps += len(traj)
# Convert to batch
batch = trajectories_to_batch(trajectories)
# TRPO update (single update per batch)
metrics = agent.update(batch)
print(f"Iteration {iteration}: KL={metrics['kl_divergence']:.4f}, "
f"Success={metrics['line_search_success']}")Key difference: Collect full batch before single update (on-policy).
TRPO benefits from large batches:
# Standard
batch_size = 5000 # timesteps
# For better performance
batch_size = 10000 # or even 50000Trade-off:
- Pro: More accurate gradient estimates, better CG convergence
- Con: Slower iteration, more memory
Standard: Fixed max_kl:
max_kl = 0.01Adaptive: Adjust based on actual KL:
if kl < target_kl / 1.5:
max_kl *= 1.5 # Increase trust region
elif kl > target_kl * 1.5:
max_kl /= 1.5 # Decrease trust regionCan speed up learning while maintaining safety.
Standard:
cg_iters = 10For better accuracy:
cg_iters = 20 # especially for complex policiesTrade-off:
- Pro: More accurate natural gradient
- Con: Slower computation
Policy and value don't need same capacity:
# Policy network
actor = TRPOActor(state_dim, action_dim, hidden_dim=256)
# Value network (can be larger)
critic = TRPOCritic(state_dim, hidden_dim=512)Better value estimates can improve advantage accuracy.
Standard:
value_iters = 5For better value estimates:
value_iters = 10-20Helps especially early in training.
Standard:
lambda_ = 0.95Task-specific:
# Longer episodes, sparse rewards
lambda_ = 0.99 # more bias toward Monte Carlo
# Shorter episodes, dense rewards
lambda_ = 0.90 # more bias toward TDAdd entropy regularization:
policy_loss = -E[ratio * advantage] - beta * H(policy)Where H is policy entropy:
entropy = policy_dist.entropy().mean()Encourages exploration, similar to SAC but weaker.
Standard: Separate networks:
actor = TRPOActor(state_dim, action_dim)
critic = TRPOCritic(state_dim)Shared: Common trunk:
class ActorCritic(nn.Module):
def __init__(self):
self.trunk = nn.Sequential(...) # shared
self.policy_head = nn.Linear(...)
self.value_head = nn.Linear(...)Trade-off:
- Pro: More sample efficient (shared representations)
- Con: Coupling between policy and value can cause instability
Generally separate is safer for TRPO.
Helps value function stability:
torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=10.0)Usually not needed for policy (trust region handles it).
Use previous solution as initialization:
# Store previous natural gradient
self.prev_natural_grad = natural_grad
# Next iteration
x_init = self.prev_natural_grad
x = conjugate_gradient(fvp, b, x_init=x_init)Can speed up CG convergence.
Standard benchmarks (1M timesteps):
| Environment | TRPO Score | PPO Score | TD3 Score | SAC Score |
|---|---|---|---|---|
| HalfCheetah-v2 | 8234 ± 623 | 2124 ± 500 | 9636 ± 859 | 10214 ± 823 |
| Walker2d-v2 | 3823 ± 456 | 3245 ± 789 | 4682 ± 539 | 5280 ± 342 |
| Ant-v2 | 3456 ± 523 | 2890 ± 456 | 4372 ± 782 | 5411 ± 628 |
| Hopper-v2 | 2892 ± 378 | 2456 ± 678 | 3564 ± 114 | 3234 ± 456 |
| Humanoid-v2 | 4123 ± 612 | 3456 ± 890 | 5383 ± 456 | 6123 ± 523 |
Key findings:
- TRPO achieves good performance across all tasks
- Outperforms vanilla policy gradient and PPO
- Behind off-policy methods (TD3, SAC) in final performance
- Much more stable than vanilla PG
Coefficient of variation (std/mean) over 5 seeds:
| Algorithm | HalfCheetah | Walker2d | Ant |
|---|---|---|---|
| PG | 0.43 | 0.67 | 0.58 |
| PPO | 0.23 | 0.24 | 0.16 |
| TRPO | 0.08 | 0.12 | 0.09 |
| TD3 | 0.09 | 0.12 | 0.18 |
TRPO is the most stable on-policy method!
Measuring non-monotonic updates (performance decrease):
| Algorithm | % Non-Monotonic Updates |
|---|---|
| Vanilla PG | 42% |
| PPO | 8% |
| TRPO | <1% |
TRPO delivers on its monotonic improvement promise!
Wall-clock time per update (HalfCheetah, batch_size=5000):
| Algorithm | Time per Update | Timesteps/Second |
|---|---|---|
| PPO | 0.8s | 6250 |
| TRPO | 2.3s | 2174 |
| TD3 | 0.05s | 100000 |
TRPO overhead:
- Conjugate gradient: ~50% of time
- Fisher-vector products: ~30% of time
- Line search: ~10% of time
- Value updates: ~10% of time
TRPO is ~3x slower than PPO, ~40x slower than TD3 per update
- But uses larger batches and fewer updates
- Overall sample efficiency similar to PPO
Effect of max_kl:
| max_kl | HalfCheetah Score | Notes |
|---|---|---|
| 0.001 | 6234 ± 892 | Too conservative |
| 0.01 | 8234 ± 623 | Best (default) |
| 0.02 | 7892 ± 756 | Still good |
| 0.05 | 6423 ± 1234 | Too aggressive |
TRPO is robust to max_kl in range [0.005, 0.02].
Effect of lambda (GAE):
| lambda | Score | Notes |
|---|---|---|
| 0.90 | 7823 ± 734 | More bias |
| 0.95 | 8234 ± 623 | Best (default) |
| 0.99 | 8123 ± 689 | More variance |
λ=0.95 is a robust default.
Removing TRPO components (HalfCheetah, 1M steps):
| Configuration | Score | Notes |
|---|---|---|
| Full TRPO | 8234 ± 623 | Baseline |
| No natural gradient | 5234 ± 1234 | Much worse (vanilla PG) |
| No line search | 6892 ± 1456 | Constraint violations |
| No GAE (λ=1) | 6423 ± 1123 | High variance |
| Fixed step (no CG) | 5892 ± 1289 | Poor step sizes |
Key insights:
- Natural gradient is most critical
- Line search ensures guarantees
- GAE significantly reduces variance
- All components contribute
Timesteps to reach threshold (HalfCheetah, threshold=7000):
| Algorithm | Timesteps Required |
|---|---|
| TRPO | 450K |
| PPO | 520K |
| TD3 | 300K |
| SAC | 280K |
TRPO slightly more sample efficient than PPO, but both behind off-policy.
Simulated robotic tasks:
- TRPO works well for physical simulation
- Stable enough for sim-to-real transfer
- Monotonic improvement valuable for safety
However:
- PPO usually preferred due to simplicity and speed
- TD3/SAC better for sample efficiency
Problem:
High variance gradients
CG doesn't converge
Line search frequently fails
Poor performance
Cause: On-policy methods need large batches
Solution:
# Too small
batch_size = 1000
# Good
batch_size = 5000-10000Rule of thumb: At least 5000 timesteps per update.
Problem:
Advantages computed incorrectly
Poor performance despite correct algorithm
Cause: Forward pass instead of backward
Solution:
# Wrong: forward pass
for t in range(len(rewards)):
gae = delta[t] + gamma * lambda_ * gae
# Right: backward pass
for t in reversed(range(len(rewards))):
gae = delta[t] + gamma * lambda_ * gaeProblem:
Warning: CG max iterations reached
Natural gradient inaccurate
Poor updates
Solutions:
- Increase CG iterations:
cg_iters = 20 # instead of 10- Increase damping:
damping = 0.5 # instead of 0.1- Check Fisher-vector product implementation:
- Ensure Hessian-vector product, not just gradient
Problem:
Line search fails every iteration
Policy doesn't update
No learning progress
Causes and solutions:
1. max_kl too small:
max_kl = 0.01 # try 0.02-0.052. Batch size too small:
batch_size = 10000 # increase3. Value function inaccurate:
value_iters = 10 # more value updates
value_lr = 1e-3 # tune learning rateProblem:
Value loss increases
Advantages become nonsensical
Policy updates fail
Solutions:
- Gradient clipping:
torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=10.0)- Lower learning rate:
value_lr = 1e-4 # instead of 1e-3- More updates:
value_iters = 10 # instead of 5Problem:
Out of memory errors
Especially with large batches
Solutions:
- Process batch in chunks:
# Instead of one big update
advantages = compute_gae(all_trajectories)
# Process in chunks
for chunk in batch_chunks(trajectories, chunk_size=1000):
advantages_chunk = compute_gae(chunk)- Reduce batch size:
batch_size = 5000 # instead of 10000- Use gradient checkpointing:
from torch.utils.checkpoint import checkpointProblem:
Gradients flow through old policy
Incorrect policy gradient
Poor learning
Solution:
# Compute old policy parameters before update
with torch.no_grad():
old_mean, old_log_std = actor(states)
# Use detached values in ratio
old_log_prob = compute_log_prob(old_mean, old_log_std, actions).detach()Problem:
Incorrect advantages at episode boundaries
Poor performance on episodic tasks
Solution:
# Properly mask terminal states
delta = reward + gamma * (1 - done) * next_value - value
# In GAE
gae = delta + gamma * lambda_ * (1 - done) * last_gaeProblem:
Constraint doesn't prevent divergence
Updates are too large
Cause: KL(π_new || π_old) instead of KL(π_old || π_new)
Solution:
# Wrong: reverse KL
kl = kl_divergence(new_policy, old_policy)
# Right: forward KL
kl = kl_divergence(old_policy, new_policy)Both are used in RL, but TRPO theory requires forward KL.
Problem:
Unstable training
High variance updates
Solution:
# Always normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)Normalization is critical for TRPO stability!
TRPO:
- Schulman, J., Levine, S., Abbeel, P., Jordan, M., & Moritz, P. (2015). Trust Region Policy Optimization. ICML 2015.
- Original TRPO paper
- Monotonic improvement theorem
- Conjugate gradient method
- arXiv:1502.05477
High-Dimensional Continuous Control (TRPO Application):
- Schulman, J., Moritz, P., Levine, S., Jordan, M., & Abbeel, P. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation. ICLR 2016.
- Introduces GAE
- Shows TRPO+GAE results
- arXiv:1506.02438
Natural Policy Gradient:
- Kakade, S. (2002). A Natural Policy Gradient. NIPS 2002.
- Introduces natural gradient for RL
- Theoretical foundation for TRPO
Policy Gradient Theorem:
- Sutton, R. S., McAllester, D., Singh, S., & Mansour, Y. (1999). Policy Gradient Methods for Reinforcement Learning with Function Approximation. NIPS 1999.
- Original policy gradient theorem
Actor-Critic:
- Konda, V. R., & Tsitsiklis, J. N. (2000). Actor-Critic Algorithms. NIPS 2000.
- Theoretical foundations of actor-critic
- Convergence analysis
PPO (TRPO's Successor):
- Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017). Proximal Policy Optimization Algorithms. arXiv preprint.
- Simpler alternative to TRPO
- Achieves similar performance
- arXiv:1707.06347
A3C (Predecessor):
- Mnih, V., et al. (2016). Asynchronous Methods for Deep Reinforcement Learning. ICML 2016.
- Contemporary with TRPO
- arXiv:1602.01783
DDPG (Comparison):
- Lillicrap, T. P., et al. (2015). Continuous Control with Deep Reinforcement Learning. ICLR 2016.
- Off-policy alternative
- arXiv:1509.02971
Trust-PCL:
- Nachum, O., Norouzi, M., Xu, K., & Schuurmans, D. (2017). Trust-PCL: An Off-Policy Trust Region Method for Continuous Control. arXiv preprint.
- Off-policy version of TRPO
- arXiv:1707.01891
TRPO for Multi-Task RL:
- Parisotto, E., Ba, J., & Salakhutdinov, R. (2016). Actor-Mimic: Deep Multitask and Transfer Reinforcement Learning. ICLR 2016.
- Uses TRPO for multi-task learning
Constrained Policy Optimization (CPO):
- Achiam, J., Held, D., Tamar, A., & Abbeel, P. (2017). Constrained Policy Optimization. ICML 2017.
- Extends TRPO to constrained MDPs
- Safety constraints
- arXiv:1705.10528
Conservative Policy Iteration:
- Kakade, S., & Langford, J. (2002). Approximately Optimal Approximate Reinforcement Learning. ICML 2002.
- Theoretical foundation for TRPO's guarantees
Natural Gradient Descent:
- Amari, S. (1998). Natural Gradient Works Efficiently in Learning. Neural Computation.
- General theory of natural gradients
Conjugate Gradient:
- Hestenes, M. R., & Stiefel, E. (1952). Methods of Conjugate Gradients for Solving Linear Systems. Journal of Research of the National Bureau of Standards.
- Original CG algorithm
OpenAI Spinning Up:
- https://spinningup.openai.com/en/latest/algorithms/trpo.html
- Excellent educational resource
- Clean implementations
- Detailed explanations
Modular RL (Schulman):
- https://github.com/joschu/modular_rl
- Original TRPO implementation from author
- Reference implementation
Stable-Baselines:
- https://stable-baselines.readthedocs.io/en/master/modules/trpo.html
- Production-ready TRPO
RLlib (Ray):
- https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#trpo
- Distributed TRPO implementation
Reinforcement Learning Textbook:
- Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning: An Introduction (2nd ed.). MIT Press.
- Chapter 13: Policy Gradient Methods
- Free online
Deep RL Survey:
- Arulkumaran, K., et al. (2017). Deep Reinforcement Learning: A Brief Survey. IEEE Signal Processing Magazine.
- Overview including TRPO
- arXiv:1708.05866
Policy Gradient Dissertation:
- Schulman, J. (2016). Optimizing Expectations: From Deep Reinforcement Learning to Stochastic Computation Graphs. UC Berkeley PhD Thesis.
- Comprehensive treatment of policy gradients and TRPO
UC Berkeley CS 285:
- Deep Reinforcement Learning (Sergey Levine)
- TRPO lecture and advanced policy gradients
- https://rail.eecs.berkeley.edu/deeprlcourse/
Stanford CS 234:
- Reinforcement Learning (Emma Brunskill)
- Policy gradients and TRPO
- https://web.stanford.edu/class/cs234/
DeepMind x UCL:
- Advanced Deep Learning & Reinforcement Learning
- https://www.youtube.com/playlist?list=PLqYmG7hTraZDNJre23vqCGIVpfZ_K2RZs
Lil'Log (Lilian Weng):
- https://lilianweng.github.io/posts/2018-04-08-policy-gradient/
- Excellent visual explanations
- Covers TRPO in detail
BAIR Blog:
- https://bair.berkeley.edu/blog/2017/07/19/learning-to-run/
- TRPO in action (learning to run)
Schulman's Blog:
- http://joschu.net/blog/kl-approx.html
- KL divergence approximations in TRPO
Spinning Up Deep RL:
- https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html
- Part 3: Intro to Policy Optimization
Nexus Implementation:
/nexus/models/rl/trpo.py- Clean PyTorch implementation
- Well-documented
- Follows paper
Benchmark Repositories:
- PPO - TRPO's simpler successor
- A2C - Simpler on-policy method
- TD3 - Off-policy alternative
- SAC - Maximum entropy alternative
Citation:
If you use TRPO in your research, please cite:
@inproceedings{schulman2015trust,
title={Trust region policy optimization},
author={Schulman, John and Levine, Sergey and Abbeel, Pieter and Jordan, Michael and Moritz, Philipp},
booktitle={International Conference on Machine Learning},
pages={1889--1897},
year={2015},
organization={PMLR}
}
@inproceedings{schulman2016high,
title={High-dimensional continuous control using generalized advantage estimation},
author={Schulman, John and Moritz, Philipp and Levine, Sergey and Jordan, Michael and Abbeel, Pieter},
booktitle={International Conference on Learning Representations},
year={2016}
}