|
23 | 23 | "outputs": [], |
24 | 24 | "source": [ |
25 | 25 | "# Install PyTorch-Ignite\n", |
26 | | - "!pip install pytorch-ignite" |
| 26 | + "!pip install -q pytorch-ignite" |
27 | 27 | ] |
28 | 28 | }, |
29 | 29 | { |
|
45 | 45 | "import os\n", |
46 | 46 | "\n", |
47 | 47 | "in_colab = \"COLAB_TPU_ADDR\" in os.environ\n", |
48 | | - "with_torch_launch = \"WORLD_SIZE\" in os.environ\n", |
| 48 | + "with_torchrun = \"WORLD_SIZE\" in os.environ\n", |
49 | 49 | "\n", |
50 | 50 | "if in_colab:\n", |
51 | 51 | " VERSION = !curl -s https://api.github.com/repos/pytorch/xla/releases/latest | grep -Po '\"tag_name\": \"v\\K.*?(?=\")'\n", |
52 | | - " !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-{VERSION[0]}-cp37-cp37m-linux_x86_64.whl" |
| 52 | + " !pip install --upgrade -q cloud-tpu-client==0.10 torch=={VERSION[0]} torchvision https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-{VERSION[0][:-2]}-cp38-cp38-linux_x86_64.whl\n", |
| 53 | + "\n", |
| 54 | + "!pip list | grep torch" |
53 | 55 | ] |
54 | 56 | }, |
55 | 57 | { |
|
88 | 90 | "\n", |
89 | 91 | "\n", |
90 | 92 | "def get_train_test_datasets(path):\n", |
| 93 | + " # - Get train/test datasets\n", |
| 94 | + " if idist.get_rank() > 0:\n", |
| 95 | + " # Ensure that only rank 0 download the dataset\n", |
| 96 | + " idist.barrier()\n", |
| 97 | + "\n", |
91 | 98 | " train_ds = datasets.CIFAR10(root=path, train=True, download=True, transform=train_transform)\n", |
92 | 99 | " test_ds = datasets.CIFAR10(root=path, train=False, download=False, transform=test_transform)\n", |
93 | 100 | "\n", |
| 101 | + " if idist.get_rank() == 0:\n", |
| 102 | + " # Ensure that only rank 0 download the dataset\n", |
| 103 | + " idist.barrier()\n", |
| 104 | + "\n", |
94 | 105 | " return train_ds, test_ds\n", |
95 | 106 | "\n", |
96 | 107 | "\n", |
|
105 | 116 | "\n", |
106 | 117 | "def get_dataflow(config):\n", |
107 | 118 | "\n", |
108 | | - " # - Get train/test datasets\n", |
109 | | - " if idist.get_rank() > 0:\n", |
110 | | - " # Ensure that only rank 0 download the dataset\n", |
111 | | - " idist.barrier()\n", |
112 | | - "\n", |
113 | 119 | " train_dataset, test_dataset = get_train_test_datasets(config.get(\"data_path\", \".\"))\n", |
114 | 120 | "\n", |
115 | | - " if idist.get_rank() == 0:\n", |
116 | | - " # Ensure that only rank 0 download the dataset\n", |
117 | | - " idist.barrier()\n", |
118 | | - "\n", |
119 | 121 | " # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu\n", |
120 | 122 | " train_loader = idist.auto_dataloader(\n", |
121 | 123 | " train_dataset,\n", |
|
402 | 404 | "# --- Single computation device ---\n", |
403 | 405 | "# $ python main.py\n", |
404 | 406 | "#\n", |
405 | | - "if __name__ == \"__main__\" and not (in_colab or with_torch_launch):\n", |
| 407 | + "if __name__ == \"__main__\" and not (in_colab or with_torchrun):\n", |
406 | 408 | "\n", |
407 | | - " backend = None # or \"nccl\", \"gloo\", \"xla-tpu\" ...\n", |
408 | | - " nproc_per_node = None # or N to spawn N processes\n", |
| 409 | + " backend = None\n", |
| 410 | + " nproc_per_node = None\n", |
409 | 411 | " config = {\n", |
410 | 412 | " \"model\": \"resnet18\",\n", |
411 | 413 | " \"dataset\": \"cifar10\",\n", |
|
416 | 418 | "\n", |
417 | 419 | "\n", |
418 | 420 | "# --- Multiple GPUs ---\n", |
419 | | - "# $ python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py\n", |
| 421 | + "# $ torchrun --nproc_per_node=2 main.py\n", |
420 | 422 | "#\n", |
421 | | - "if __name__ == \"__main__\" and with_torch_launch:\n", |
| 423 | + "if __name__ == \"__main__\" and with_torchrun:\n", |
422 | 424 | "\n", |
423 | | - " backend = \"nccl\" # or \"nccl\", \"gloo\", \"xla-tpu\" ...\n", |
424 | | - " nproc_per_node = None # or N to spawn N processes\n", |
| 425 | + " backend = \"nccl\" # or \"nccl\", \"gloo\"\n", |
| 426 | + " nproc_per_node = None\n", |
425 | 427 | " config = {\n", |
426 | 428 | " \"model\": \"resnet18\",\n", |
427 | 429 | " \"dataset\": \"cifar10\",\n", |
|
435 | 437 | "#\n", |
436 | 438 | "if in_colab:\n", |
437 | 439 | "\n", |
438 | | - " backend = \"xla-tpu\" # or \"nccl\", \"gloo\", \"xla-tpu\" ...\n", |
439 | | - " nproc_per_node = 8 # or N to spawn N processes\n", |
| 440 | + " backend = \"xla-tpu\"\n", |
| 441 | + " nproc_per_node = 8\n", |
440 | 442 | " config = {\n", |
441 | 443 | " \"model\": \"resnet18\",\n", |
442 | 444 | " \"dataset\": \"cifar10\",\n", |
|
465 | 467 | ], |
466 | 468 | "metadata": { |
467 | 469 | "kernelspec": { |
468 | | - "display_name": "Python 3", |
| 470 | + "display_name": "Python 3.10.6 64-bit", |
469 | 471 | "language": "python", |
470 | 472 | "name": "python3" |
471 | 473 | }, |
|
479 | 481 | "name": "python", |
480 | 482 | "nbconvert_exporter": "python", |
481 | 483 | "pygments_lexer": "ipython3", |
482 | | - "version": "3.7.7" |
| 484 | + "version": "3.10.6" |
| 485 | + }, |
| 486 | + "vscode": { |
| 487 | + "interpreter": { |
| 488 | + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" |
| 489 | + } |
483 | 490 | } |
484 | 491 | }, |
485 | 492 | "nbformat": 4, |
|
0 commit comments