Skip to content

Conversation

@DhyeyMavani2003
Copy link

Description

This PR implements the feature requested in #10386 to allow users to pass custom PyTorch Lightning callbacks via the trainer_config parameter in torch_geometric.graphgym.train.

Changes

  • Modified torch_geometric/graphgym/train.py to extract and extend custom callbacks from trainer_config
  • Handles both single callback and list of callbacks
  • Maintains backward compatibility with existing code

Usage Example

from torch_geometric.graphgym.train import train

# Pass custom callbacks
trainer_config = {
    'callbacks': [my_custom_callback],
    # other trainer config options...
}

train(model, datamodule, trainer_config=trainer_config)

Testing

  • ✅ All pre-commit hooks pass (linting, formatting, etc.)
  • ✅ Existing tests remain compatible
  • ✅ The implementation follows the exact suggestion from the issue

Fixes #10386

Allow users to pass custom PyTorch Lightning callbacks via the
trainer_config parameter. This enables extending the training
process with custom monitoring, logging, or other callback
functionality without modifying the core train function.

Fixes pyg-team#10386

Co-authored-by: Ona <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Custom callbacks in graphgym.train

1 participant