Skip to content

Commit 94d927d

Browse files
Saurav AgarwalSaurav Agarwal
authored andcommitted
Add python modules and params
1 parent fcef7b6 commit 94d927d

File tree

7 files changed

+296
-25
lines changed

7 files changed

+296
-25
lines changed

params/data_params.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
DataDir = "/root/CoverageControl_ws/data/pure_coverage" # Absolute location
2-
EnvironmentConfig = "./env_params.yaml" # Relative to DataDir
1+
DataDir = "${CoverageControl_ws}/datasets/lpac" # Absolute location
2+
EnvironmentConfig = "${CoverageControl_ws}/datasets/lpac/coverage_control_params.toml" # Absolute location
33

4+
# Only required for data generation using C++
45
# The generator requires a TorchVision JIT transformer model
56
# for resizing robot local maps on the GPU
67
# The python script for generating the model is located at
@@ -16,18 +17,18 @@ EveryNumSteps = 5
1617

1718
# The robots stop moving once the algorithm has converged
1819
# Having some of these converged steps can help in stabilizing robot actions
19-
ConvergedDataRatio = .25
20+
ConvergedDataRatio = 0.02
2021

2122
# Resizing of maps and Sparsification of tensors are triggered every TriggerPostProcessing dataset
2223
# This should be set based on RAM resources available on the system
23-
TriggerPostProcessing = 1000
24+
TriggerPostProcessing = 100
2425

2526
CNNMapSize = 32
2627
SaveAsSparseQ = true
2728
NormalizeQ = true
2829

2930
[DataSetSplit]
30-
TrainRatio = .7
31-
ValRatio = .2
32-
TestRatio = .1
31+
TrainRatio = 0.7
32+
ValRatio = 0.2
33+
TestRatio = 0.1
3334

params/eval.toml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
EvalDir = "${CoverageControl_ws}/datasets/lpac/eval/" # Absolute location
2+
EnvironmentConfig = "${CoverageControl_ws}/datasets/lpac/coverage_control_params.toml" # Absolute location
3+
4+
EnvironmentDataDir = "${CoverageControl_ws}/datasets/lpac/envs/" # Absolute location
5+
NumEnvironments = 2
6+
NumSteps = 600
7+
8+
[[Controllers]]
9+
Name = "lpac"
10+
Type = "Learning"
11+
# ModelFile: "~/CoverageControl_ws/datsets/lpac/models/model_k3_1024.pt"
12+
ModelStateDict = "${CoverageControl_ws}/datasets/lpac/models/model_k3_1024_state_dict.pt"
13+
LearningParams = "${CoverageControl_ws}/datasets/lpac/models/learning_params.toml"
14+
UseCommMap = true
15+
UseCNN = true
16+
CNNMapSize = 32
17+
18+
[[Controllers]]
19+
Name = "DecentralizedCVT" # Creates a subdirectory with this name
20+
Algorithm = "DecentralizedCVT"
21+
Type = "CVT"
22+
23+
[[Controllers]]
24+
Name = "ClairvoyantCVT"
25+
Algorithm = "ClairvoyantCVT"
26+
Type = "CVT"
27+
28+
[[Controllers]]
29+
Name = "CentralizedCVT"
30+
Algorithm = "CentralizedCVT"
31+
Type = "CVT"

params/eval_single.toml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
EvalDir = "${CoverageControl_ws}/datasets/lpac/eval/" # Absolute location
2+
EnvironmentConfig = "${CoverageControl_ws}/datasets/lpac/coverage_control_params.toml" # Absolute location
3+
4+
EnvironmentDataDir = "${CoverageControl_ws}/datasets/lpac/envs/" # Absolute location
5+
FeatureFile = "0.env" # Relative to EnvironmentDataDir
6+
RobotPosFile = "0.pos" # Relative to EnvironmentDataDir
7+
8+
NumSteps = 600
9+
10+
PlotMap = true
11+
GenerateVideo = true # Will generate a video for each controller
12+
13+
[[Controllers]]
14+
Name = "lpac"
15+
Type = "Learning"
16+
# ModelFile: "~/CoverageControl_ws/datsets/lpac/models/model_k3_1024.pt"
17+
ModelStateDict = "${CoverageControl_ws}/datasets/lpac/models/model_k3_1024_state_dict.pt"
18+
LearningParams = "${CoverageControl_ws}/datasets/lpac/models/learning_params.toml"
19+
UseCommMap = true
20+
UseCNN = true
21+
CNNMapSize = 32
22+
23+
# [[Controllers]]
24+
# Name = "DecentralizedCVT" # Creates a subdirectory with this name
25+
# Algorithm = "DecentralizedCVT"
26+
# Type = "CVT"
27+
28+
# [[Controllers]]
29+
# Name = "ClairvoyantCVT"
30+
# Algorithm = "ClairvoyantCVT"
31+
# Type = "CVT"
32+
33+
# [[Controllers]]
34+
# Name = "CentralizedCVT"
35+
# Algorithm = "CentralizedCVT"
36+
# Type = "CVT"

