Skip to content

Commit 56fcfa2

Browse files
authored
CIM scenario refinement (#400)
* Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch
1 parent 39aaa92 commit 56fcfa2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+4996
-1790
lines changed

.gitignore

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,5 @@ data/
2222
maro_venv/
2323
pyvenv.cfg
2424
htmlcov/
25-
.coverage
26-
27-
.coveragerc
25+
.coverage
26+
.coveragerc

docs/source/examples/multi_agent_dqn_cim.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ in the roll-out loop. In this example,
6969
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
7070
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
7171
else:
72-
actual_action, action_type = 0, None
72+
actual_action, action_type = 0, ActionType.LOAD
7373
7474
return {port: Action(vessel, port, actual_action, action_type)}
7575

docs/source/scenarios/container_inventory_management.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ Once we get a ``DecisionEvent`` from the environment, we should respond with an
572572
* **vessel_idx** (int): The id of the vessel/operation object of the port/agent.
573573
* **port_idx** (int): The id of the port/agent that take this action.
574574
* **action_type** (ActionType): Whether to load or discharge empty containers in this action.
575-
* **quantity** (int): The quantity of empty containers to be loaded/discharged.
575+
* **quantity** (int): The (non-negative) quantity of empty containers to be loaded/discharged.
576576

577577
Example
578578
^^^^^^^

examples/cim/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class CIMTrajectory(Trajectory):
3030
def __init__(
3131
self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,
3232
reward_time_window, fulfillment_factor, shortage_factor, time_decay,
33-
finite_vessel_space=True, has_early_discharge=True
33+
finite_vessel_space=True, has_early_discharge=True
3434
):
3535
super().__init__(env)
3636
self.port_attributes = port_attributes
@@ -72,7 +72,7 @@ def get_action(self, action_by_agent, event):
7272
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
7373
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
7474
else:
75-
actual_action, action_type = 0, None
75+
actual_action, action_type = 0, ActionType.LOAD
7676

7777
return {port: Action(vessel, port, actual_action, action_type)}
7878

maro/data_lib/cim/cim_data_container.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .entities import (
1212
CimBaseDataCollection, CimRealDataCollection, CimSyntheticDataCollection, NoisedItem, Order, OrderGenerateMode,
13-
PortSetting, VesselSetting
13+
PortSetting, SyntheticPortSetting, VesselSetting
1414
)
1515
from .port_buffer_tick_wrapper import PortBufferTickWrapper
1616
from .utils import BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY, apply_noise, list_sum_normalize
@@ -42,7 +42,7 @@ class CimBaseDataContainer(ABC):
4242
Args:
4343
data_collection (CimBaseDataCollection): Corresponding data collection.
4444
"""
45-
def __init__(self, data_collection: CimBaseDataCollection):
45+
def __init__(self, data_collection: CimBaseDataCollection) -> None:
4646
self._data_collection = data_collection
4747

4848
# wrapper for interfaces, to make it easy to use
@@ -151,7 +151,7 @@ def full_return_buffers(self) -> PortBufferTickWrapper:
151151
.. code-block:: python
152152
153153
# Get full return buffer tick of port 0.
154-
buffer_tick = data_cnr.full_return_buffers[0]
154+
buffer_tick = data_cntr.full_return_buffers[0]
155155
"""
156156
return self._full_return_buffer_wrapper
157157

@@ -209,7 +209,7 @@ def reachable_stops(self) -> VesselReachableStopsWrapper:
209209
return self._reachable_stops_wrapper
210210

211211
@property
212-
def vessel_period(self) -> int:
212+
def vessel_period(self) -> List[int]:
213213
"""Wrapper to get vessel's planned sailing period (without noise to complete a whole route).
214214
215215
Examples:
@@ -241,7 +241,7 @@ def reset(self):
241241
self._is_need_reset_seed = True
242242

243243
def _reset_seed(self):
244-
"""Reset internal seed for generate reproduceable data"""
244+
"""Reset internal seed for generate reproduce-able data"""
245245
random.reset_seed(BUFFER_TICK_RAND_KEY)
246246

247247
@abstractmethod
@@ -288,14 +288,14 @@ def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
288288

289289
self._is_need_reset_seed = False
290290

291-
if tick >= self._data_collection.max_tick:
291+
if tick >= self._data_collection.max_tick: # pragma: no cover
292292
warnings.warn(f"{tick} out of max tick {self._data_collection.max_tick}")
293293
return []
294294

295295
return self._gen_orders(tick, total_empty_container)
296296

297297
def _reset_seed(self):
298-
"""Reset internal seed for generate reproduceable data"""
298+
"""Reset internal seed for generate reproduce-able data"""
299299
super()._reset_seed()
300300
random.reset_seed(ORDER_NUM_RAND_KEY)
301301

@@ -308,6 +308,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
308308
"""
309309
# result
310310
order_list: List[Order] = []
311+
assert isinstance(self._data_collection, CimSyntheticDataCollection)
311312
order_proportion = self._data_collection.order_proportion
312313
order_mode = self._data_collection.order_mode
313314
total_containers = self._data_collection.total_containers
@@ -316,7 +317,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
316317
orders_to_gen = int(order_proportion[tick])
317318

318319
# if under unfixed mode, we will consider current empty container as factor
319-
if order_mode == OrderGenerateMode.UNFIXED:
320+
if order_mode == OrderGenerateMode.UNFIXED: # pragma: no cover. TODO: remove this mark later
320321
delta = total_containers - total_empty_container
321322

322323
if orders_to_gen <= delta:
@@ -331,7 +332,9 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
331332

332333
# calculate orders distribution for each port as source
333334
for port_idx in range(self.port_number):
334-
source_dist: NoisedItem = self.ports[port_idx].source_proportion
335+
port = self.ports[port_idx]
336+
assert isinstance(port, SyntheticPortSetting)
337+
source_dist: NoisedItem = port.source_proportion
335338

336339
noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, random[ORDER_NUM_RAND_KEY])
337340

