diff --git a/torch_geometric/graphgym/train.py b/torch_geometric/graphgym/train.py index d391e9de21a8..6060bbd5d6c1 100644 --- a/torch_geometric/graphgym/train.py +++ b/torch_geometric/graphgym/train.py @@ -62,6 +62,15 @@ def train( callbacks.append(ckpt_cbk) trainer_config = trainer_config or {} + + # Allow custom callbacks to be passed via trainer_config + if 'callbacks' in trainer_config: + custom_callbacks = trainer_config.pop('callbacks') + if isinstance(custom_callbacks, list): + callbacks.extend(custom_callbacks) + else: + callbacks.append(custom_callbacks) + trainer = pl.Trainer( **trainer_config, enable_checkpointing=cfg.train.enable_ckpt,