diff --git a/src/power_grid_model_ds/_core/data_source/generator/grid_generators.py b/src/power_grid_model_ds/_core/data_source/generator/grid_generators.py index 2dab4a2..6493325 100644 --- a/src/power_grid_model_ds/_core/data_source/generator/grid_generators.py +++ b/src/power_grid_model_ds/_core/data_source/generator/grid_generators.py @@ -4,7 +4,7 @@ """Generators for the grid""" -from typing import Type +from typing import Generic, Type, TypeVar import numpy as np @@ -18,13 +18,15 @@ # pylint: disable=too-few-public-methods,too-many-arguments,too-many-positional-arguments +T = TypeVar("T", bound=Grid) -class RadialGridGenerator: + +class RadialGridGenerator(Generic[T]): """Generates a random but structurally correct radial grid with the given specifications""" def __init__( self, - grid_class: Type[Grid], + grid_class: Type[T], nr_nodes: int = 100, nr_sources: int = 2, nr_nops: int = 10, @@ -36,7 +38,7 @@ def __init__( self.nr_sources = nr_sources self.nr_nops = nr_nops - def run(self, seed=None, create_10_3_kv_net: bool = False): + def run(self, seed=None, create_10_3_kv_net: bool = False) -> T: """Run the generator to create a random radial grid. if a seed is provided, this will be used to set rng.