diff --git a/src/power_grid_model_ds/_core/load_flow.py b/src/power_grid_model_ds/_core/load_flow.py index 97b3cc3..7e26a69 100644 --- a/src/power_grid_model_ds/_core/load_flow.py +++ b/src/power_grid_model_ds/_core/load_flow.py @@ -107,7 +107,7 @@ def calculate_power_flow( Returns output of the power flow calculation (also stored in self.output_data) """ - self.model = self.model or self._setup_model() + self.model = self.model or self.setup_model() self.output_data = self.model.calculate_power_flow( calculation_method=calculation_method, update_data=update_data, **kwargs @@ -146,7 +146,7 @@ def update_model(self, update_data: Dict): """ - self.model = self.model or self._setup_model() + self.model = self.model or self.setup_model() self.model.update(update_data=update_data) def update_grid(self) -> None: @@ -166,7 +166,8 @@ def update_grid(self) -> None: def _match_dtypes(first_dtype: np.dtype, second_dtype: np.dtype): return list(set(first_dtype.names).intersection(set(second_dtype.names))) # type: ignore[arg-type] - def _setup_model(self): + def setup_model(self): + """Setup the PowerGridModel with the input data.""" self._input_data = self._input_data or self.create_input_from_grid() self.model = PowerGridModel(self._input_data, system_frequency=self.system_frequency) return self.model diff --git a/tests/integration/loadflow/test_power_grid_model.py b/tests/integration/loadflow/test_power_grid_model.py index 2a9e91d..9529e6d 100644 --- a/tests/integration/loadflow/test_power_grid_model.py +++ b/tests/integration/loadflow/test_power_grid_model.py @@ -176,6 +176,18 @@ def test_batch_run(self): # Results have been calculated for all 10 scenarios assert 10 == len(output["line"]) + def test_setup_model(self): + """Test whether a pgm model can be setup with a custom grid""" + grid_generator = RadialGridGenerator(grid_class=CustomGrid, nr_nodes=5, nr_sources=1, nr_nops=0) + grid = grid_generator.run(seed=0) + + core_interface = PowerGridModelInterface(grid=grid) + assert core_interface.model is None + assert core_interface._input_data is None + core_interface.setup_model() + assert core_interface.model + assert core_interface._input_data + class TestCreateGridFromInputData: def test_create_grid_from_input_data(self, input_data_pgm):