Skip to content

Commit 0a6912c

Browse files
Saurav AgarwalSaurav Agarwal
authored andcommitted
Restructure py library
1 parent efc55f9 commit 0a6912c

File tree

8 files changed

+209
-110
lines changed

8 files changed

+209
-110
lines changed

python/coverage_control/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
from __future__ import annotations
99

10+
1011
from ._version import version as __version__
12+
from .core import *
13+
14+
__name__ = "coverage_control"
1115

12-
__all__ = ["__version__"]
16+
__all__ = ["__version__", "core"]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
from __future__ import annotations
3+
4+
__name__ = "core"
5+
6+
from .._core import Point2, PointVector, PolygonFeature, VoronoiCell, VoronoiCells
7+
from .._core import BivariateNormalDistribution, WorldIDF, RobotModel, CoverageSystem
8+
from .._core import Parameters
9+
10+
from .algorithms import *
11+
12+
__all__ = ["Point2", "PointVector", "PolygonFeature", "VoronoiCell", "VoronoiCells", "BivariateNormalDistribution", "WorldIDF", "RobotModel", "CoverageSystem", "algorithms", "Parameters", "__name__"]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
from __future__ import annotations
3+
4+
__name__ = "algorithms"
5+
6+
from ..._core import NearOptimalCVT, ClairvoyantCVT, CentralizedCVT, DecentralizedCVT
7+
8+
__all__ = ["NearOptimalCVT", "ClairvoyantCVT", "CentralizedCVT", "DecentralizedCVT"]
Lines changed: 148 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,152 @@
1+
# This file is part of the CoverageControl library
2+
#
3+
# Author: Saurav Agarwal
4+
5+
# Repository: https://github.com/KumarRobotics/CoverageControl
6+
#
7+
# Copyright (c) 2024, Saurav Agarwal
8+
#
9+
# The CoverageControl library is free software: you can redistribute it and/or
10+
# modify it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or (at your
12+
# option) any later version.
13+
#
14+
# The CoverageControl library is distributed in the hope that it will be
15+
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17+
# Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License along with
20+
# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21+
22+
import os
123
import yaml
224
import tomllib
3-
import os
425
import torch
526
import torch_geometric
6-
import numpy
7-
8-
"""
9-
Function to load a tensor from a file
10-
Checks if the file exists, if not, returns None
11-
Checks if the loaded data is a tensor or is in jit script format
12-
"""
13-
def LoadTensor(path):
14-
# Throw error if path does not exist
15-
if not os.path.exists(path):
16-
raise FileNotFoundError(f"data_loader_utils::LoadTensor: File not found: {path}")
17-
# Load data
18-
data = torch.load(path)
19-
# Extract tensor if data is in jit script format
20-
if isinstance(data, torch.jit.ScriptModule):
21-
tensor = list(data.parameters())[0]
22-
else:
23-
tensor = data
24-
return tensor
25-
26-
def LoadYaml(path):
27-
# Throw error if path does not exist
28-
if not os.path.exists(path):
29-
raise FileNotFoundError(f"data_loader_utils::LoadYaml: File not found: {path}")
30-
# Load data
31-
with open(path, "r") as f:
32-
data = yaml.load(f, Loader=yaml.FullLoader)
33-
return data
34-
35-
def LoadToml(path):
36-
# Throw error if path does not exist
37-
if not os.path.exists(path):
38-
raise FileNotFoundError(f"data_loader_utils::LoadToml: File not found: {path}")
39-
# Load data
40-
with open(path, "rb") as f:
41-
data = tomllib.load(f)
42-
return data
43-
44-
def LoadMaps(path, use_comm_map):
45-
local_maps = LoadTensor(f"{path}/local_maps.pt")
46-
local_maps = local_maps.to_dense().unsqueeze(2)
47-
obstacle_maps = LoadTensor(f"{path}/obstacle_maps.pt")
48-
obstacle_maps = obstacle_maps.to_dense().unsqueeze(2)
49-
50-
if use_comm_map:
51-
comm_maps = LoadTensor(f"{path}/comm_maps.pt")
52-
comm_maps = comm_maps.to_dense()
53-
# comm_maps = (comm_maps * 256 + 256)/512
54-
maps = torch.cat([local_maps, comm_maps, obstacle_maps], 2)
55-
else:
56-
maps = torch.cat([local_maps, obstacle_maps], 2)
57-
return maps
58-
59-
def LoadFeatures(path, output_dim = None):
60-
normalized_coverage_features = LoadTensor(f"{path}/normalized_coverage_features.pt")
61-
coverage_features_mean = LoadTensor(f"{path}/../coverage_features_mean.pt")
62-
coverage_features_std = LoadTensor(f"{path}/../coverage_features_std.pt")
63-
if output_dim is not None:
64-
normalized_coverage_features = normalized_coverage_features[:, :, :output_dim]
65-
return normalized_coverage_features, coverage_features_mean, coverage_features_std
66-
67-
def LoadActions(path):
68-
actions = LoadTensor(f"{path}/normalized_actions.pt")
69-
actions_mean = LoadTensor(f"{path}/../actions_mean.pt")
70-
actions_std = LoadTensor(f"{path}/../actions_std.pt")
71-
return actions, actions_mean, actions_std
72-
73-
def LoadRobotPositions(path):
74-
robot_positions = LoadTensor(f"{path}/robot_positions.pt")
75-
return robot_positions
76-
77-
def LoadEdgeWeights(path):
78-
edge_weights = LoadTensor(f"{path}/edge_weights.pt")
79-
edge_weights.to_dense()
80-
return edge_weights
81-
82-
def ToTorchGeometricData(feature, edge_weights, pos = None):
83-
# senders, receivers = numpy.nonzero(edge_weights)
84-
# weights = edge_weights[senders, receivers]
85-
# edge_index = numpy.stack([senders, receivers])
86-
edge_weights = edge_weights.to_sparse()
87-
edge_weights = edge_weights.coalesce()
88-
edge_index = edge_weights.indices().long()
89-
weights = edge_weights.values().float()
90-
# weights = torch.reciprocal(edge_weights.values().float())
91-
if pos == None:
92-
data = torch_geometric.data.Data(
93-
x=feature,
94-
edge_index=edge_index.clone().detach(),
95-
edge_weight=weights.clone().detach()
96-
)
97-
else:
98-
data = torch_geometric.data.Data(
99-
x=feature,
100-
edge_index=edge_index.clone().detach(),
101-
edge_weight=weights.clone().detach(),
102-
pos=pos.clone().detach()
103-
)
104-
return data
27+
28+
## @ingroup python_api
29+
class DataLoaderUtils:
30+
"""
31+
Class to provide utility functions to load tensors and configuration files
32+
"""
33+
34+
def load_tensor(path):
35+
"""
36+
Function to load a tensor from a file
37+
Can load tensors from jit script format files
38+
39+
Args:
40+
path (str): Path to the file
41+
42+
Returns:
43+
tensor: The loaded tensor
44+
None: If the file does not exist
45+
46+
Raises:
47+
FileNotFoundError: If the file does not exist
48+
"""
49+
# Throw error if path does not exist
50+
if not os.path.exists(path):
51+
raise FileNotFoundError(f"DataLoaderUtils::load_tensor: File not found: {path}")
52+
# Load data
53+
data = torch.load(path)
54+
# Extract tensor if data is in jit script format
55+
if isinstance(data, torch.jit.ScriptModule):
56+
tensor = list(data.parameters())[0]
57+
else:
58+
tensor = data
59+
return tensor
60+
61+
def load_yaml(path):
62+
"""
63+
Function to load a yaml file
64+
65+
Args:
66+
path (str): Path to the file
67+
68+
Returns:
69+
data: The loaded data
70+
Raises:
71+
FileNotFoundError: If the file does not exist
72+
"""
73+
74+
75+
# Throw error if path does not exist
76+
if not os.path.exists(path):
77+
raise FileNotFoundError(f"DataLoaderUtils::load_yaml File not found: {path}")
78+
# Load data
79+
with open(path, "r") as f:
80+
data = yaml.load(f, Loader=yaml.FullLoader)
81+
return data
82+
83+
def LoadToml(path):
84+
# Throw error if path does not exist
85+
if not os.path.exists(path):
86+
raise FileNotFoundError(f"data_loader_utils::LoadToml: File not found: {path}")
87+
# Load data
88+
with open(path, "rb") as f:
89+
data = tomllib.load(f)
90+
return data
91+
92+
def LoadMaps(path, use_comm_map):
93+
local_maps = load_tensor(f"{path}/local_maps.pt")
94+
local_maps = local_maps.to_dense().unsqueeze(2)
95+
obstacle_maps = load_tensor(f"{path}/obstacle_maps.pt")
96+
obstacle_maps = obstacle_maps.to_dense().unsqueeze(2)
97+
98+
if use_comm_map:
99+
comm_maps = load_tensor(f"{path}/comm_maps.pt")
100+
comm_maps = comm_maps.to_dense()
101+
# comm_maps = (comm_maps * 256 + 256)/512
102+
maps = torch.cat([local_maps, comm_maps, obstacle_maps], 2)
103+
else:
104+
maps = torch.cat([local_maps, obstacle_maps], 2)
105+
return maps
106+
107+
def LoadFeatures(path, output_dim = None):
108+
normalized_coverage_features = load_tensor(f"{path}/normalized_coverage_features.pt")
109+
coverage_features_mean = load_tensor(f"{path}/../coverage_features_mean.pt")
110+
coverage_features_std = load_tensor(f"{path}/../coverage_features_std.pt")
111+
if output_dim is not None:
112+
normalized_coverage_features = normalized_coverage_features[:, :, :output_dim]
113+
return normalized_coverage_features, coverage_features_mean, coverage_features_std
114+
115+
def LoadActions(path):
116+
actions = load_tensor(f"{path}/normalized_actions.pt")
117+
actions_mean = load_tensor(f"{path}/../actions_mean.pt")
118+
actions_std = load_tensor(f"{path}/../actions_std.pt")
119+
return actions, actions_mean, actions_std
120+
121+
def LoadRobotPositions(path):
122+
robot_positions = load_tensor(f"{path}/robot_positions.pt")
123+
return robot_positions
124+
125+
def LoadEdgeWeights(path):
126+
edge_weights = load_tensor(f"{path}/edge_weights.pt")
127+
edge_weights.to_dense()
128+
return edge_weights
129+
130+
def ToTorchGeometricData(feature, edge_weights, pos = None):
131+
# senders, receivers = numpy.nonzero(edge_weights)
132+
# weights = edge_weights[senders, receivers]
133+
# edge_index = numpy.stack([senders, receivers])
134+
edge_weights = edge_weights.to_sparse()
135+
edge_weights = edge_weights.coalesce()
136+
edge_index = edge_weights.indices().long()
137+
weights = edge_weights.values().float()
138+
# weights = torch.reciprocal(edge_weights.values().float())
139+
if pos == None:
140+
data = torch_geometric.data.Data(
141+
x=feature,
142+
edge_index=edge_index.clone().detach(),
143+
edge_weight=weights.clone().detach()
144+
)
145+
else:
146+
data = torch_geometric.data.Data(
147+
x=feature,
148+
edge_index=edge_index.clone().detach(),
149+
edge_weight=weights.clone().detach(),
150+
pos=pos.clone().detach()
151+
)
152+
return data

