|
77 | 77 | # create a sharding plan based on the given world_size.
|
78 | 78 | dp_size = _world_size // tp_size
|
79 | 79 |
|
| 80 | +device_type = torch.accelerator.current_accelerator().type |
80 | 81 | # Create a device mesh with 2 dimensions.
|
81 | 82 | # First dim is the data parallel dimension
|
82 | 83 | # Second dim is the tensor parallel dimension.
|
83 |
| -device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) |
| 84 | +device_mesh = init_device_mesh(device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")) |
84 | 85 |
|
85 | 86 | rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
|
86 | 87 | tp_mesh = device_mesh["tp"]
|
|
92 | 93 | # to mimic the behavior of the dataloader.
|
93 | 94 | dp_rank = dp_mesh.get_local_rank()
|
94 | 95 |
|
95 |
| -# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. |
| 96 | +# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids. |
96 | 97 | simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000)
|
97 | 98 |
|
98 |
| -model = Transformer.from_model_args(simple_llama2_config).to("cuda") |
| 99 | +model = Transformer.from_model_args(simple_llama2_config).to(device_type) |
99 | 100 |
|
100 | 101 | # init model weights
|
101 | 102 | model.init_weights()
|
|
170 | 171 | for i in range(num_iterations):
|
171 | 172 | # seeding with dp_rank to ensure identical inputs for TP groups
|
172 | 173 | torch.manual_seed(i + dp_rank)
|
173 |
| - inp = torch.randint(32000, (8, 256), device="cuda") |
| 174 | + inp = torch.randint(32000, (8, 256), device=device_type) |
174 | 175 |
|
175 | 176 | output = sharded_model(inp)
|
176 | 177 | output.sum().backward()
|
|
0 commit comments