88
99import os
1010import multiprocessing as mp
11- from typing import Optional
11+ from typing import Optional , Type
1212import warnings
1313import unittest
1414import importlib
1515import numpy as np
16- from gymnasium .vector import AsyncVectorEnv
16+ from grid2op .Space import GridObjects
17+ from gymnasium .vector import AsyncVectorEnv # type: ignore
1718
1819import grid2op
1920from grid2op ._glop_platform_info import _IS_WINDOWS
@@ -75,7 +76,7 @@ def act(self, observation: BaseObservation, reward: float, done: bool = False) -
7576 this_class_act = getattr (this_module , self ._name_cls_act )
7677 else :
7778 raise RuntimeError (f"class { self ._name_cls_act } not found" )
78- res = super (). act ( observation , reward , done )
79+ res = self . action_space ( )
7980 assert isinstance (res , this_class_act )
8081 return res
8182
@@ -85,7 +86,9 @@ class AutoClassMakeTester(unittest.TestCase):
8586 def test_in_make (self ):
8687 with warnings .catch_warnings ():
8788 warnings .filterwarnings ("ignore" )
88- env = grid2op .make ("l2rpn_case14_sandbox" , test = True , class_in_file = False )
89+ env = grid2op .make ("l2rpn_case14_sandbox" ,
90+ test = True ,
91+ class_in_file = False )
8992 assert env ._read_from_local_dir is None
9093 assert not env .classes_are_in_files ()
9194
@@ -117,7 +120,9 @@ def _aux_make_env(self, env: Optional[Environment]=None):
117120 if env is None :
118121 with warnings .catch_warnings ():
119122 warnings .filterwarnings ("ignore" )
120- env = grid2op .make (self .get_env_name (), test = True , class_in_file = True )
123+ env = grid2op .make (self .get_env_name (),
124+ test = True ,
125+ class_in_file = True )
121126 assert env .classes_are_in_files ()
122127 return env
123128
@@ -176,7 +181,7 @@ def test_all_classes_from_file(self,
176181 for name_cls , name_attr in zip (names_cls , names_attr ):
177182 this_module = importlib .import_module (f"{ module_nm } .{ name_cls } _file" , super_module )
178183 if hasattr (this_module , name_cls ):
179- this_class = getattr (this_module , name_cls )
184+ this_class : Type [ GridObjects ] = getattr (this_module , name_cls )
180185 else :
181186 raise RuntimeError (f"class { name_cls } not found" )
182187 if name_attr is not None :
@@ -382,7 +387,30 @@ def test_all_classes_from_file_runner_1ep(self, env: Optional[Environment]=None)
382387 env_seeds = [0 ],
383388 episode_id = [0 ])
384389
385- def test_all_classes_from_file_runner_2ep_seq (self , env : Optional [Environment ]= None ):
390+ def _aux_test_rewards (self , res , _mix_id ):
391+ if issubclass (AutoClassInFileTester , type (self )):
392+ ref = [645.702087 , 648.907958 ]
393+ elif issubclass (TOEnvAutoClassTester , type (self )):
394+ ref = [645.702087 , 648.907958 ]
395+ elif issubclass (MaskedEnvAutoClassTester , type (self )):
396+ ref = [645.702087 , 648.907958 ]
397+ elif issubclass (MultiMixEnvAutoClassTester , type (self )):
398+ if _mix_id == 0 :
399+ ref = [119.103179 , 112.700851 ]
400+ elif _mix_id == 1 :
401+ ref = [119.113189 , 112.687309 ]
402+ else :
403+ raise RuntimeError ("Unknown mix id" )
404+ elif issubclass (ForEnvAutoClassTester , type (self )):
405+ ref = [53.9044303 , 53.9044303 ]
406+ else :
407+ raise RuntimeError ("Unknown test suite" )
408+ assert abs (res [0 ][2 ] - ref [0 ]) <= 1e-5 , f"{ res [0 ][2 ]} vs { ref [0 ]} "
409+ assert abs (res [1 ][2 ] - ref [1 ]) <= 1e-5 , f"{ res [1 ][2 ]} vs { ref [1 ]} "
410+
411+ def test_all_classes_from_file_runner_2ep_seq (self ,
412+ env : Optional [Environment ]= None ,
413+ _mix_id = None ):
386414 """this test that the runner is able to "run" (one other type of run), but the tests on the classes
387415 are much lighter than in test_all_classes_from_file_env_runner"""
388416 if not self ._do_test_runner ():
@@ -396,14 +424,19 @@ def test_all_classes_from_file_runner_2ep_seq(self, env: Optional[Environment]=N
396424 runner = Runner (** env .get_params_for_runner (),
397425 agentClass = None ,
398426 agentInstance = this_agent )
427+
399428 res = runner .run (nb_episode = 2 ,
400429 max_iter = self .max_iter ,
430+ agent_seeds = [2 , 3 ],
401431 env_seeds = [0 , 0 ],
402432 episode_id = [0 , 1 ])
403433 assert res [0 ][4 ] == self .max_iter
404434 assert res [1 ][4 ] == self .max_iter
405-
406- def test_all_classes_from_file_runner_2ep_par_fork (self , env : Optional [Environment ]= None ):
435+ self ._aux_test_rewards (res , _mix_id )
436+
437+ def test_all_classes_from_file_runner_2ep_par_fork (self ,
438+ env : Optional [Environment ]= None ,
439+ _mix_id = None ):
407440 """this test that the runner is able to "run" (one other type of run), but the tests on the classes
408441 are much lighter than in test_all_classes_from_file_env_runner"""
409442 if not self ._do_test_runner ():
@@ -428,8 +461,11 @@ def test_all_classes_from_file_runner_2ep_par_fork(self, env: Optional[Environme
428461 episode_id = [0 , 1 ])
429462 assert res [0 ][4 ] == self .max_iter
430463 assert res [1 ][4 ] == self .max_iter
464+ self ._aux_test_rewards (res , _mix_id )
431465
432- def test_all_classes_from_file_runner_2ep_par_spawn (self , env : Optional [Environment ]= None ):
466+ def test_all_classes_from_file_runner_2ep_par_spawn (self ,
467+ env : Optional [Environment ]= None ,
468+ _mix_id = None ):
433469 """this test that the runner is able to "run" (one other type of run), but the tests on the classes
434470 are much lighter than in test_all_classes_from_file_env_runner"""
435471 if not self ._do_test_runner ():
@@ -452,6 +488,7 @@ def test_all_classes_from_file_runner_2ep_par_spawn(self, env: Optional[Environm
452488 episode_id = [0 , 1 ])
453489 assert res [0 ][4 ] == self .max_iter
454490 assert res [1 ][4 ] == self .max_iter
491+ self ._aux_test_rewards (res , _mix_id )
455492
456493
457494class MaskedEnvAutoClassTester (AutoClassInFileTester ):
@@ -656,7 +693,8 @@ def test_all_classes_from_file(self,
656693 )
657694 if isinstance (env , MultiMixEnvironment ):
658695 # test each mix of a multi mix
659- for mix in env :
696+ for mix_name in sorted (env .all_names ):
697+ mix = env [mix_name ]
660698 super ().test_all_classes_from_file (mix ,
661699 classes_name = classes_name ,
662700 name_complete_obs_cls = name_complete_obs_cls ,
@@ -675,7 +713,8 @@ def test_all_classes_from_file_env_after_reset(self, env: Optional[Environment]=
675713 super ().test_all_classes_from_file_env_after_reset (env )
676714 if isinstance (env , MultiMixEnvironment ):
677715 # test each mix of a multimix
678- for mix in env :
716+ for mix_name in sorted (env .all_names ):
717+ mix = env [mix_name ]
679718 super ().test_all_classes_from_file_env_after_reset (mix )
680719 finally :
681720 if env_orig is None :
@@ -689,7 +728,8 @@ def test_all_classes_from_file_obsenv(self, env: Optional[Environment]=None):
689728 super ().test_all_classes_from_file_obsenv (env )
690729 if isinstance (env , MultiMixEnvironment ):
691730 # test each mix of a multimix
692- for mix in env :
731+ for mix_name in sorted (env .all_names ):
732+ mix = env [mix_name ]
693733 super ().test_all_classes_from_file_obsenv (mix )
694734 finally :
695735 if env_orig is None :
@@ -703,7 +743,8 @@ def test_all_classes_from_file_env_cpy(self, env: Optional[Environment]=None):
703743 super ().test_all_classes_from_file_env_cpy (env )
704744 if isinstance (env , MultiMixEnvironment ):
705745 # test each mix of a multimix
706- for mix in env :
746+ for mix_name in sorted (env .all_names ):
747+ mix = env [mix_name ]
707748 super ().test_all_classes_from_file_env_cpy (mix )
708749 finally :
709750 if env_orig is None :
@@ -716,7 +757,8 @@ def test_all_classes_from_file_env_runner(self, env: Optional[Environment]=None)
716757 try :
717758 if isinstance (env , MultiMixEnvironment ):
718759 # test each mix of a multimix
719- for mix in env :
760+ for mix_name in sorted (env .all_names ):
761+ mix = env [mix_name ]
720762 super ().test_all_classes_from_file_env_runner (mix )
721763 else :
722764 # runner does not handle multimix
@@ -732,7 +774,8 @@ def test_all_classes_from_file_runner_1ep(self, env: Optional[Environment]=None)
732774 try :
733775 if isinstance (env , MultiMixEnvironment ):
734776 # test each mix of a multimix
735- for mix in env :
777+ for mix_name in sorted (env .all_names ):
778+ mix = env [mix_name ]
736779 super ().test_all_classes_from_file_runner_1ep (mix )
737780 else :
738781 # runner does not handle multimix
@@ -748,8 +791,9 @@ def test_all_classes_from_file_runner_2ep_seq(self, env: Optional[Environment]=N
748791 try :
749792 if isinstance (env , MultiMixEnvironment ):
750793 # test each mix of a multimix
751- for mix in env :
752- super ().test_all_classes_from_file_runner_2ep_seq (mix )
794+ for mix_id , mix_name in enumerate (sorted (env .all_names )):
795+ mix = env [mix_name ]
796+ super ().test_all_classes_from_file_runner_2ep_seq (mix , mix_id )
753797 else :
754798 # runner does not handle multimix
755799 super ().test_all_classes_from_file_runner_2ep_seq (env )
@@ -766,8 +810,9 @@ def test_all_classes_from_file_runner_2ep_par_fork(self, env: Optional[Environme
766810 try :
767811 if isinstance (env , MultiMixEnvironment ):
768812 # test each mix of a multimix
769- for mix in env :
770- super ().test_all_classes_from_file_runner_2ep_par_fork (mix )
813+ for mix_id , mix_name in enumerate (sorted (env .all_names )):
814+ mix = env [mix_name ]
815+ super ().test_all_classes_from_file_runner_2ep_par_fork (mix , mix_id )
771816 else :
772817 # runner does not handle multimix
773818 super ().test_all_classes_from_file_runner_2ep_par_fork (env )
@@ -782,8 +827,9 @@ def test_all_classes_from_file_runner_2ep_par_spawn(self, env: Optional[Environm
782827 try :
783828 if isinstance (env , MultiMixEnvironment ):
784829 # test each mix of a multimix
785- for mix in env :
786- super ().test_all_classes_from_file_runner_2ep_par_spawn (mix )
830+ for mix_id , mix_name in enumerate (sorted (env .all_names )):
831+ mix = env [mix_name ]
832+ super ().test_all_classes_from_file_runner_2ep_par_spawn (mix , mix_id )
787833 else :
788834 # runner does not handle multimix
789835 super ().test_all_classes_from_file_runner_2ep_par_spawn (env )
@@ -798,8 +844,9 @@ def test_forecast_env_basic(self, env: Optional[Environment]=None):
798844 try :
799845 if isinstance (env , MultiMixEnvironment ):
800846 # test each mix of a multimix
801- for mix in env :
802- obs = mix .reset ()
847+ for mix_id , mix_name in enumerate (sorted (env .all_names )):
848+ mix = env [mix_name ]
849+ obs = mix .reset (seed = 0 , options = {"time serie id" : 0 })
803850 for_env = obs .get_forecast_env ()
804851 super ().test_all_classes_from_file (for_env )
805852 finally :
0 commit comments