Skip to content
359 changes: 359 additions & 0 deletions cares_reinforcement_learning/algorithm/policy/SACAE1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
"""
Original Paper: https://arxiv.org/abs/1910.01741
Code based on: https://github.com/denisyarats/pytorch_sac_ae/tree/master

This code runs automatic entropy tuning
"""

import copy
import logging
import os
from typing import Any

import numpy as np
import torch
import torch.nn.functional as F

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.algorithm.algorithm import VectorAlgorithm
from cares_reinforcement_learning.encoders.losses import AELoss
from cares_reinforcement_learning.encoders.vanilla_autoencoder import Decoder1D
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.SACAE1D import Actor, Critic
from cares_reinforcement_learning.util.configurations import SACAE1DConfig


class SACAE1D(VectorAlgorithm):
def __init__(
self,
actor_network: Actor,
critic_network: Critic,
decoder_network: Decoder1D,
config: SACAE1DConfig,
device: torch.device,
):
super().__init__(policy_type="policy", config=config, device=device)

self.vector_observation = config.vector_observation

# this may be called policy_net in other implementations
self.actor_net = actor_network.to(device)

# this may be called soft_q_net in other implementations
self.critic_net = critic_network.to(device)
self.target_critic_net = copy.deepcopy(self.critic_net).to(device)
self.target_critic_net.eval() # never in training mode - helps with batch/drop out layers

# tie the encoder weights
self.actor_net.encoder.copy_conv_weights_from(self.critic_net.encoder)

self.encoder_tau = config.encoder_tau

self.decoder_net = decoder_network.to(device)
self.decoder_update_freq = config.decoder_update_freq
self.decoder_latent_lambda = config.autoencoder_config.latent_lambda

self.gamma = config.gamma
self.tau = config.tau
self.reward_scale = config.reward_scale

# PER
self.use_per_buffer = config.use_per_buffer
self.per_sampling_strategy = config.per_sampling_strategy
self.per_weight_normalisation = config.per_weight_normalisation
self.per_alpha = config.per_alpha
self.min_priority = config.min_priority

self.learn_counter = 0
self.policy_update_freq = config.policy_update_freq
self.target_update_freq = config.target_update_freq

self.target_entropy = -self.actor_net.num_actions

self.actor_net_optimiser = torch.optim.Adam(
self.actor_net.parameters(), lr=config.actor_lr, **config.actor_lr_params
)
self.critic_net_optimiser = torch.optim.Adam(
self.critic_net.parameters(), lr=config.critic_lr, **config.critic_lr_params
)

self.ae_loss_function = AELoss(
latent_lambda=config.autoencoder_config.latent_lambda
)

self.encoder_net_optimiser = torch.optim.Adam(
self.critic_net.encoder.parameters(),
**config.autoencoder_config.encoder_optim_kwargs,
)
self.decoder_net_optimiser = torch.optim.Adam(
self.decoder_net.parameters(),
**config.autoencoder_config.decoder_optim_kwargs,
)

# Set to initial alpha to 1.0 according to other baselines.
init_temperature = 1.0
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
self.log_alpha.requires_grad = True
self.log_alpha_optimizer = torch.optim.Adam(
[self.log_alpha], lr=config.alpha_lr, **config.alpha_lr_params
)

def select_action_from_policy(
self,
state: np.ndarray | dict[str, np.ndarray],
evaluation: bool = False,
) -> np.ndarray:
# note that when evaluating this algorithm we need to select mu as action
self.actor_net.eval()
with torch.no_grad():
if isinstance(state, dict):
state_tensor = hlp.lidar_state_dict_to_tensor(state, self.device)
# In this case the state_tensor is a dict
else:
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
# In this case the state_tensor is a tensor
if evaluation:
(_, _, action) = self.actor_net(state_tensor)
else:
(action, _, _) = self.actor_net(state_tensor)
action = action.cpu().data.numpy().flatten()
self.actor_net.train()
return action

@property
def alpha(self) -> torch.Tensor:
return self.log_alpha.exp()

