Skip to content

Commit 1823294

Browse files
committed
feat(plugin): enhance register method to support aliases for plugins
1 parent 8b0c3b0 commit 1823294

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

deepmd/utils/plugin.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ def __add__(self, other: "Plugin") -> "Plugin":
3838
self.plugins.update(other.plugins)
3939
return self
4040

41-
def register(self, key: str) -> Callable[[object], object]:
41+
def register(
42+
self, key: str, alias: list[str] | None = None
43+
) -> Callable[[object], object]:
4244
"""Register a plugin.
4345
4446
Parameters
4547
----------
4648
key : str
47-
key of the plugin
49+
Primary key of the plugin.
50+
alias : list[str], optional
51+
Alternative keys for the plugin.
4852
4953
Returns
5054
-------
@@ -54,6 +58,9 @@ def register(self, key: str) -> Callable[[object], object]:
5458

5559
def decorator(object: object) -> object:
5660
self.plugins[key] = object
61+
if alias:
62+
for alias_key in alias:
63+
self.plugins[alias_key] = object
5764
return object
5865

5966
return decorator
@@ -119,13 +126,17 @@ class PR:
119126
__plugins = Plugin()
120127

121128
@staticmethod
122-
def register(key: str) -> Callable[[object], object]:
129+
def register(
130+
key: str, alias: list[str] | None = None
131+
) -> Callable[[object], object]:
123132
"""Register a descriptor plugin.
124133
125134
Parameters
126135
----------
127136
key : str
128-
the key of a descriptor
137+
The primary key of the plugin.
138+
alias : list[str], optional
139+
Alternative keys for the plugin.
129140
130141
Returns
131142
-------
@@ -137,8 +148,11 @@ def register(key: str) -> Callable[[object], object]:
137148
>>> @BaseClass.register("some_class")
138149
class SomeClass(BaseClass):
139150
pass
151+
>>> @BaseClass.register("some_class", alias=["alias1", "alias2"])
152+
class SomeClass(BaseClass):
153+
pass
140154
"""
141-
return PR.__plugins.register(key)
155+
return PR.__plugins.register(key, alias=alias)
142156

143157
@classmethod
144158
def get_class_by_type(cls, class_type: str) -> type[object]:

0 commit comments

Comments
 (0)