python/coverage_control/nn/data_loaders/data_loaders.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
from torch_geometric.data import Dataset
77

88

9+
## \defgroup python_api_data_loaders Data Loaders
10+
# \ingroup python_api
11+
## \brief Data loaders for training and testing
12+
13+
## \ingroup python_api_data_loaders
914
class LocalMapCNNDataset(Dataset):
1015
"""
1116
Dataset for CNN training
@@ -41,6 +46,7 @@ def LoadData(self):
4146
self.targets, self.targets_mean, self.targets_std = dl_utils.LoadActions(f"{self.data_dir}/{self.stage}")
4247
self.targets = self.targets.view(-1, self.targets.shape[2])
4348

49+
## \ingroup python_api_data_loaders
4450
class LocalMapGNNDataset(Dataset):
4551
"""
4652
Dataset for hybrid CNN-GNN training
@@ -51,7 +57,7 @@ def __init__(self, data_dir, stage):
5157
self.stage = stage
5258

5359
# Coverage maps is of shape (num_samples, num_robots, 2, image_size, image_size)
54-
self.coverage_maps = dl_utils.LoadTensor(f"{data_dir}/{stage}/coverage_maps.pt")
60+
self.coverage_maps = dl_utils.load_tensor(f"{data_dir}/{stage}/coverage_maps.pt")
5561
self.num_robots = self.coverage_maps.shape[1]
5662
self.dataset_size = self.coverage_maps.shape[0]
5763
self.targets, self.targets_mean, self.targets_std = dl_utils.LoadActions(f"{data_dir}/{stage}")

