Skip to content

Commit 7f6915e

Browse files
committed
Initial VIN implementation
1 parent d4a4a90 commit 7f6915e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+5996
-32
lines changed
13.5 KB
Binary file not shown.
13.5 KB
Binary file not shown.
13.5 KB
Binary file not shown.

src/algorithms/algorithm_manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
from algorithms.classic.sample_based.rrt_star import RRT_Star
3131
from algorithms.classic.sample_based.rrt_connect import RRT_Connect
3232
from algorithms.classic.graph_based.wavefront import Wavefront
33-
from algorithms.lstm.LSTM_tile_by_tile import OnlineLSTM
34-
from algorithms.lstm.a_star_waypoint import WayPointNavigation
35-
from algorithms.lstm.combined_online_LSTM import CombinedOnlineLSTM
33+
from algorithms.learning.LSTM_tile_by_tile import OnlineLSTM
34+
from algorithms.learning.a_star_waypoint import WayPointNavigation
35+
from algorithms.learning.combined_online_LSTM import CombinedOnlineLSTM
36+
from algorithms.learning.VIN.VIN import VINAlgorithm
3637

3738
if HAS_OMPL:
3839
from algorithms.classic.sample_based.ompl_rrt import OMPL_RRT
@@ -103,7 +104,8 @@ def _static_init_(cls):
103104
"Dijkstra": (Dijkstra, DijkstraTesting, ([], {})),
104105
"Bug1": (Bug1, BasicTesting, ([], {})),
105106
"Bug2": (Bug2, BasicTesting, ([], {})),
106-
"Potential Field": (PotentialField, BasicTesting, ([], {}))
107+
"Potential Field": (PotentialField, BasicTesting, ([], {})),
108+
"VIN": (VINAlgorithm, BasicTesting, ([], {"load_name": "vin_pretrained"}))
107109
}
108110

109111
if HAS_OMPL:

src/algorithms/classic/testing/way_point_navigation_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from algorithms.basic_testing import BasicTesting
6-
from algorithms.lstm.combined_online_LSTM import CombinedOnlineLSTM
6+
from algorithms.learning.combined_online_LSTM import CombinedOnlineLSTM
77
from simulator.services.debug import DebugLevel
88

99

src/algorithms/configuration/configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from algorithms.algorithm import Algorithm
55
from algorithms.basic_testing import BasicTesting
66
from algorithms.configuration.maps.map import Map
7-
from algorithms.lstm.LSTM_tile_by_tile import BasicLSTMModule
8-
from algorithms.lstm.ML_model import MLModel
7+
from algorithms.learning.LSTM_tile_by_tile import BasicLSTMModule
8+
from algorithms.learning.ML_model import MLModel
99
from simulator.services.debug import DebugLevel
1010

1111
class Configuration:

src/algorithms/lstm/LSTM_CAE_tile_by_tile.py renamed to src/algorithms/learning/LSTM_CAE_tile_by_tile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from algorithms.basic_testing import BasicTesting
1717
from algorithms.configuration.maps.map import Map
18-
from algorithms.lstm.LSTM_tile_by_tile import BasicLSTMModule, OnlineLSTM
19-
from algorithms.lstm.ML_model import MLModel, EvaluationResults
20-
from algorithms.lstm.map_processing import MapProcessing
18+
from algorithms.learning.LSTM_tile_by_tile import BasicLSTMModule, OnlineLSTM
19+
from algorithms.learning.ML_model import MLModel, EvaluationResults
20+
from algorithms.learning.map_processing import MapProcessing
2121
from simulator.services.services import Services
2222
from utility.constants import DATA_PATH
2323

src/algorithms/lstm/LSTM_CNN_tile_by_tile_obsolete.py renamed to src/algorithms/learning/LSTM_CNN_tile_by_tile_obsolete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn
1111
1212
from algorithms.basic_testing import BasicTesting
13-
from algorithms.lstm.online_lstm import BasicLSTMModule, OnlineLSTM
13+
from algorithms.learning.online_lstm import BasicLSTMModule, OnlineLSTM
1414
from simulator.services.services import Services
1515
1616

src/algorithms/lstm/LSTM_tile_by_tile.py renamed to src/algorithms/learning/LSTM_tile_by_tile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from algorithms.basic_testing import BasicTesting
1111
from algorithms.configuration.entities.goal import Goal
1212
from algorithms.configuration.maps.map import Map
13-
from algorithms.lstm.ML_model import MLModel, SingleTensorDataset, PackedDataset
14-
from algorithms.lstm.map_processing import MapProcessing
13+
from algorithms.learning.ML_model import MLModel, SingleTensorDataset, PackedDataset
14+
from algorithms.learning.map_processing import MapProcessing
1515
from simulator.services.services import Services
1616
from simulator.views.map.display.entities_map_display import EntitiesMapDisplay
1717
from simulator.views.map.display.online_lstm_map_display import OnlineLSTMMapDisplay

src/algorithms/lstm/ML_model.py renamed to src/algorithms/learning/ML_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence, pack_sequence, PackedSequence
1313
from torch.utils import data
1414
from torch.utils.data import DataLoader, TensorDataset, Dataset, Subset
15-
from algorithms.lstm.map_processing import MapProcessing
15+
from algorithms.learning.map_processing import MapProcessing
1616
from simulator.services.debug import DebugLevel
1717
from simulator.services.services import Services
1818

@@ -154,7 +154,7 @@ class PackedDataset(Dataset):
154154
lengths: torch.Tensor
155155

156156
def __init__(self, seq: List[torch.Tensor]) -> None:
157-
from algorithms.lstm.LSTM_tile_by_tile import BasicLSTMModule
157+
from algorithms.learning.LSTM_tile_by_tile import BasicLSTMModule
158158

159159
ls = list(map(lambda el: el.shape[0], seq))
160160
self.perm = BasicLSTMModule.get_sort_by_lengths_indices(ls)

0 commit comments

Comments
 (0)