params/learning_params.toml

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,47 @@
1-
DataDir = "/root/CoverageControl_ws/data/pure_coverage/" # Absolute location
2-
3-
GPUs = [4, 5]
1+
DataDir = "${CoverageControl_ws}/datasets/lpac/" # Absolute location
2+
3+
GPUs = [0, 1]
4+
NumWorkers = 4
5+
# Directory to save the model
6+
# If a model is already present, it will be loaded
7+
# Similarly, for the optimizer
8+
[LPACModel]
9+
Dir = "${CoverageControl_ws}/datasets/lpac/models/"
10+
Model = "model.pt"
11+
Optimizer = "optimizer.pt"
12+
13+
[CNNModel]
14+
Dir = "${CoverageControl_ws}/datasets/lpac/models/" # Absolute location
15+
Model = "model.pt"
16+
Optimizer = "optimizer.pt"
417

518
[ModelConfig]
6-
UseCommMaps = True
7-
19+
UseCommMaps = true
820

921
[GNNBackBone]
1022
InputDim = 7
1123
NumHops = 3
12-
NumLayers = 4
13-
LatentSize = 64
24+
NumLayers = 5
25+
LatentSize = 256
1426
OutputDim = 2
1527

16-
[GNNTraining]
17-
LearningRate = 0.001
18-
WeightDecay = 0.0001
28+
[LPACTraining]
29+
LearningRate = 0.0001
30+
WeightDecay = 0.001
1931
BatchSize = 10
20-
NumEpochs = 50
32+
NumEpochs = 15
2133

2234
[CNNBackBone]
2335
InputDim = 4
24-
OutputDim = 7
25-
NumLayers = 2
26-
LatentSize = 8
36+
NumLayers = 3
37+
LatentSize = 32
2738
KernelSize = 3
2839
ImageSize = 32
40+
OutputDim = 7
2941

