66# SPDX-License-Identifier: MPL-2.0
77# This file is part of Grid2Op, Grid2Op a testbed platform to model sequential decision making in power systems.
88
9- from typing import Optional , Type
9+ from typing import Optional , Type , Any
1010import numpy as np
1111from grid2op .Space import GridObjects
1212import grid2op .Backend
13+ from grid2op .typing_variables import CLS_AS_DICT_TYPING
1314from grid2op .Exceptions import Grid2OpException
1415
1516
@@ -26,7 +27,9 @@ def __init__(self,
2627 init_shunt_q : np .ndarray ,
2728 init_shunt_bus : np .ndarray ):
2829 self ._can_modif = True
29- self ._grid_obj_cls : Type [GridObjects ] = grid_obj_cls
30+ self ._grid_obj_cls : CLS_AS_DICT_TYPING = grid_obj_cls .cls_to_dict ()
31+ self ._n_storage = len (self ._grid_obj_cls ["name_storage" ]) # to avoid typing that over and over again
32+
3033 self ._load_p : np .ndarray = 1. * init_load_p
3134 self ._load_q : np .ndarray = 1. * init_load_q
3235 self ._gen_p : np .ndarray = 1. * init_gen_p
@@ -50,21 +53,21 @@ def update(self,
5053 if not self ._can_modif :
5154 raise Grid2OpException (f"Impossible to modifiy this _EnvPreviousState" )
5255
53- self ._aux_update (topo_vect [self ._grid_obj_cls . load_pos_topo_vect ],
56+ self ._aux_update (topo_vect [self ._grid_obj_cls [ " load_pos_topo_vect" ] ],
5457 self ._load_p ,
5558 load_p ,
5659 self ._load_q ,
5760 load_q )
58- self ._aux_update (topo_vect [self ._grid_obj_cls . gen_pos_topo_vect ],
61+ self ._aux_update (topo_vect [self ._grid_obj_cls [ " gen_pos_topo_vect" ] ],
5962 self ._gen_p ,
6063 gen_p ,
6164 self ._gen_v ,
6265 gen_v )
6366 self ._topo_vect [topo_vect > 0 ] = 1 * topo_vect [topo_vect > 0 ]
6467
6568 # update storage units
66- if self ._grid_obj_cls . n_storage > 0 :
67- self ._aux_update (topo_vect [self ._grid_obj_cls . storage_pos_topo_vect ],
69+ if self ._n_storage > 0 :
70+ self ._aux_update (topo_vect [self ._grid_obj_cls [ " storage_pos_topo_vect" ] ],
6871 self ._storage_p ,
6972 storage_p )
7073
@@ -84,7 +87,7 @@ def update_from_backend(self,
8487 topo_vect = backend .get_topo_vect ()
8588 load_p , load_q , * _ = backend .loads_info ()
8689 gen_p , gen_q , gen_v = backend .generators_info ()
87- if self ._grid_obj_cls . n_storage > 0 :
90+ if self ._n_storage > 0 :
8891 storage_p , * _ = backend .storages_info ()
8992 else :
9093 storage_p = None
@@ -109,7 +112,12 @@ def update_from_other(self,
109112 "_shunt_p" ,
110113 "_shunt_q" ,
111114 "_shunt_bus" ]:
112- getattr (self , attr_nm )[:] = getattr (other , attr_nm )
115+ tmp = getattr (self , attr_nm )
116+ if tmp .size > 1 :
117+ # works only for array of size 2 or more
118+ tmp [:] = getattr (other , attr_nm )
119+ else :
120+ setattr (self , attr_nm , getattr (other , attr_nm ))
113121
114122 def prevent_modification (self ):
115123 for attr_nm in ["_load_p" ,
@@ -121,7 +129,10 @@ def prevent_modification(self):
121129 "_shunt_p" ,
122130 "_shunt_q" ,
123131 "_shunt_bus" ]:
124- getattr (self , attr_nm ).flags .writeable = False
132+ tmp = getattr (self , attr_nm )
133+ if tmp .size > 1 :
134+ # can't set flags on array of size 1 apparently
135+ tmp .flags .writeable = False
125136 self ._can_modif = False
126137
127138 def _aux_update (self ,
0 commit comments