Skip to content

Commit d39ba7f

Browse files
committed
try to fix issue on windows for automatic classes
Signed-off-by: DONNOT Benjamin <[email protected]>
1 parent 4872e32 commit d39ba7f

File tree

3 files changed

+75
-32
lines changed

3 files changed

+75
-32
lines changed

.github/workflows/main.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ jobs:
197197
path: dist/*.tar.gz
198198

199199
auto_class_in_file:
200-
name: Test ${{ matrix.config.name }} OS can handle automatic class generation
200+
name: Test ${{ matrix.config.name }} OS can handle automatic class generation for python ${{matrix.python.version}}
201201
runs-on: ${{ matrix.config.os }}
202202
strategy:
203203
matrix:
@@ -238,10 +238,7 @@ jobs:
238238

239239
- name: Install Python dependencies
240240
run: |
241-
python -m pip install --upgrade pip
242-
python -m pip install --upgrade wheel
243-
python -m pip install --upgrade setuptools
244-
python -m pip install --upgrade gymnasium "numpy<2"
241+
python -m pip install --upgrade pip wheel setuptools gymnasium
245242
246243
- name: Build wheel
247244
run: python setup.py bdist_wheel
@@ -254,7 +251,7 @@ jobs:
254251
255252
- name: Test the automatic generation of classes in the env folder
256253
run: |
257-
python -m unittest grid2op/tests/automatic_classes.py -f
254+
python -m unittest grid2op/tests/automatic_classes.py -v -f
258255
259256
package:
260257
name: Test install

grid2op/VoltageControler/BaseVoltageController.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
# This file is part of Grid2Op, Grid2Op a testbed platform to model sequential decision making in power systems.
88
from abc import ABC, abstractmethod
99
import numpy as np
10-
import copy
1110

1211
from grid2op.dtypes import dt_int
13-
from grid2op.Action import VoltageOnlyAction, ActionSpace
12+
from grid2op.Action import VoltageOnlyAction
1413
from grid2op.Rules import AlwaysLegal
1514
from grid2op.Space import RandomObject
1615

grid2op/tests/automatic_classes.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
import os
1010
import multiprocessing as mp
11-
from typing import Optional
11+
from typing import Optional, Type
1212
import warnings
1313
import unittest
1414
import importlib
1515
import numpy as np
16-
from gymnasium.vector import AsyncVectorEnv
16+
from grid2op.Space import GridObjects
17+
from gymnasium.vector import AsyncVectorEnv # type: ignore
1718

1819
import grid2op
1920
from grid2op._glop_platform_info import _IS_WINDOWS
@@ -75,7 +76,7 @@ def act(self, observation: BaseObservation, reward: float, done: bool = False) -
7576
this_class_act = getattr(this_module, self._name_cls_act)
7677
else:
7778
raise RuntimeError(f"class {self._name_cls_act} not found")
78-
res = super().act(observation, reward, done)
79+
res = self.action_space()
7980
assert isinstance(res, this_class_act)
8081
return res
8182

@@ -85,7 +86,9 @@ class AutoClassMakeTester(unittest.TestCase):
8586
def test_in_make(self):
8687
with warnings.catch_warnings():
8788
warnings.filterwarnings("ignore")
88-
env = grid2op.make("l2rpn_case14_sandbox", test=True, class_in_file=False)
89+
env = grid2op.make("l2rpn_case14_sandbox",
90+
test=True,
91+
class_in_file=False)
8992
assert env._read_from_local_dir is None
9093
assert not env.classes_are_in_files()
9194

@@ -117,7 +120,9 @@ def _aux_make_env(self, env: Optional[Environment]=None):
117120
if env is None:
118121
with warnings.catch_warnings():
119122
warnings.filterwarnings("ignore")
120-
env = grid2op.make(self.get_env_name(), test=True, class_in_file=True)
123+
env = grid2op.make(self.get_env_name(),
124+
test=True,
125+
class_in_file=True)
121126
assert env.classes_are_in_files()
122127
return env
123128

@@ -176,7 +181,7 @@ def test_all_classes_from_file(self,
176181
for name_cls, name_attr in zip(names_cls, names_attr):
177182
this_module = importlib.import_module(f"{module_nm}.{name_cls}_file", super_module)
178183
if hasattr(this_module, name_cls):
179-
this_class = getattr(this_module, name_cls)
184+
this_class : Type[GridObjects]= getattr(this_module, name_cls)
180185
else:
181186
raise RuntimeError(f"class {name_cls} not found")
182187
if name_attr is not None:
@@ -382,7 +387,30 @@ def test_all_classes_from_file_runner_1ep(self, env: Optional[Environment]=None)
382387
env_seeds=[0],
383388
episode_id=[0])
384389

385-
def test_all_classes_from_file_runner_2ep_seq(self, env: Optional[Environment]=None):
390+
def _aux_test_rewards(self, res, _mix_id):
391+
if issubclass(AutoClassInFileTester, type(self)):
392+
ref = [645.702087, 648.907958]
393+
elif issubclass(TOEnvAutoClassTester, type(self)):
394+
ref = [645.702087, 648.907958]
395+
elif issubclass(MaskedEnvAutoClassTester, type(self)):
396+
ref = [645.702087, 648.907958]
397+
elif issubclass(MultiMixEnvAutoClassTester, type(self)):
398+
if _mix_id == 0:
399+
ref = [119.103179, 112.700851]
400+
elif _mix_id == 1:
401+
ref = [119.113189, 112.687309]
402+
else:
403+
raise RuntimeError("Unknown mix id")
404+
elif issubclass(ForEnvAutoClassTester, type(self)):
405+
ref = [53.9044303, 53.9044303]
406+
else:
407+
raise RuntimeError("Unknown test suite")
408+
assert abs(res[0][2] - ref[0]) <= 1e-5, f"{res[0][2]} vs {ref[0]}"
409+
assert abs(res[1][2] - ref[1]) <= 1e-5, f"{res[1][2]} vs {ref[1]}"
410+
411+
def test_all_classes_from_file_runner_2ep_seq(self,
412+
env: Optional[Environment]=None,
413+
_mix_id=None):
386414
"""this test that the runner is able to "run" (one other type of run), but the tests on the classes
387415
are much lighter than in test_all_classes_from_file_env_runner"""
388416
if not self._do_test_runner():
@@ -396,14 +424,19 @@ def test_all_classes_from_file_runner_2ep_seq(self, env: Optional[Environment]=N
396424
runner = Runner(**env.get_params_for_runner(),
397425
agentClass=None,
398426
agentInstance=this_agent)
427+
399428
res = runner.run(nb_episode=2,
400429
max_iter=self.max_iter,
430+
agent_seeds=[2, 3],
401431
env_seeds=[0, 0],
402432
episode_id=[0, 1])
403433
assert res[0][4] == self.max_iter
404434
assert res[1][4] == self.max_iter
405-
406-
def test_all_classes_from_file_runner_2ep_par_fork(self, env: Optional[Environment]=None):
435+
self._aux_test_rewards(res, _mix_id)
436+
437+
def test_all_classes_from_file_runner_2ep_par_fork(self,
438+
env: Optional[Environment]=None,
439+
_mix_id=None):
407440
"""this test that the runner is able to "run" (one other type of run), but the tests on the classes
408441
are much lighter than in test_all_classes_from_file_env_runner"""
409442
if not self._do_test_runner():
@@ -428,8 +461,11 @@ def test_all_classes_from_file_runner_2ep_par_fork(self, env: Optional[Environme
428461
episode_id=[0, 1])
429462
assert res[0][4] == self.max_iter
430463
assert res[1][4] == self.max_iter
464+
self._aux_test_rewards(res, _mix_id)
431465

432-
def test_all_classes_from_file_runner_2ep_par_spawn(self, env: Optional[Environment]=None):
466+
def test_all_classes_from_file_runner_2ep_par_spawn(self,
467+
env: Optional[Environment]=None,
468+
_mix_id=None):
433469
"""this test that the runner is able to "run" (one other type of run), but the tests on the classes
434470
are much lighter than in test_all_classes_from_file_env_runner"""
435471
if not self._do_test_runner():
@@ -452,6 +488,7 @@ def test_all_classes_from_file_runner_2ep_par_spawn(self, env: Optional[Environm
452488
episode_id=[0, 1])
453489
assert res[0][4] == self.max_iter
454490
assert res[1][4] == self.max_iter
491+
self._aux_test_rewards(res, _mix_id)
455492

456493

457494
class MaskedEnvAutoClassTester(AutoClassInFileTester):
@@ -656,7 +693,8 @@ def test_all_classes_from_file(self,
656693
)
657694
if isinstance(env, MultiMixEnvironment):
658695
# test each mix of a multi mix
659-
for mix in env:
696+
for mix_name in sorted(env.all_names):
697+
mix = env[mix_name]
660698
super().test_all_classes_from_file(mix,
661699
classes_name=classes_name,
662700
name_complete_obs_cls=name_complete_obs_cls,
@@ -675,7 +713,8 @@ def test_all_classes_from_file_env_after_reset(self, env: Optional[Environment]=
675713
super().test_all_classes_from_file_env_after_reset(env)
676714
if isinstance(env, MultiMixEnvironment):
677715
# test each mix of a multimix
678-
for mix in env:
716+
for mix_name in sorted(env.all_names):
717+
mix = env[mix_name]
679718
super().test_all_classes_from_file_env_after_reset(mix)
680719
finally:
681720
if env_orig is None:
@@ -689,7 +728,8 @@ def test_all_classes_from_file_obsenv(self, env: Optional[Environment]=None):
689728
super().test_all_classes_from_file_obsenv(env)
690729
if isinstance(env, MultiMixEnvironment):
691730
# test each mix of a multimix
692-
for mix in env:
731+
for mix_name in sorted(env.all_names):
732+
mix = env[mix_name]
693733
super().test_all_classes_from_file_obsenv(mix)
694734
finally:
695735
if env_orig is None:
@@ -703,7 +743,8 @@ def test_all_classes_from_file_env_cpy(self, env: Optional[Environment]=None):
703743
super().test_all_classes_from_file_env_cpy(env)
704744
if isinstance(env, MultiMixEnvironment):
705745
# test each mix of a multimix
706-
for mix in env:
746+
for mix_name in sorted(env.all_names):
747+
mix = env[mix_name]
707748
super().test_all_classes_from_file_env_cpy(mix)
708749
finally:
709750
if env_orig is None:
@@ -716,7 +757,8 @@ def test_all_classes_from_file_env_runner(self, env: Optional[Environment]=None)
716757
try:
717758
if isinstance(env, MultiMixEnvironment):
718759
# test each mix of a multimix
719-
for mix in env:
760+
for mix_name in sorted(env.all_names):
761+
mix = env[mix_name]
720762
super().test_all_classes_from_file_env_runner(mix)
721763
else:
722764
# runner does not handle multimix
@@ -732,7 +774,8 @@ def test_all_classes_from_file_runner_1ep(self, env: Optional[Environment]=None)
732774
try:
733775
if isinstance(env, MultiMixEnvironment):
734776
# test each mix of a multimix
735-
for mix in env:
777+
for mix_name in sorted(env.all_names):
778+
mix = env[mix_name]
736779
super().test_all_classes_from_file_runner_1ep(mix)
737780
else:
738781
# runner does not handle multimix
@@ -748,8 +791,9 @@ def test_all_classes_from_file_runner_2ep_seq(self, env: Optional[Environment]=N
748791
try:
749792
if isinstance(env, MultiMixEnvironment):
750793
# test each mix of a multimix
751-
for mix in env:
752-
super().test_all_classes_from_file_runner_2ep_seq(mix)
794+
for mix_id, mix_name in enumerate(sorted(env.all_names)):
795+
mix = env[mix_name]
796+
super().test_all_classes_from_file_runner_2ep_seq(mix, mix_id)
753797
else:
754798
# runner does not handle multimix
755799
super().test_all_classes_from_file_runner_2ep_seq(env)
@@ -766,8 +810,9 @@ def test_all_classes_from_file_runner_2ep_par_fork(self, env: Optional[Environme
766810
try:
767811
if isinstance(env, MultiMixEnvironment):
768812
# test each mix of a multimix
769-
for mix in env:
770-
super().test_all_classes_from_file_runner_2ep_par_fork(mix)
813+
for mix_id, mix_name in enumerate(sorted(env.all_names)):
814+
mix = env[mix_name]
815+
super().test_all_classes_from_file_runner_2ep_par_fork(mix, mix_id)
771816
else:
772817
# runner does not handle multimix
773818
super().test_all_classes_from_file_runner_2ep_par_fork(env)
@@ -782,8 +827,9 @@ def test_all_classes_from_file_runner_2ep_par_spawn(self, env: Optional[Environm
782827
try:
783828
if isinstance(env, MultiMixEnvironment):
784829
# test each mix of a multimix
785-
for mix in env:
786-
super().test_all_classes_from_file_runner_2ep_par_spawn(mix)
830+
for mix_id, mix_name in enumerate(sorted(env.all_names)):
831+
mix = env[mix_name]
832+
super().test_all_classes_from_file_runner_2ep_par_spawn(mix, mix_id)
787833
else:
788834
# runner does not handle multimix
789835
super().test_all_classes_from_file_runner_2ep_par_spawn(env)
@@ -798,8 +844,9 @@ def test_forecast_env_basic(self, env: Optional[Environment]=None):
798844
try:
799845
if isinstance(env, MultiMixEnvironment):
800846
# test each mix of a multimix
801-
for mix in env:
802-
obs = mix.reset()
847+
for mix_id, mix_name in enumerate(sorted(env.all_names)):
848+
mix = env[mix_name]
849+
obs = mix.reset(seed=0, options={"time serie id": 0})
803850
for_env = obs.get_forecast_env()
804851
super().test_all_classes_from_file(for_env)
805852
finally:

0 commit comments

Comments
 (0)