Refactors process_group_tests.py#103
Conversation
torchft/process_group_test.py
Outdated
| self.assertIs(wrapper.parent, pg) | ||
|
|
||
| works = _test_pg(wrapper) | ||
| works = run_collective(pg=wrapper, collective="allreduce") |
There was a problem hiding this comment.
this seems like a pretty big decrease in coverage?
There was a problem hiding this comment.
Added functionality back
torchft/process_group_test.py
Outdated
| shape: torch.Size = example_tensor.shape | ||
| dtype: torch.dtype = example_tensor.dtype | ||
| coll = getattr(pg, collective) | ||
| args_list = _build_args(pg=pg, collective=collective, example_tensor=example_tensor) |
There was a problem hiding this comment.
What's the intention behind pulling this out? I'm not really convinced that this makes it all that much cleaner
In some ways I think I'd prefer if we got rid of the arg generation and instead flatten this out with direct calls i.e.
if collective == "allreduce":
work = pg.allreduce(...)
work.wait()
...
There was a problem hiding this comment.
I agree, I've removed the arg generation and included it in place for run_collective
| pg = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10)) | ||
| try: | ||
| pg.configure(self.store_addr, 0, 1) |
There was a problem hiding this comment.
This seems really slow -- how fast does this run? Launching the subprocess is pretty slow so would actually prefer to run these all on the same PG
If you want prettier printing we can use subtests?
i.e.
for collective in collectives:
with self.subTest(collective=collective):
...
There was a problem hiding this comment.
Good callout, with parameterized it took ~36s, without it took ~16s. I've removed parameterized.
|
updates in #102 are likely cleaner, so I am going to deprecate this PR |
What does this PR do?
As part of #97, this PR refactors
process_group_test:_test_pgtorun_collectivesand extending it to accept a given list of collectives by name.ProcessGroupTestinto three tests:GlooTest,NCCLTestsandDummyTests:GlooTestlogically tests every test usinggloo,NCCLTestwith NCCL, etc.shutdown()and garbage collection etc. to avoid extraneous messages & warnings likeWhy is this needed?
As part of #102, I noticed that there were some mismatches between which collectives ran on which backends (matrix is here). Therefore this logical grouping of tests by backend allows us to define which collectives should be tested explicitly