def _update_critic(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
weights: torch.Tensor,
) -> tuple[dict[str, Any], np.ndarray]:

with torch.no_grad():
with hlp.evaluating(self.actor_net):
next_actions, next_log_pi, _ = self.actor_net(next_states)

target_q_values_one, target_q_values_two = self.target_critic_net(
next_states, next_actions
)
target_q_values = (
torch.minimum(target_q_values_one, target_q_values_two)
- self.alpha * next_log_pi
)

q_target = (
rewards * self.reward_scale + self.gamma * (1 - dones) * target_q_values
)

q_values_one, q_values_two = self.critic_net(states, actions)

td_error_one = (q_values_one - q_target).abs()
td_error_two = (q_values_two - q_target).abs()

critic_loss_one = F.mse_loss(q_values_one, q_target, reduction="none")
critic_loss_one = (critic_loss_one * weights).mean()

critic_loss_two = F.mse_loss(q_values_two, q_target, reduction="none")
critic_loss_two = (critic_loss_two * weights).mean()

critic_loss_total = critic_loss_one + critic_loss_two

self.critic_net_optimiser.zero_grad()
critic_loss_total.backward()
self.critic_net_optimiser.step()

# Update the Priorities - PER only
priorities = (
torch.max(td_error_one, td_error_two)
.clamp(self.min_priority)
.pow(self.per_alpha)
.cpu()
.data.numpy()
.flatten()
)

info = {
"critic_loss_one": critic_loss_one.item(),
"critic_loss_two": critic_loss_two.item(),
"critic_loss_total": critic_loss_total.item(),
}

return info, priorities

def _update_actor_alpha(self, states: torch.Tensor) -> dict[str, Any]:
pi, log_pi, _ = self.actor_net(states, detach_encoder=True)

with hlp.evaluating(self.critic_net):
qf1_pi, qf2_pi = self.critic_net(states, pi, detach_encoder=True)

min_qf_pi = torch.minimum(qf1_pi, qf2_pi)
actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

self.actor_net_optimiser.zero_grad()
actor_loss.backward()
self.actor_net_optimiser.step()

# Update the temperature (alpha)
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

self.log_alpha_optimizer.zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()

info = {
"actor_loss": actor_loss.item(),
"alpha_loss": alpha_loss.item(),
}

return info

def _update_autoencoder(self, states: torch.Tensor) -> dict[str, Any]:
latent_samples = self.critic_net.encoder(states)
reconstructed_data = self.decoder_net(latent_samples)

ae_loss = self.ae_loss_function.calculate_loss(
data=states,
reconstructed_data=reconstructed_data,
latent_sample=latent_samples,
)

with open("ae_loss.txt", "a") as f:
f.write(f"{ae_loss.item()}\n")

self.encoder_net_optimiser.zero_grad()
self.decoder_net_optimiser.zero_grad()
ae_loss.backward()
self.encoder_net_optimiser.step()
self.decoder_net_optimiser.step()

info = {
"ae_loss": ae_loss.item(),
}
return info

def train_policy(
self, memory: MemoryBuffer, batch_size: int, training_step: int
) -> dict[str, Any]:
self.learn_counter += 1

if self.use_per_buffer:
experiences = memory.sample_priority(
batch_size,
sampling_stratagy=self.per_sampling_strategy,
weight_normalisation=self.per_weight_normalisation,
)
states, actions, rewards, next_states, dones, indices, weights = experiences
else:
experiences = memory.sample_uniform(batch_size)
states, actions, rewards, next_states, dones, _ = experiences
weights = [1.0] * batch_size

batch_size = len(states)

actions_tensor = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards_tensor = torch.FloatTensor(np.asarray(rewards)).to(self.device)

if self.vector_observation:
states_tensor = hlp.lidar_states_dict_to_tensor(states, self.device)
next_states_tensor = hlp.lidar_states_dict_to_tensor(
next_states, self.device
)
else:
states_tensor = torch.FloatTensor(np.asarray(states)).to(self.device)
next_states_tensor = torch.FloatTensor(np.asarray(next_states)).to(
self.device
)

