diff --git a/deepmd/utils/plugin.py b/deepmd/utils/plugin.py index e1817327a3..a5c5c554ab 100644 --- a/deepmd/utils/plugin.py +++ b/deepmd/utils/plugin.py @@ -38,13 +38,17 @@ def __add__(self, other: "Plugin") -> "Plugin": self.plugins.update(other.plugins) return self - def register(self, key: str) -> Callable[[object], object]: + def register( + self, key: str, alias: list[str] | None = None + ) -> Callable[[object], object]: """Register a plugin. Parameters ---------- key : str - key of the plugin + Primary key of the plugin. + alias : list[str], optional + Alternative keys for the plugin. Returns ------- @@ -54,6 +58,9 @@ def register(self, key: str) -> Callable[[object], object]: def decorator(object: object) -> object: self.plugins[key] = object + if alias: + for alias_key in alias: + self.plugins[alias_key] = object return object return decorator @@ -119,13 +126,17 @@ class PR: __plugins = Plugin() @staticmethod - def register(key: str) -> Callable[[object], object]: + def register( + key: str, alias: list[str] | None = None + ) -> Callable[[object], object]: """Register a descriptor plugin. Parameters ---------- key : str - the key of a descriptor + The primary key of the plugin. + alias : list[str], optional + Alternative keys for the plugin. Returns ------- @@ -137,8 +148,11 @@ def register(key: str) -> Callable[[object], object]: >>> @BaseClass.register("some_class") class SomeClass(BaseClass): pass + >>> @BaseClass.register("some_class", alias=["alias1", "alias2"]) + class SomeClass(BaseClass): + pass """ - return PR.__plugins.register(key) + return PR.__plugins.register(key, alias=alias) @classmethod def get_class_by_type(cls, class_type: str) -> type[object]: