Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions deepmd/utils/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ def __add__(self, other: "Plugin") -> "Plugin":
self.plugins.update(other.plugins)
return self

def register(self, key: str) -> Callable[[object], object]:
def register(
self, key: str, alias: list[str] | None = None
) -> Callable[[object], object]:
"""Register a plugin.

Parameters
----------
key : str
key of the plugin
Primary key of the plugin.
alias : list[str], optional
Alternative keys for the plugin.

Returns
-------
Expand All @@ -54,6 +58,9 @@ def register(self, key: str) -> Callable[[object], object]:

def decorator(object: object) -> object:
self.plugins[key] = object
if alias:
for alias_key in alias:
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alias registration logic doesn't validate whether an alias key already exists in the plugins dictionary. If an alias key is already registered (either as a primary key or another alias), it will be silently overwritten. Consider adding validation to check for key conflicts and either raise an error or log a warning to prevent unintended plugin overwrites.

Suggested change
for alias_key in alias:
for alias_key in alias:
if alias_key in self.plugins and self.plugins[alias_key] is not object:
raise ValueError(
f"Alias key {alias_key!r} is already registered for a different plugin"
)

Copilot uses AI. Check for mistakes.
self.plugins[alias_key] = object
Comment on lines 60 to +63
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code doesn't validate whether the primary key already exists in the plugins dictionary before registration. If a plugin is registered with a key that already exists (or matches an existing alias), it will silently overwrite the previous plugin. Consider adding validation to check for key conflicts before registration.

Suggested change
self.plugins[key] = object
if alias:
for alias_key in alias:
self.plugins[alias_key] = object
# Collect all keys (primary and aliases) to be registered
keys_to_register: list[str] = [key]
if alias:
keys_to_register.extend(alias)
# Validate that none of the keys are already registered
for k in keys_to_register:
if k in self.plugins:
raise KeyError(
f"Plugin key '{k}' is already registered and cannot be overwritten."
)
# Register the plugin under all requested keys
for k in keys_to_register:
self.plugins[k] = object

Copilot uses AI. Check for mistakes.
Comment on lines +62 to +63
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop variable 'alias_key' creates potential confusion with the parameter 'alias'. Consider using a more descriptive variable name like 'alias_name' or 'key' to improve code clarity and avoid any naming ambiguity.

Suggested change
for alias_key in alias:
self.plugins[alias_key] = object
for alias_name in alias:
self.plugins[alias_name] = object

Copilot uses AI. Check for mistakes.
return object

return decorator
Expand Down Expand Up @@ -119,13 +126,17 @@ class PR:
__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
def register(
key: str, alias: list[str] | None = None
) -> Callable[[object], object]:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor
The primary key of the plugin.
alias : list[str], optional
Alternative keys for the plugin.

Returns
-------
Expand All @@ -137,8 +148,11 @@ def register(key: str) -> Callable[[object], object]:
>>> @BaseClass.register("some_class")
class SomeClass(BaseClass):
pass
>>> @BaseClass.register("some_class", alias=["alias1", "alias2"])
class SomeClass(BaseClass):
pass
"""
return PR.__plugins.register(key)
return PR.__plugins.register(key, alias=alias)

@classmethod
def get_class_by_type(cls, class_type: str) -> type[object]:
Expand Down
Loading