|
1 | 1 | from typing import Optional, Union, Any, overload, Literal, Iterable, Iterator |
2 | | -from typing import cast |
| 2 | +from typing import cast, Type, TypeVar |
3 | 3 | import os |
4 | 4 | import json |
5 | 5 | from datetime import date |
|
39 | 39 | logger = logging.getLogger(__name__) |
40 | 40 |
|
41 | 41 |
|
| 42 | +AddonType = TypeVar("AddonType", bound="AddonComponent") |
| 43 | + |
| 44 | + |
42 | 45 | class CAT(AbstractSerialisable): |
43 | 46 | """This is a collection of serialisable model parts. |
44 | 47 | """ |
@@ -839,9 +842,36 @@ def __eq__(self, other: Any) -> bool: |
839 | 842 | # addon (e.g MetaCAT) related stuff |
840 | 843 |
|
841 | 844 | def add_addon(self, addon: AddonComponent) -> None: |
| 845 | + """Add the addon to the model pack an pipe. |
| 846 | +
|
| 847 | + Args: |
| 848 | + addon (AddonComponent): The addon to add. |
| 849 | + """ |
842 | 850 | self.config.components.addons.append(addon.config) |
843 | 851 | self._pipeline.add_addon(addon) |
844 | 852 |
|
| 853 | + def get_addons(self) -> list[AddonComponent]: |
| 854 | + """Get the list of all addons in this model pack. |
| 855 | +
|
| 856 | + Returns: |
| 857 | + list[AddonComponent]: The list of addons present. |
| 858 | + """ |
| 859 | + return list(self._pipeline.iter_addons()) |
| 860 | + |
| 861 | + def get_addons_of_type(self, addon_type: Type[AddonType]) -> list[AddonType]: |
| 862 | + """Get a list of addons of a specific type. |
| 863 | +
|
| 864 | + Args: |
| 865 | + addon_type (Type[AddonType]): The type of addons to look for. |
| 866 | +
|
| 867 | + Returns: |
| 868 | + list[AddonType]: The list of addons of this specific type. |
| 869 | + """ |
| 870 | + return [ |
| 871 | + addon for addon in self.get_addons() |
| 872 | + if isinstance(addon, addon_type) |
| 873 | + ] |
| 874 | + |
845 | 875 |
|
846 | 876 | class OutOfDataException(ValueError): |
847 | 877 | pass |
0 commit comments