python/tests/coverage_simple.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import sys
2-
import CoverageControl # Main library
2+
from coverage_control import _core as cc # Main library
33
# Algorithms available:
44
# ClairvoyantCVT
55
# CentralizedCVT
66
# DecentralizedCVT
77
# NearOptimalCVT
8-
from CoverageControl import ClairvoyantCVT as CoverageAlgorithm
8+
from coverage_control._core import ClairvoyantCVT as CoverageAlgorithm
99

10-
params = CoverageControl.Parameters()
10+
params = cc.Parameters()
1111

1212
# CoverageSystem handles the environment and robots
13-
env = CoverageControl.CoverageSystem(params)
13+
env = cc.CoverageSystem(params)
1414

1515
init_cost = env.GetObjectiveValue()
1616
print("Initial Coverage cost: " + str('{:.2e}'.format(init_cost)))

python/tests/test_coverage.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,35 @@
1+
# This file is part of the CoverageControl library
2+
#
3+
# Author: Saurav Agarwal
4+
5+
# Repository: https://github.com/KumarRobotics/CoverageControl
6+
#
7+
# Copyright (c) 2024, Saurav Agarwal
8+
#
9+
# The CoverageControl library is free software: you can redistribute it and/or
10+
# modify it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or (at your
12+
# option) any later version.
13+
#
14+
# The CoverageControl library is distributed in the hope that it will be
15+
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17+
# Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License along with
20+
# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21+
122
from __future__ import annotations
223

324
import importlib.metadata
425
import test as m
5-
from testcoverage import _core as cc # Main library
26+
from coverage_control import _core as cc # Main library
627
# Algorithms available:
728
# ClairvoyantCVT
829
# CentralizedCVT
930
# DecentralizedCVT
1031
# NearOptimalCVT
11-
from testcoverage._core import ClairvoyantCVT as CoverageAlgorithm
32+
from coverage_control._core import ClairvoyantCVT as CoverageAlgorithm
1233

1334
def test_coverage_system():
1435
params = cc.Parameters()

python/tests/test_package.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import importlib.metadata
4-
import testcoverage as m
4+
import coverage_control as m
55

66

77
def test_version():
8-
assert importlib.metadata.version("testcoverage") == m.__version__
8+
assert importlib.metadata.version("coverage_control") == m.__version__

0 commit comments

Comments
 (0)