@@ -109,12 +109,19 @@ def update_component_data(component: str, input_data: np.ndarray, update_data: n
109109 mask = ~ np .isnan (update_data [field ])
110110 else :
111111 mask = np .not_equal (update_data [field ], nan )
112+
112113 if mask .ndim == 2 :
113- mask = np .any (mask , axis = 1 )
114- data = update_data [["id" , field ]][mask ]
115- idx = np .where (input_data ["id" ] == np .reshape (data ["id" ], (- 1 , 1 )))
116- if isinstance (idx , tuple ):
117- input_data [field ][idx [1 ]] = data [field ]
114+ for phase in range (mask .shape [1 ]):
115+ # find indexers of to-be-updated object
116+ sub_mask = mask [:, phase ]
117+ idx = get_indexer (input_data ["id" ], update_data ["id" ][sub_mask ])
118+ # update
119+ input_data [field ][idx , phase ] = update_data [field ][sub_mask , phase ]
120+ else :
121+ # find indexers of to-be-updated object
122+ idx = get_indexer (input_data ["id" ], update_data ["id" ][mask ])
123+ # update
124+ input_data [field ][idx ] = update_data [field ][mask ]
118125
119126
120127def errors_to_string (
@@ -158,3 +165,25 @@ def nan_type(component: str, field: str, data_type="input"):
158165 It silently returns float('nan') if data_type/component/field can't be found.
159166 """
160167 return power_grid_meta_data .get (data_type , {}).get (component , {}).get ("nans" , {}).get (field , float ("nan" ))
168+
169+
170+ def get_indexer (input_ids : np .ndarray , update_ids : np .ndarray ) -> np .ndarray :
171+ """
172+ Given array of ids from input and update dataset.
173+ Find the posision of each id in the update dataset in the context of input dataset.
174+ This is needed to update values in the dataset by id lookup.
175+ Internally this is done by sorting the input ids, then using binary search lookup.
176+
177+ Args:
178+ input_ids: array of ids in the input dataset
179+ update_ids: array of ids in the update dataset
180+
181+ Returns:
182+ np.ndarray: array of positions of the ids from update dataset in the input dataset
183+ the following should hold
184+ input_ids[result] == update_ids
185+ """
186+ permutation_sort = np .argsort (input_ids ) # complexity O(N_input * logN_input)
187+ return permutation_sort [
188+ np .searchsorted (input_ids , update_ids , sorter = permutation_sort )
189+ ] # complexity O(N_update * logN_input)
0 commit comments