@@ -346,7 +349,9 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
346349
if remaining_orders == 0:
347350
break
348351

349-
targets_dist: List[NoisedItem] = self.ports[port_idx].target_proportions
352+
port = self.ports[port_idx]
353+
assert isinstance(port, SyntheticPortSetting)
354+
targets_dist: List[NoisedItem] = port.target_proportions
350355

351356
# apply noise and normalize
352357
noised_targets_dist = list_sum_normalize(
@@ -403,6 +408,7 @@ def __init__(self, data_collection: CimRealDataCollection):
403408
super().__init__(data_collection)
404409

405410
# orders
411+
assert isinstance(self._data_collection, CimRealDataCollection)
406412
self._orders: Dict[int, List[Order]] = self._data_collection.orders
407413

408414
def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
@@ -422,11 +428,8 @@ def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
422428

423429
self._is_need_reset_seed = False
424430

425-
if tick >= self._data_collection.max_tick:
431+
if tick >= self._data_collection.max_tick: # pragma: no cover
426432
warnings.warn(f"{tick} out of max tick {self._data_collection.max_tick}")
427433
return []
428434

429-
if tick not in self._orders:
430-
return []
431-
432-
return self._orders[tick]
435+
return self._orders[tick] if tick in self._orders else []

maro/data_lib/cim/cim_data_container_helpers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33

44
import os
55
import urllib.parse
6+
from typing import Optional
67

78
from maro.cli.data_pipeline.utils import StaticParameter
89
from maro.simulator.utils import random, seed
910

1011
from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer
11-
from .cim_data_generator import CimDataGenerator
12+
from .cim_data_generator import gen_cim_data
1213
from .cim_data_loader import load_from_folder, load_real_data_from_folder
1314
from .utils import DATA_CONTAINER_INIT_SEED_LIMIT, ROUTE_INIT_RAND_KEY
1415

1516

1617
class CimDataContainerWrapper:
1718

1819
def __init__(self, config_path: str, max_tick: int, topology: str):
19-
self._data_cntr: CimBaseDataContainer = None
20+
self._data_cntr: Optional[CimBaseDataContainer] = None
2021
self._max_tick = max_tick
2122
self._config_path = config_path
2223
self._start_tick = 0
@@ -39,11 +40,13 @@ def _init_data_container(self, topology_seed: int = None):
3940
config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick,
4041
topology_seed=topology_seed
4142
)
43+
elif os.path.exists(os.path.join(self._config_path, "order_proportion.csv")):
44+
self._data_cntr = data_from_dumps(dumps_folder=self._config_path)
4245
else:
4346
# Real Data Mode: read data from input data files, no need for any config.yml.
4447
self._data_cntr = data_from_files(data_folder=self._config_path)
4548

