Skip to content

Commit 7376111

Browse files
XilunWupytorchmergebot
authored andcommitted
[BE] fix compute_global_tensor_shape test (pytorch#161441)
Fixes pytorch#161154 **Test** `pytest test/distributed/tensor/test_utils.py -s -k test_compute_global_tensor_shape_1D` Pull Request resolved: pytorch#161441 Approved by: https://github.com/kwen2501
1 parent 92ab184 commit 7376111

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

.ci/pytorch/multigpu-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then
4545
# DTensor tests
4646
time python test/run_test.py --verbose -i distributed/tensor/test_random_ops
4747
time python test/run_test.py --verbose -i distributed/tensor/test_dtensor_compile
48+
time python test/run_test.py --verbose -i distributed/tensor/test_utils.py
4849

4950
# DeviceMesh test
5051
time python test/run_test.py --verbose -i distributed/test_device_mesh

torch/distributed/tensor/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,12 @@ def compute_global_tensor_shape(
284284
if isinstance(placements[0], Replicate):
285285
return shape
286286
elif isinstance(placements[0], Shard):
287-
local_shape = torch.tensor(list(shape))
287+
local_shape = torch.tensor(list(shape), device=mesh.device_type)
288288
gathered_shaped_tensors = [
289289
torch.empty_like(local_shape, device=local_shape.device)
290290
for _ in range(mesh.size())
291291
]
292-
funcol.all_gather_inplace(gathered_shaped_tensors, local_shape)
292+
funcol.all_gather_inplace(gathered_shaped_tensors, local_shape, mesh)
293293
sharded_dim_sum = 0
294294
shard_dim = placements[0].dim
295295
other_dims = [d for d in range(mesh.ndim) if d != shard_dim]

0 commit comments

Comments
 (0)