Skip to content

Commit 5975eca

Browse files
lihuoranJinyu-W
andauthored
Fix reset random seed bug (#387)
* update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
1 parent 790ed55 commit 5975eca

File tree

17 files changed

+86
-78
lines changed

17 files changed

+86
-78
lines changed

docs/source/apidoc/maro.utils.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ maro.utils.exception.communication\_exception
4545
:undoc-members:
4646
:show-inheritance:
4747

48-
maro.utils.exception.data\_lib\_exeption
48+
maro.utils.exception.data\_lib\_exception
4949
--------------------------------------------------------------------------------
5050

51-
.. automodule:: maro.utils.exception.data_lib_exeption
51+
.. automodule:: maro.utils.exception.data_lib_exception
5252
:members:
5353
:undoc-members:
5454
:show-inheritance:

maro/data_lib/cim/cim_data_container.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from math import ceil
77
from typing import Dict, List
88

9+
from maro.simulator.utils import random
10+
911
from .entities import (
1012
CimBaseDataCollection, CimRealDataCollection, CimSyntheticDataCollection, NoisedItem, Order, OrderGenerateMode,
1113
PortSetting, VesselSetting
1214
)
1315
from .port_buffer_tick_wrapper import PortBufferTickWrapper
14-
from .utils import (
15-
apply_noise, buffer_tick_rand, get_buffer_tick_seed, get_order_num_seed, list_sum_normalize, order_num_rand
16-
)
16+
from .utils import BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY, apply_noise, list_sum_normalize
1717
from .vessel_future_stops_prediction import VesselFutureStopsPrediction
1818
from .vessel_past_stops_wrapper import VesselPastStopsWrapper
1919
from .vessel_reachable_stops_wrapper import VesselReachableStopsWrapper
@@ -60,9 +60,6 @@ def __init__(self, data_collection: CimBaseDataCollection):
6060
self._vessel_plan_wrapper = VesselSailingPlanWrapper(self._data_collection)
6161
self._reachable_stops_wrapper = VesselReachableStopsWrapper(self._data_collection)
6262

63-
# keep the seed so we can reproduce the sequence after reset
64-
self._buffer_tick_seed: int = get_buffer_tick_seed()
65-
6663
# flag to tell if we need to reset seed, we need this flag as outside may set the seed after env.reset
6764
self._is_need_reset_seed = False
6865

@@ -245,7 +242,7 @@ def reset(self):
245242

246243
def _reset_seed(self):
247244
"""Reset internal seed for generate reproduceable data"""
248-
buffer_tick_rand.seed(self._buffer_tick_seed)
245+
random.reset_seed(BUFFER_TICK_RAND_KEY)
249246

250247
@abstractmethod
251248
def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
@@ -272,9 +269,6 @@ class CimSyntheticDataContainer(CimBaseDataContainer):
272269
def __init__(self, data_collection: CimSyntheticDataCollection):
273270
super().__init__(data_collection)
274271

275-
# keep the seed so we can reproduce the sequence after reset
276-
self._order_num_seed: int = get_order_num_seed()
277-
278272
# TODO: get_events which composed with arrive, departure and order
279273

280274
def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
@@ -303,7 +297,7 @@ def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
303297
def _reset_seed(self):
304298
"""Reset internal seed for generate reproduceable data"""
305299
super()._reset_seed()
306-
order_num_rand.seed(self._order_num_seed)
300+
random.reset_seed(ORDER_NUM_RAND_KEY)
307301

308302
def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
309303
"""Generate order for specified tick.
@@ -339,7 +333,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
339333
for port_idx in range(self.port_number):
340334
source_dist: NoisedItem = self.ports[port_idx].source_proportion
341335

342-
noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, order_num_rand)
336+
noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, random[ORDER_NUM_RAND_KEY])
343337

344338
noised_source_order_dist.append(noised_source_order_number)
345339

@@ -356,7 +350,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
356350

357351
# apply noise and normalize
358352
noised_targets_dist = list_sum_normalize(
359-
[apply_noise(target.base, target.noise, order_num_rand) for target in targets_dist])
353+
[apply_noise(target.base, target.noise, random[ORDER_NUM_RAND_KEY]) for target in targets_dist])
360354

361355
# order for current ports
362356
cur_port_order_num = ceil(orders_to_gen * noised_source_order_dist[port_idx])

maro/data_lib/cim/cim_data_container_helpers.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import urllib.parse
66

77
from maro.cli.data_pipeline.utils import StaticParameter
8-
from maro.simulator.utils import seed
8+
from maro.simulator.utils import random, seed
99

1010
from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer
1111
from .cim_data_generator import CimDataGenerator
1212
from .cim_data_loader import load_from_folder, load_real_data_from_folder
13+
from .utils import DATA_CONTAINER_INIT_SEED_LIMIT, ROUTE_INIT_RAND_KEY
1314

1415

1516
class CimDataContainerWrapper:
@@ -28,22 +29,26 @@ def __init__(self, config_path: str, max_tick: int, topology: str):
2829

2930
self._init_data_container()
3031

31-
def _init_data_container(self):
32+
def _init_data_container(self, topology_seed: int = None):
3233
if not os.path.exists(self._config_path):
3334
raise FileNotFoundError
3435
# Synthetic Data Mode: config.yml must exist.
3536
config_path = os.path.join(self._config_path, "config.yml")
3637
if os.path.exists(config_path):
3738
self._data_cntr = data_from_generator(
38-
config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick
39+
config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick,
40+
topology_seed=topology_seed
3941
)
4042
else:
4143
# Real Data Mode: read data from input data files, no need for any config.yml.
4244
self._data_cntr = data_from_files(data_folder=self._config_path)
4345

44-
def reset(self):
46+
def reset(self, keep_seed):
4547
"""Reset data container internal state"""
46-
self._data_cntr.reset()
48+
if not keep_seed:
49+
self._init_data_container(random[ROUTE_INIT_RAND_KEY].randint(0, DATA_CONTAINER_INIT_SEED_LIMIT - 1))
50+
else:
51+
self._data_cntr.reset()
4752

4853
def __getattr__(self, name):
4954
return getattr(self._data_cntr, name)
@@ -68,20 +73,23 @@ def data_from_dumps(dumps_folder: str) -> CimSyntheticDataContainer:
6873
return CimSyntheticDataContainer(data_collection)
6974

7075

71-
def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0) -> CimSyntheticDataContainer:
76+
def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0,
77+
topology_seed: int = None) -> CimSyntheticDataContainer:
7278
"""Collect data from data generator with configurations.
7379
7480
Args:
7581
config_path(str): Path of configuration file (yaml).
7682
max_tick (int): Max tick to generate data.
7783
start_tick(int): Start tick to generate data.
84+
topology_seed(int): Random seed of the business engine. \
85+
'None' means using the seed in the configuration file.
7886
7987
Returns:
8088
CimSyntheticDataContainer: Data container used to provide cim data related interfaces.
8189
"""
8290
edg = CimDataGenerator()
8391

84-
data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick)
92+
data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick, topology_seed=topology_seed)
8593

8694
return CimSyntheticDataContainer(data_collection)
8795

maro/data_lib/cim/cim_data_dump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def dump_from_config(config_file: str, output_folder: str, max_tick: int):
247247

248248
generator = CimDataGenerator()
249249

250-
data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0)
250+
data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0, topology_seed=None)
251251

252252
dump_util = CimDataDumpUtil(data_collection)
253253

maro/data_lib/cim/cim_data_generator.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
from yaml import safe_load
88

9-
from maro.simulator.utils import seed
10-
from maro.utils.exception.data_lib_exeption import CimGeneratorInvalidParkingDuration
9+
from maro.simulator.utils import random, seed
10+
from maro.utils.exception.data_lib_exception import CimGeneratorInvalidParkingDuration
1111

1212
from .entities import CimSyntheticDataCollection, OrderGenerateMode, Stop
1313
from .global_order_proportion import GlobalOrderProportion
1414
from .port_parser import PortsParser
1515
from .route_parser import RoutesParser
16-
from .utils import apply_noise, route_init_rand
16+
from .utils import ROUTE_INIT_RAND_KEY, apply_noise
1717
from .vessel_parser import VesselsParser
1818

1919
CIM_GENERATOR_VERSION = 0x000001
@@ -29,13 +29,19 @@ def __init__(self):
2929
self._routes_parser = RoutesParser()
3030
self._global_order_proportion = GlobalOrderProportion()
3131

32-
def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0) -> CimSyntheticDataCollection:
32+
def gen_data(
33+
self, config_file: str, max_tick: int,
34+
start_tick: int = 0,
35+
topology_seed: int = None
36+
) -> CimSyntheticDataCollection:
3337
"""Generate data with specified configurations.
3438
3539
Args:
3640
config_file(str): File of configuration (yaml).
3741
max_tick(int): Max tick to generate.
3842
start_tick(int): Start tick to generate.
43+
topology_seed(int): Random seed of the business engine. \
44+
'None' means using the seed in the configuration file.
3945
4046
Returns:
4147
CimSyntheticDataCollection: Data collection contains all cim data.
@@ -45,7 +51,8 @@ def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0) -> CimS
4551
with open(config_file, "r") as fp:
4652
conf: dict = safe_load(fp)
4753

