Skip to content

Commit 610db6f

Browse files
Binyang2014chhwang
andauthored
Fix test script (#655)
Fix: #654. Address correctness_test.py crash issue Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
1 parent b8f61cb commit 610db6f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

test/torch/correctness_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ def _init_dist():
6262
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
6363
raise RuntimeError("Distributed environment variables not set. Run with torchrun.")
6464
backend = "nccl"
65-
dist.init_process_group(backend=backend)
65+
rank = int(os.environ["RANK"])
66+
world_size = int(os.environ["WORLD_SIZE"])
6667
local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"]))
68+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=local_rank)
6769
torch.cuda.set_device(local_rank)
6870

6971

@@ -174,6 +176,10 @@ def main():
174176
run_reducescatter_test(args.num_elems, args.iters, dtype, rtol, atol)
175177
else:
176178
raise ValueError("Unknown collective")
179+
dist.barrier()
180+
if dist.get_rank() == 0:
181+
print(f"{args.collective} test passed for dtype={dtype} num_elems={args.num_elems} iters={args.iters}")
182+
dist.destroy_process_group()
177183

178184

179185
if __name__ == "__main__":

0 commit comments

Comments
 (0)