File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -45,6 +45,7 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then
45
45
# DTensor tests
46
46
time python test/run_test.py --verbose -i distributed/tensor/test_random_ops
47
47
time python test/run_test.py --verbose -i distributed/tensor/test_dtensor_compile
48
+ time python test/run_test.py --verbose -i distributed/tensor/test_utils.py
48
49
49
50
# DeviceMesh test
50
51
time python test/run_test.py --verbose -i distributed/test_device_mesh
Original file line number Diff line number Diff line change @@ -284,12 +284,12 @@ def compute_global_tensor_shape(
284
284
if isinstance (placements [0 ], Replicate ):
285
285
return shape
286
286
elif isinstance (placements [0 ], Shard ):
287
- local_shape = torch .tensor (list (shape ))
287
+ local_shape = torch .tensor (list (shape ), device = mesh . device_type )
288
288
gathered_shaped_tensors = [
289
289
torch .empty_like (local_shape , device = local_shape .device )
290
290
for _ in range (mesh .size ())
291
291
]
292
- funcol .all_gather_inplace (gathered_shaped_tensors , local_shape )
292
+ funcol .all_gather_inplace (gathered_shaped_tensors , local_shape , mesh )
293
293
sharded_dim_sum = 0
294
294
shard_dim = placements [0 ].dim
295
295
other_dims = [d for d in range (mesh .ndim ) if d != shard_dim ]
You can’t perform that action at this time.
0 commit comments