48-
topology_seed = conf["seed"]
54+
if topology_seed is None:
55+
topology_seed = conf["seed"]
4956

5057
# set seed to generate data
5158
seed(topology_seed)
@@ -146,7 +153,7 @@ def _extend_route(
146153
port_idx = port_mapping[cur_route_point.port_name]
147154

148155
# apply noise to parking duration
149-
parking_duration = ceil(apply_noise(duration, duration_noise, route_init_rand))
156+
parking_duration = ceil(apply_noise(duration, duration_noise, random[ROUTE_INIT_RAND_KEY]))
150157

151158
if parking_duration <= 0:
152159
raise CimGeneratorInvalidParkingDuration()
@@ -165,7 +172,7 @@ def _extend_route(
165172
distance_to_next_port = cur_route_point.distance_to_next_port
166173

167174
# apply noise to speed
168-
noised_speed = apply_noise(speed, speed_noise, route_init_rand)
175+
noised_speed = apply_noise(speed, speed_noise, random[ROUTE_INIT_RAND_KEY])
169176
sailing_duration = ceil(distance_to_next_port / noised_speed)
170177

171178
# next tick

maro/data_lib/cim/global_order_proportion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import numpy as np
88

9-
from .utils import apply_noise, clip, order_init_rand
9+
from maro.simulator.utils import random
10+
11+
from .utils import ORDER_INIT_RAND_KEY, apply_noise, clip
1012

1113

1214
class GlobalOrderProportion:
@@ -59,7 +61,7 @@ def parse(self, conf: dict, total_container: int, max_tick: int, start_tick: int
5961
# apply noise if the distribution not zero
6062
if orders != 0:
6163
if noise != 0:
62-
orders = apply_noise(orders, noise, order_init_rand)
64+
orders = apply_noise(orders, noise, random[ORDER_INIT_RAND_KEY])
6365

6466
# clip and gen order
6567
orders = floor(clip(0, 1, orders) * total_container)

maro/data_lib/cim/port_buffer_tick_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
from math import ceil
55

6+
from maro.simulator.utils import random
7+
68
from .entities import CimBaseDataCollection, NoisedItem, PortSetting
7-
from .utils import apply_noise, buffer_tick_rand
9+
from .utils import BUFFER_TICK_RAND_KEY, apply_noise
810

911

1012
class PortBufferTickWrapper:
@@ -29,4 +31,4 @@ def __getitem__(self, key):
2931

3032
buffer_setting: NoisedItem = self._attribute_func(port)
3133

32-
return ceil(apply_noise(buffer_setting.base, buffer_setting.noise, buffer_tick_rand))
34+
return ceil(apply_noise(buffer_setting.base, buffer_setting.noise, random[BUFFER_TICK_RAND_KEY]))

maro/data_lib/cim/utils.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
from random import Random
45
from typing import List, Union
56

6-
from maro.simulator.utils.sim_random import SimRandom, random
7-
87
# we keep 4 random generator to make the result is reproduceable with same seed(s), no matter if agent passed actions
9-
route_init_rand = random["route_init"]
10-
order_init_rand = random["order_init"]
11-
buffer_tick_rand = random["buffer_time"]
12-
order_num_rand = random["order_number"]
13-
14-
15-
def get_buffer_tick_seed():
16-
return random.get_seed("buffer_time")
17-
8+
ROUTE_INIT_RAND_KEY = "route_init"
9+
ORDER_INIT_RAND_KEY = "order_init"
10+
BUFFER_TICK_RAND_KEY = "buffer_time"
11+
ORDER_NUM_RAND_KEY = "order_number"
1812

19-
def get_order_num_seed():
20-
return random.get_seed("order_number")
13+
DATA_CONTAINER_INIT_SEED_LIMIT = 4096
2114

2215

2316
def clip(min_val: Union[int, float], max_val: Union[int, float], value: Union[int, float]) -> Union[int, float]:
@@ -34,7 +27,7 @@ def clip(min_val: Union[int, float], max_val: Union[int, float], value: Union[in
3427
return max(min_val, min(max_val, value))
3528

3629

37-
def apply_noise(value: Union[int, float], noise: Union[int, float], rand: SimRandom) -> float:
30+
def apply_noise(value: Union[int, float], noise: Union[int, float], rand: Random) -> float:
3831
"""Apply noise with specified random generator
3932
4033
Args:

maro/data_lib/item_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from yaml import SafeDumper, SafeLoader, YAMLObject, safe_dump, safe_load
1313

1414
from maro.data_lib.common import dtype_pack_map
15-
from maro.utils.exception.data_lib_exeption import MetaTimestampNotExist
15+
from maro.utils.exception.data_lib_exception import MetaTimestampNotExist
1616

1717

1818
class EntityAttr(YAMLObject):

maro/simulator/core.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,12 @@ def dump(self):
101101
"""
102102
return
103103

104-
def reset(self):
105-
"""Reset environment."""
104+
def reset(self, keep_seed: bool = False):
105+
"""Reset environment.
106+
107+
Args:
108+
keep_seed (bool): Reset the random seed to the generate the same data sequence or not. Defaults to False.
109+
"""
106110
self._tick = self._start_tick
107111

108112
self._simulate_generator.close()
@@ -120,7 +124,7 @@ def reset(self):
120124

121125
self._decision_events.clear()
122126

123-
self._business_engine.reset()
127+
self._business_engine.reset(keep_seed)
124128

125129
@property
126130
def configs(self) -> dict:

0 commit comments

Comments
 (0)