|
| 1 | +# This file is part of the CoverageControl library |
| 2 | +# |
| 3 | +# Author: Saurav Agarwal |
| 4 | + |
| 5 | +# Repository: https://github.com/KumarRobotics/CoverageControl |
| 6 | +# |
| 7 | +# Copyright (c) 2024, Saurav Agarwal |
| 8 | +# |
| 9 | +# The CoverageControl library is free software: you can redistribute it and/or |
| 10 | +# modify it under the terms of the GNU General Public License as published by |
| 11 | +# the Free Software Foundation, either version 3 of the License, or (at your |
| 12 | +# option) any later version. |
| 13 | +# |
| 14 | +# The CoverageControl library is distributed in the hope that it will be |
| 15 | +# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 16 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General |
| 17 | +# Public License for more details. |
| 18 | +# |
| 19 | +# You should have received a copy of the GNU General Public License along with |
| 20 | +# CoverageControl library. If not, see <https://www.gnu.org/licenses/>. |
| 21 | + |
| 22 | +import torch |
| 23 | +import torch_geometric |
| 24 | +from torch_geometric.nn import MLP |
| 25 | + |
| 26 | +from .config_parser import GNNConfigParser |
| 27 | +from .cnn_backbone import CNNBackBone |
| 28 | +from .gnn_backbone import GNNBackBone |
| 29 | + |
| 30 | +__all__ = ["LPAC"] |
| 31 | + |
| 32 | +class LPAC(torch.nn.Module, GNNConfigParser): |
| 33 | + def __init__(self, config): |
| 34 | + super(LPAC, self).__init__() |
| 35 | + self.cnn_config = config['CNN'] |
| 36 | + self.parse(config['GNN']) |
| 37 | + self.cnn_backbone = CNNBackBone(self.cnn_config) |
| 38 | + self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size + 2) |
| 39 | + # --- no pos --- |
| 40 | + # self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size) |
| 41 | + # --- no pos --- |
| 42 | + self.gnn_mlp = MLP([self.latent_size, 32, 32]) |
| 43 | + self.output_linear = torch.nn.Linear(32, self.output_dim) |
| 44 | + # Register buffers to model |
| 45 | + self.register_buffer("actions_mean", torch.zeros(self.output_dim)) |
| 46 | + self.register_buffer("actions_std", torch.ones(self.output_dim)) |
| 47 | + |
| 48 | + def forward(self, data: torch_geometric.data.Data) -> torch.Tensor: |
| 49 | + x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight |
| 50 | + pos = data.pos |
| 51 | + cnn_output = self.cnn_backbone(x.view(-1, x.shape[-3], x.shape[-2], x.shape[-1])) |
| 52 | + |
| 53 | + # --- no pos --- |
| 54 | + # gnn_output = self.gnn_backbone(cnn_output, edge_index) |
| 55 | + # mlp_output = self.gnn_mlp(gnn_output) |
| 56 | + # x = self.output_linear(mlp_output) |
| 57 | + # x = self.output_linear(self.gnn_mlp(self.gnn_backbone(cnn_output, edge_index))) |
| 58 | + # --- no pos --- |
| 59 | + |
| 60 | + gnn_backbone_in = torch.cat([cnn_output, pos], dim=-1) |
| 61 | + # print(gnn_backbone_in) |
| 62 | + # gnn_output = self.gnn_backbone(gnn_backbone_in, edge_index) |
| 63 | + # mid_test = self.gnn_mlp.lins[0](gnn_output) |
| 64 | + # print(f'mid_test sum1: {mid_test.sum()}') |
| 65 | + # mid_test = self.gnn_mlp.norms[0](mid_test) |
| 66 | + # print(f'mid_test sum: {mid_test.sum()}') |
| 67 | + # mlp_output = self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index) |
| 68 | + # print(f'mlp_output sum: {mlp_output[0]}') |
| 69 | + x = self.output_linear(self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index))) |
| 70 | + return x |
| 71 | + |
| 72 | + def load_model(self, model_state_dict_path: str) -> None: |
| 73 | + self.load_state_dict(torch.load(model_state_dict_path), strict=False) |
| 74 | + |
| 75 | + def load_cnn_backbone(self, model_path: str) -> None: |
| 76 | + self.load_state_dict(torch.load(model_path).state_dict(), strict=False) |
| 77 | + |
| 78 | + def load_gnn_backbone(self, model_path: str) -> None: |
| 79 | + self.load_state_dict(torch.load(model_path).state_dict(), strict=False) |
0 commit comments