3042
[CNNTraining]
31-
LearningRate = 0.001
32-
WeightDecay = 0.0001
43+
LearningRate = 0.0001
44+
WeightDecay = 0.001
3345
BatchSize = 10
34-
NumEpochs = 50
35-
46+
NumEpochs = 15
47+
Momentum = 0.1
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
from __future__ import annotations
3+
4+
__name__ = "algorithms"
5+
6+
from .._core import NearOptimalCVT, ClairvoyantCVT, CentralizedCVT, DecentralizedCVT
7+
8+
__all__ = ["NearOptimalCVT", "ClairvoyantCVT", "CentralizedCVT", "DecentralizedCVT"]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# This file is part of the CoverageControl library
2+
#
3+
# Author: Saurav Agarwal
4+
5+
# Repository: https://github.com/KumarRobotics/CoverageControl
6+
#
7+
# Copyright (c) 2024, Saurav Agarwal
8+
#
9+
# The CoverageControl library is free software: you can redistribute it and/or
10+
# modify it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or (at your
12+
# option) any later version.
13+
#
14+
# The CoverageControl library is distributed in the hope that it will be
15+
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17+
# Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License along with
20+
# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21+
22+
import os
23+
import sys
24+
if sys.version_info[1] < 11:
25+
import tomli as tomllib
26+
else:
27+
import tomllib
28+
import yaml
29+
import torch
30+
31+
class IOUtils:
32+
"""
33+
Class provides the following utility functions:
34+
- load_tensor
35+
- load_yaml
36+
- load_toml
37+
"""
38+
39+
@staticmethod
40+
def sanitize_path(path_str: str) -> str:
41+
return os.path.normpath(os.path.expanduser(os.path.expandvars(path_str)))
42+
43+
@staticmethod
44+
def load_tensor(path: str) -> torch.tensor:
45+
"""
46+
Function to load a tensor from a file
47+
Can load tensors from jit script format files
48+
49+
Args:
50+
path (str): Path to the file
51+
52+
Returns:
53+
tensor: The loaded tensor
54+
None: If the file does not exist
55+
56+
Raises:
57+
FileNotFoundError: If the file does not exist
58+
"""
59+
# Throw error if path does not exist
60+
path = IOUtils.sanitize_path(path)
61+
if not os.path.exists(path):
62+
raise FileNotFoundError(f"DataLoaderUtils::load_tensor: File not found: {path}")
63+
# Load data
64+
data = torch.load(path)
65+
# Extract tensor if data is in jit script format
66+
if isinstance(data, torch.jit.ScriptModule):
67+
tensor = list(data.parameters())[0]
68+
else:
69+
tensor = data
70+
return tensor
71+
72+
@staticmethod
73+
def load_yaml(path: str) -> dict:
74+
"""
75+
Function to load a yaml file
76+
77+
Args:
78+
path (str): Path to the file
79+
80+
Returns:
81+
data: The loaded data
82+
83+
Raises:
84+
FileNotFoundError: If the file does not exist
85+
"""
86+
87+
path = IOUtils.sanitize_path(path)
88+
# Throw error if path does not exist
89+
if not os.path.exists(path):
90+
raise FileNotFoundError(f"DataLoaderUtils::load_yaml File not found: {path}")
91+
# Load data
92+
with open(path, "r") as f:
93+
data = yaml.load(f, Loader=yaml.FullLoader)
94+
return data
95+
96+
@staticmethod
97+
def load_toml(path: str) -> dict: # Throw error if path does not exist
98+
path = IOUtils.sanitize_path(path)
99+
if not os.path.exists(path):
100+
raise FileNotFoundError(f"data_loader_utils::LoadToml: File not found: {path}")
101+
# Load data
102+
with open(path, "rb") as f:
103+
data = tomllib.load(f)
104+
return data
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# This file is part of the CoverageControl library
2+
#
3+
# Author: Saurav Agarwal
4+
5+
# Repository: https://github.com/KumarRobotics/CoverageControl
6+
#
7+
# Copyright (c) 2024, Saurav Agarwal
8+
#
9+
# The CoverageControl library is free software: you can redistribute it and/or
10+
# modify it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or (at your
12+
# option) any later version.
13+
#
14+
# The CoverageControl library is distributed in the hope that it will be
15+
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17+
# Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License along with
20+
# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21+
22+
import torch
23+
import torch_geometric
24+
from torch_geometric.nn import MLP
25+
26+
from .config_parser import GNNConfigParser
27+
from .cnn_backbone import CNNBackBone
28+
from .gnn_backbone import GNNBackBone
29+
30+
__all__ = ["LPAC"]
31+
32+
class LPAC(torch.nn.Module, GNNConfigParser):
33+
def __init__(self, config):
34+
super(LPAC, self).__init__()
35+
self.cnn_config = config['CNN']
36+
self.parse(config['GNN'])
37+
self.cnn_backbone = CNNBackBone(self.cnn_config)
38+
self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size + 2)
39+
# --- no pos ---
40+
# self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size)
41+
# --- no pos ---
42+
self.gnn_mlp = MLP([self.latent_size, 32, 32])
43+
self.output_linear = torch.nn.Linear(32, self.output_dim)
44+
# Register buffers to model
45+
self.register_buffer("actions_mean", torch.zeros(self.output_dim))
46+
self.register_buffer("actions_std", torch.ones(self.output_dim))
47+
48+
def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
49+
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
50+
pos = data.pos
51+
cnn_output = self.cnn_backbone(x.view(-1, x.shape[-3], x.shape[-2], x.shape[-1]))
52+
53+
# --- no pos ---
54+
# gnn_output = self.gnn_backbone(cnn_output, edge_index)
55+
# mlp_output = self.gnn_mlp(gnn_output)
56+
# x = self.output_linear(mlp_output)
57+
# x = self.output_linear(self.gnn_mlp(self.gnn_backbone(cnn_output, edge_index)))
58+
# --- no pos ---
59+
60+
gnn_backbone_in = torch.cat([cnn_output, pos], dim=-1)
61+
# print(gnn_backbone_in)
62+
# gnn_output = self.gnn_backbone(gnn_backbone_in, edge_index)
63+
# mid_test = self.gnn_mlp.lins[0](gnn_output)
64+
# print(f'mid_test sum1: {mid_test.sum()}')
65+
# mid_test = self.gnn_mlp.norms[0](mid_test)
66+
# print(f'mid_test sum: {mid_test.sum()}')
67+
# mlp_output = self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index)
68+
# print(f'mlp_output sum: {mlp_output[0]}')
69+
x = self.output_linear(self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index)))
70+
return x
71+
72+
def load_model(self, model_state_dict_path: str) -> None:
73+
self.load_state_dict(torch.load(model_state_dict_path), strict=False)
74+
75+
def load_cnn_backbone(self, model_path: str) -> None:
76+
self.load_state_dict(torch.load(model_path).state_dict(), strict=False)
77+
78+
def load_gnn_backbone(self, model_path: str) -> None:
79+
self.load_state_dict(torch.load(model_path).state_dict(), strict=False)

0 commit comments

Comments
 (0)