dones_tensor = torch.LongTensor(np.asarray(dones)).to(self.device)
weights_tensor = torch.FloatTensor(np.asarray(weights)).to(self.device)

# Reshape to batch_size x whatever
rewards_tensor = rewards_tensor.reshape(batch_size, 1)
dones_tensor = dones_tensor.reshape(batch_size, 1)
weights_tensor = weights_tensor.reshape(batch_size, 1)

info: dict[str, Any] = {}

# Update the Critic
critic_info, priorities = self._update_critic(
states_tensor,
actions_tensor,
rewards_tensor,
next_states_tensor,
dones_tensor,
weights_tensor,
)
info |= critic_info

# Update the Actor
if self.learn_counter % self.policy_update_freq == 0:
actor_info = self._update_actor_alpha(states_tensor)
info |= actor_info
info["alpha"] = self.alpha.item()

if self.learn_counter % self.target_update_freq == 0:
# Update the target networks - Soft Update
hlp.soft_update_params(
self.critic_net.critic.Q1, self.target_critic_net.critic.Q1, self.tau
)
hlp.soft_update_params(
self.critic_net.critic.Q2, self.target_critic_net.critic.Q2, self.tau
)
hlp.soft_update_params(
self.critic_net.encoder,
self.target_critic_net.encoder,
self.encoder_tau,
)

if self.learn_counter % self.decoder_update_freq == 0:
if self.vector_observation:
ae_info = self._update_autoencoder(states_tensor["lidar"])
else:
ae_info = self._update_autoencoder(states_tensor)
info |= ae_info

# Update the Priorities
if self.use_per_buffer:
memory.update_priorities(indices, priorities)

return info

def save_models(self, filepath: str, filename: str) -> None:
if not os.path.exists(filepath):
os.makedirs(filepath)
torch.save(self.actor_net.state_dict(), f"{filepath}/{filename}_actor.pht")
torch.save(self.critic_net.state_dict(), f"{filepath}/{filename}_critic.pht")
torch.save(self.decoder_net.state_dict(), f"{filepath}/{filename}_decoder.pht")
logging.info("models has been saved...")

def load_models(self, filepath: str, filename: str) -> None:
if torch.cuda.is_available():
self.actor_net.load_state_dict(
torch.load(f"{filepath}/{filename}_actor.pht")
)
self.critic_net.load_state_dict(
torch.load(f"{filepath}/{filename}_critic.pht")
)
self.decoder_net.load_state_dict(
torch.load(f"{filepath}/{filename}_decoder.pht")
)
else:
self.actor_net.load_state_dict(
torch.load(f"{filepath}/{filename}_actor.pht"),
map_location=torch.device("cpu"),
)
self.critic_net.load_state_dict(
torch.load(f"{filepath}/{filename}_critic.pht"),
map_location=torch.device("cpu"),
)
self.decoder_net.load_state_dict(
torch.load(f"{filepath}/{filename}_decoder.pht"),
map_location=torch.device("cpu"),
)
logging.info("models has been loaded...")
21 changes: 19 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,23 @@ def save_models(self, filepath: str, filename: str) -> None:
logging.info("models has been saved...")

def load_models(self, filepath: str, filename: str) -> None:
self.actor_net.load_state_dict(torch.load(f"{filepath}/{filename}_actor.pht"))
self.critic_net.load_state_dict(torch.load(f"{filepath}/{filename}_critic.pht"))
if torch.cuda.is_available():
self.actor_net.load_state_dict(
torch.load(f"{filepath}/{filename}_actor.pht")
)
self.critic_net.load_state_dict(
torch.load(f"{filepath}/{filename}_critic.pht")
)
else:
self.actor_net.load_state_dict(
torch.load(
f"{filepath}/{filename}_actor.pht", map_location=torch.device("cpu")
)
)
self.critic_net.load_state_dict(
torch.load(
f"{filepath}/{filename}_critic.pht",
map_location=torch.device("cpu"),
)
)
logging.info("models has been loaded...")
Loading
Loading