46-
def reset(self, keep_seed):
49+
def reset(self, keep_seed: bool):
4750
"""Reset data container internal state"""
4851
if not keep_seed:
4952
self._init_data_container(random[ROUTE_INIT_RAND_KEY].randint(0, DATA_CONTAINER_INIT_SEED_LIMIT - 1))
@@ -87,9 +90,8 @@ def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0,
8790
Returns:
8891
CimSyntheticDataContainer: Data container used to provide cim data related interfaces.
8992
"""
90-
edg = CimDataGenerator()
91-
92-
data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick, topology_seed=topology_seed)
93+
data_collection = gen_cim_data(
94+
config_path, start_tick=start_tick, max_tick=max_tick, topology_seed=topology_seed)
9395

9496
return CimSyntheticDataContainer(data_collection)
9597

maro/data_lib/cim/cim_data_dump.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,25 @@
88
import numpy as np
99
from yaml import safe_dump
1010

11-
from .cim_data_generator import CimDataGenerator
12-
from .entities import CimSyntheticDataCollection
11+
from .cim_data_generator import gen_cim_data
12+
from .entities import CimSyntheticDataCollection, SyntheticPortSetting
13+
14+
15+
def _dump_csv_file(file_path: str, headers: List[str], line_generator: callable):
16+
"""helper method to dump csv file
17+
18+
Args:
19+
file_path(str): path of output csv file
20+
headers(List[str]): list of header
21+
line_generator(callable): generator function to generate line to write
22+
"""
23+
with open(file_path, "wt+", newline="") as fp:
24+
writer = csv.writer(fp)
25+
26+
writer.writerow(headers)
27+
28+
for line in line_generator():
29+
writer.writerow(line)
1330

1431

1532
class CimDataDumpUtil:
@@ -70,7 +87,7 @@ def stop_generator():
7087
stop.leave_tick
7188
]
7289

73-
self._dump_csv_file(stops_file_path, headers, stop_generator)
90+
_dump_csv_file(stops_file_path, headers, stop_generator)
7491

7592
def _dump_ports(self, output_folder: str):
7693
"""
@@ -86,6 +103,7 @@ def _dump_ports(self, output_folder: str):
86103

87104
def port_generator():
88105
for port in self._data_collection.port_settings:
106+
assert isinstance(port, SyntheticPortSetting)
89107
yield [
90108
port.index,
91109
port.name,
@@ -99,7 +117,7 @@ def port_generator():
99117
port.full_return_buffer.noise
100118
]
101119

102-
self._dump_csv_file(ports_file_path, headers, port_generator)
120+
_dump_csv_file(ports_file_path, headers, port_generator)
103121

104122
def _dump_vessels(self, output_folder: str):
105123
"""
@@ -137,7 +155,7 @@ def vessel_generator():
137155
vessel.empty
138156
]
139157

140-
self._dump_csv_file(vessels_file_path, headers, vessel_generator)
158+
_dump_csv_file(vessels_file_path, headers, vessel_generator)
141159

142160
def _dump_routes(self, output_folder: str, route_idx2name_dict: dict):
143161
"""
@@ -161,7 +179,7 @@ def route_generator():
161179
point.distance_to_next_port
162180
]
163181

164-
self._dump_csv_file(routes_file_path, headers, route_generator)
182+
_dump_csv_file(routes_file_path, headers, route_generator)
165183

166184
def _dump_order_proportions(self, output_folder: str, port_idx2name_dict: dict):
167185
"""
@@ -179,6 +197,7 @@ def _dump_order_proportions(self, output_folder: str, port_idx2name_dict: dict):
179197

180198
def order_prop_generator():
181199
for port in ports:
200+
assert isinstance(port, SyntheticPortSetting)
182201
for prop in port.target_proportions:
183202
yield [
184203
port.name,
@@ -189,7 +208,7 @@ def order_prop_generator():
189208
prop.noise
190209
]
191210

192-
self._dump_csv_file(proportion_file_path, headers, order_prop_generator)
211+
_dump_csv_file(proportion_file_path, headers, order_prop_generator)
193212

194213
def _dump_misc(self, output_folder: str):
195214
"""
@@ -213,22 +232,6 @@ def _dump_misc(self, output_folder: str):
213232
with open(misc_file_path, "wt+") as fp:
214233
safe_dump(misc_items, fp)
215234

216-
def _dump_csv_file(self, file_path: str, headers: List[str], line_generator: callable):
217-
"""helper method to dump csv file
218-
219-
Args:
220-
file_path(str): path of output csv file
221-
headers(List[str]): list of header
222-
line_generator(callable): generator function to generate line to write
223-
"""
224-
with open(file_path, "wt+", newline="") as fp:
225-
writer = csv.writer(fp)
226-
227-
writer.writerow(headers)
228-
229-
for line in line_generator():
230-
writer.writerow(line)
231-
232235

233236
def dump_from_config(config_file: str, output_folder: str, max_tick: int):
234237
"""Dump cim data from config, this will call data generator to generate data , and dump it.
@@ -245,9 +248,7 @@ def dump_from_config(config_file: str, output_folder: str, max_tick: int):
245248
assert output_folder is not None and os.path.exists(output_folder), f"Got output folder path: {output_folder}"
246249
assert max_tick is not None and max_tick > 0, f"Got max tick: {max_tick}"
247250

248-
generator = CimDataGenerator()
249-
250-
data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0, topology_seed=None)
251+
data_collection = gen_cim_data(config_file, max_tick=max_tick, start_tick=0, topology_seed=None)
251252

252253
dump_util = CimDataDumpUtil(data_collection)
253254

0 commit comments

Comments
 (0)