Skip to content

Commit c1f487c

Browse files
committed
fixing broken tests
Signed-off-by: DONNOT Benjamin <[email protected]>
1 parent 04293b3 commit c1f487c

File tree

4 files changed

+39
-26
lines changed

4 files changed

+39
-26
lines changed

grid2op/Environment/_env_prev_state.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# SPDX-License-Identifier: MPL-2.0
77
# This file is part of Grid2Op, Grid2Op a testbed platform to model sequential decision making in power systems.
88

9-
from typing import Optional, Type
9+
from typing import Optional, Type, Any
1010
import numpy as np
1111
from grid2op.Space import GridObjects
1212
import grid2op.Backend
13+
from grid2op.typing_variables import CLS_AS_DICT_TYPING
1314
from grid2op.Exceptions import Grid2OpException
1415

1516

@@ -26,7 +27,9 @@ def __init__(self,
2627
init_shunt_q : np.ndarray,
2728
init_shunt_bus : np.ndarray):
2829
self._can_modif = True
29-
self._grid_obj_cls : Type[GridObjects] = grid_obj_cls
30+
self._grid_obj_cls : CLS_AS_DICT_TYPING = grid_obj_cls.cls_to_dict()
31+
self._n_storage = len(self._grid_obj_cls["name_storage"]) # to avoid typing that over and over again
32+
3033
self._load_p : np.ndarray = 1. * init_load_p
3134
self._load_q : np.ndarray = 1. * init_load_q
3235
self._gen_p : np.ndarray = 1. * init_gen_p
@@ -50,21 +53,21 @@ def update(self,
5053
if not self._can_modif:
5154
raise Grid2OpException(f"Impossible to modifiy this _EnvPreviousState")
5255

53-
self._aux_update(topo_vect[self._grid_obj_cls.load_pos_topo_vect],
56+
self._aux_update(topo_vect[self._grid_obj_cls["load_pos_topo_vect"]],
5457
self._load_p,
5558
load_p,
5659
self._load_q,
5760
load_q)
58-
self._aux_update(topo_vect[self._grid_obj_cls.gen_pos_topo_vect],
61+
self._aux_update(topo_vect[self._grid_obj_cls["gen_pos_topo_vect"]],
5962
self._gen_p,
6063
gen_p,
6164
self._gen_v,
6265
gen_v)
6366
self._topo_vect[topo_vect > 0] = 1 * topo_vect[topo_vect > 0]
6467

6568
# update storage units
66-
if self._grid_obj_cls.n_storage > 0:
67-
self._aux_update(topo_vect[self._grid_obj_cls.storage_pos_topo_vect],
69+
if self._n_storage > 0:
70+
self._aux_update(topo_vect[self._grid_obj_cls["storage_pos_topo_vect"]],
6871
self._storage_p,
6972
storage_p)
7073

@@ -84,7 +87,7 @@ def update_from_backend(self,
8487
topo_vect = backend.get_topo_vect()
8588
load_p, load_q, *_ = backend.loads_info()
8689
gen_p, gen_q, gen_v = backend.generators_info()
87-
if self._grid_obj_cls.n_storage > 0:
90+
if self._n_storage > 0:
8891
storage_p, *_ = backend.storages_info()
8992
else:
9093
storage_p = None
@@ -109,7 +112,12 @@ def update_from_other(self,
109112
"_shunt_p",
110113
"_shunt_q",
111114
"_shunt_bus"]:
112-
getattr(self, attr_nm)[:] = getattr(other, attr_nm)
115+
tmp = getattr(self, attr_nm)
116+
if tmp.size > 1:
117+
# works only for array of size 2 or more
118+
tmp[:] = getattr(other, attr_nm)
119+
else:
120+
setattr(self, attr_nm, getattr(other, attr_nm))
113121

114122
def prevent_modification(self):
115123
for attr_nm in ["_load_p",
@@ -121,7 +129,10 @@ def prevent_modification(self):
121129
"_shunt_p",
122130
"_shunt_q",
123131
"_shunt_bus"]:
124-
getattr(self, attr_nm).flags.writeable = False
132+
tmp = getattr(self, attr_nm)
133+
if tmp.size > 1:
134+
# can't set flags on array of size 1 apparently
135+
tmp.flags.writeable = False
125136
self._can_modif = False
126137

127138
def _aux_update(self,

grid2op/Space/GridObjects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4159,7 +4159,7 @@ def _make_cls_dict_extended(cls, res: CLS_AS_DICT_TYPING, as_list=True, copy_=Tr
41594159
cls._CLS_DICT_EXTENDED = res.copy()
41604160

41614161
@classmethod
4162-
def cls_to_dict(cls):
4162+
def cls_to_dict(cls) -> CLS_AS_DICT_TYPING:
41634163
"""
41644164
INTERNAL
41654165

grid2op/tests/test_simulate_disco_load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_backend_action(self):
5656
obs._obs_env._backend_action_set += self.env.action_space({"set_bus": {"loads_id": [(l_id, -1)]}})
5757
assert obs._obs_env._backend_action_set.current_topo.values[l_pos] == -1
5858
tmp = obs._obs_env._backend_action_set() # do as if the action has been processed
59-
assert not obs._obs_env._backend_action_set.load_p.changed[l_id] # it's disconnected, so marked as unchanged now
59+
assert obs._obs_env._backend_action_set.load_p.changed[l_id] # it is not changed because disconnected
6060
assert np.allclose(obs._obs_env._backend_action_set.load_p.values[l_id], 22.3), f"{obs._obs_env._backend_action_set.load_p.values[l_id]:.2f} vs 22.3"
6161

6262

grid2op/tests/test_ts_handlers.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from grid2op.tests.helper_path_test import *
1414

1515
import grid2op
16+
from grid2op.Agent import BaseAgent
1617
from grid2op.Exceptions import NoForecastAvailable
1718
from grid2op.Chronics import GridStateFromFileWithForecasts, GridStateFromFile, GridStateFromFileWithForecastsWithoutMaintenance, FromHandlers
1819
from grid2op.Chronics.handlers import (CSVHandler,
@@ -34,6 +35,20 @@
3435
# TODO check when there is also redispatching
3536

3637

38+
class TestAgent(BaseAgent):
39+
def __init__(self, action_space, tester):
40+
super().__init__(action_space)
41+
self.tester = tester
42+
def act(self, obs, reward, done=False):
43+
# size of the forecast is always 12 even if it's "after" the size of the episode
44+
self.tester._aux_test_obs(obs, max_it=12)
45+
_ = self.tester.env.step(self.action_space()) # for TestPerfectForecastHandler: self.tester.env should be synch with the runner env...
46+
return self.action_space()
47+
def reset(self, obs):
48+
self.tester.env.reset() # for TestPerfectForecastHandler
49+
return super().reset(obs)
50+
51+
3752
def _load_next_chunk_in_memory_hack(self):
3853
self._nb_call += 1
3954
# i load the next chunk as dataframes
@@ -512,7 +527,7 @@ def tearDown(self) -> None:
512527
return super().tearDown()
513528

514529
def _aux_test_obs(self, obs, max_it=12, tol=1e-5):
515-
assert len(obs._forecasted_inj) == max_it + 1 # 12 + 1
530+
assert len(obs._forecasted_inj) == max_it + 1, f"{len(obs._forecasted_inj)} vs {max_it + 1}"
516531
init_obj = obs._forecasted_inj[0]
517532
for for_h, el in enumerate(obs._forecasted_inj):
518533
for k_ in ["load_p", "load_q"]:
@@ -553,20 +568,7 @@ def test_copy(self):
553568
for k_ in ["load_p", "load_q", "prod_p"]:
554569
assert np.all(el[1]["injection"][k_] == el_cpy[1]["injection"][k_])
555570

556-
def test_runner(self):
557-
from grid2op.Agent import BaseAgent
558-
class TestAgent(BaseAgent):
559-
def __init__(self, action_space, tester):
560-
super().__init__(action_space)
561-
self.tester = tester
562-
def act(self, obs, reward, done=False):
563-
self.tester._aux_test_obs(obs, max_it=5 - obs.current_step)
564-
_ = self.tester.env.step(self.action_space()) # for TestPerfectForecastHandler: self.tester.env should be synch with the runner env...
565-
return self.action_space()
566-
def reset(self, obs):
567-
self.tester.env.reset() # for TestPerfectForecastHandler
568-
return super().reset(obs)
569-
571+
def test_runner(self):
570572
testagent = TestAgent(self.env.action_space, self)
571573
self.env.set_id(0) # for TestPerfectForecastHandler
572574
runner = Runner(**self.env.get_params_for_runner(), agentClass=None, agentInstance=testagent)

0 commit comments

Comments
 (0)