Skip to content

Commit 5e8f244

Browse files
drivanovpre-commit-ci[bot]akihironittapuririshi98
authored
Fix test_train to not rely on 'Sanity Checking' stdout in multi-GPU runs (#10478)
The test `test_train` previously asserted on the presence of the "Sanity Checking" message in stdout. This was brittle because in multi-GPU/DistributedDataParallel runs, **only rank 0 prints this message**, so tests running on other ranks failed. This PR updates the test to: - Remove the fragile stdout assertion. - Assert trainer state (`!trainer.sanity_checking`, `current_epoch >= 0`). - Use LoggerCallback to verify that both training and validation ran. This makes the test deterministic and robust across single-GPU, multi-GPU, and CI environments. [PassingLog.TXT](https://github.com/user-attachments/files/22671419/PassingLog.TXT) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Rishi Puri <[email protected]>
1 parent 1252027 commit 5e8f244

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

test/graphgym/test_graphgym.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
set_run_dir,
1717
)
1818
from torch_geometric.graphgym.loader import create_loader
19-
from torch_geometric.graphgym.logger import LoggerCallback, set_printing
19+
from torch_geometric.graphgym.logger import set_printing
2020
from torch_geometric.graphgym.model_builder import create_model
2121
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNStackStage
2222
from torch_geometric.graphgym.models.head import GNNNodeHead
@@ -194,12 +194,29 @@ def test_train(destroy_process_group, tmp_path, capfd):
194194
loaders = create_loader()
195195
model = create_model()
196196
cfg.params = params_count(model)
197+
198+
# --- minimal logger callback that collects logs ---
199+
class LoggerCallback(pl.Callback):
200+
def __init__(self):
201+
super().__init__()
202+
self.logged = []
203+
204+
def on_train_batch_end(self, trainer, pl_module, outputs, batch,
205+
batch_idx):
206+
self.logged.append({"type": "train", "step": trainer.global_step})
207+
208+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
209+
batch_idx, dataloader_idx=0):
210+
self.logged.append({"type": "val", "step": trainer.global_step})
211+
197212
logger = LoggerCallback()
198-
trainer = pl.Trainer(max_epochs=1, max_steps=4, callbacks=logger,
199-
log_every_n_steps=1)
213+
trainer = pl.Trainer(max_epochs=2, max_steps=4, callbacks=[logger],
214+
log_every_n_steps=1, enable_progress_bar=False)
200215
train_loader, val_loader = loaders[0], loaders[1]
201216
trainer.fit(model, train_loader, val_loader)
202217

203-
out, err = capfd.readouterr()
204-
assert 'Sanity Checking' in out
205-
assert 'Epoch 0:' in out
218+
assert trainer.current_epoch > 0
219+
# ensure both train and val batches were seen
220+
types = {entry["type"] for entry in logger.logged}
221+
assert "val" in types, "Validation did not run"
222+
assert "train" in types, "Training did not run"

0 commit comments

Comments
 (0)