Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5456,6 +5456,71 @@ def unregister_worker_plugin(self, name, nanny=None):
"""
return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny)

def has_plugin(
self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | list
) -> bool | dict[str, bool]:
"""Check if plugin(s) are registered

Parameters
----------
plugin : str | plugin object | list
Plugin to check. You can use the plugin object directly or the plugin name. For plugin objects, they must have a 'name' attribute. You can also pass a list of plugin objects or names.

Returns
-------
bool or dict[str, bool]
If name is str: True if plugin is registered, False otherwise
If name is list: dict mapping names to registration status

Examples
--------
>>> logging_plugin = LoggingConfigPlugin() # Has name = "logging-config"
>>> client.register_plugin(logging_plugin)
>>> client.has_plugin(logging_plugin)
True

>>> client.has_plugin('logging-config')
True

>>> client.has_plugin([logging_plugin, 'other-plugin'])
{'logging-config': True, 'other-plugin': False}
"""
if isinstance(plugin, str):
result = self.sync(self._get_plugin_registration_status, names=[plugin])
return result[plugin]

elif isinstance(plugin, (WorkerPlugin, SchedulerPlugin, NannyPlugin)):
plugin_name = getattr(plugin, "name", None)
if plugin_name is None:
Comment on lines +5501 to +5503
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible simplification: if we detect a single plugin here, we can assign it to

plugin = [plugin]

and then fall through to the list case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see that the return type is different in that case (bool vs. dict[str, bool]). We could add a local unbox variable to handle this, but perhaps not worth it.

raise ValueError(
f"Plugin {funcname(type(plugin))} has no 'name' attribute. "
"Please add a 'name' attribute to your plugin class."
)
result = self.sync(
self._get_plugin_registration_status, names=[plugin_name]
)
return result[plugin_name]

elif isinstance(plugin, list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Sequence here (and in the type signature)?

names_to_check = []
for p in plugin:
if isinstance(p, str):
names_to_check.append(p)
else:
plugin_name = getattr(p, "name", None)
if plugin_name is None:
raise ValueError(
f"Plugin {funcname(type(p))} has no 'name' attribute"
)
names_to_check.append(plugin_name)
return self.sync(self._get_plugin_registration_status, names=names_to_check)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have an else case that raises rather than silently returning None.


async def _get_plugin_registration_status(
self, names: list[str]
) -> dict[str, bool]:
"""Async implementation for checking plugin registration"""
return await self.scheduler.get_plugin_registration_status(names=names)

@property
def amm(self):
"""Convenience accessors for the :doc:`active_memory_manager`"""
Expand Down
45 changes: 45 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4039,6 +4039,7 @@ async def post(self) -> None:
"unregister_worker_plugin": self.unregister_worker_plugin,
"register_nanny_plugin": self.register_nanny_plugin,
"unregister_nanny_plugin": self.unregister_nanny_plugin,
"get_plugin_registration_status": self.get_plugin_registration_status,
"adaptive_target": self.adaptive_target,
"workers_to_close": self.workers_to_close,
"subscribe_worker_status": self.subscribe_worker_status,
Expand Down Expand Up @@ -8696,6 +8697,50 @@ async def get_worker_monitor_info(
)
return dict(zip(self.workers, results))

async def get_plugin_registration_status(self, names: list[str]) -> dict[str, bool]:
"""Check if plugins are registered in any plugin registry

Checks all plugin registries (worker, scheduler, nanny) and returns True
if the plugin is found in any of them.

Parameters
----------
names : list[str]
List of plugin names to check

Returns
-------
dict[str, bool]
Dict mapping plugin names to their registration status across all registries
"""
result = {}
for name in names:
# Check if plugin exists in any registry
result[name] = (
name in self.worker_plugins
or name in self.plugins
or name in self.nanny_plugins
)
return result

async def get_worker_plugin_registration_status(
self, names: list[str]
) -> dict[str, bool]:
"""Check if worker plugins are registered"""
return {name: name in self.worker_plugins for name in names}

async def get_scheduler_plugin_registration_status(
self, names: list[str]
) -> dict[str, bool]:
"""Check if scheduler plugins are registered"""
return {name: name in self.plugins for name in names}

async def get_nanny_plugin_registration_status(
self, names: list[str]
) -> dict[str, bool]:
"""Check if nanny plugins are registered"""
return {name: name in self.nanny_plugins for name in names}

###########
# Cleanup #
###########
Expand Down
Loading