Skip to content

Ignore Keyword Arguments Outside of Callback Signature During Fabric.call #20915

@ryan-minato

Description

@ryan-minato

Outline & Motivation

Currently, Fabric.call passes all positional and keyword arguments (except hook_name) directly to the callback function. This implicitly requires that all callbacks for a given hook name must have mutually compatible function signatures, which hinders code sharing among callback classes.

https://github.com/Lightning-AI/pytorch-lightning/blob/cb1afbe37772c2b3acbbeb062703c0020355de13/src/lightning/fabric/fabric.py#L838C1-L876C1

This issue proposes a change to Fabric.call so that it inspects the function signature of each callback and ignores keyword arguments that are not present in that callback's signature.

Pitch

Before invoking each callback hook within Fabric.call, compare the callback's function signature against the provided keyword arguments. Any keyword arguments not present in the callback's signature should be filtered out.

https://github.com/Lightning-AI/pytorch-lightning/blob/cb1afbe37772c2b3acbbeb062703c0020355de13/src/lightning/fabric/fabric.py#L868C1-L870C1

Here's a demo of how this filtering logic could be implemented:

            sig = inspect.signature(method)
            
            filtered_kwargs = {}
            for name, param in sig.parameters.items():
                # Exclude 'self' parameter for instance methods
                if name == 'self':
                    continue
                # If the method accepts **kwargs, pass all original kwargs directly
                if param.kind == inspect.Parameter.VAR_KEYWORD:
                    filtered_kwargs = kwargs
                    break
                # If the parameter exists in the incoming kwargs, add it to filtered_kwargs
                if name in kwargs:
                    filtered_kwargs[name] = kwargs[name]

            method(*args, **filtered_kwargs)

This demo illustrates the core logic. inspect.signature is generally not a very expensive operation. In benchmarks, operations similar to those shown typically complete within tens of microseconds. However, in some particularly complex scenarios, inspect.signature could take over 100 microseconds. If performance becomes a critical concern for frequently called callbacks, it might be beneficial to cache the filtered_kwargs results for each unique callback function.

Optional Breaking Change
To further enhance readability and maintainability, we could disallow Fabric.call from accepting any positional arguments other than hook_name. This would change the method signature to def call(self, hook_name: str, /, **kwargs: Any) -> None. This change would enforce that all parameters passed to callbacks are explicitly named keyword arguments, making the intent clearer.

Additional context

No response

cc @lantiga @justusschock

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions