Skip to content

Commit c44e040

Browse files
authored
[XPU] fix fleet unittests (#68542)
* [XPU] fix fleet unittests * [XPU] fix fleet unittests * refine: use new default parameter * revert unnecessary modifications. * revert unnecessary modifications. * fix cmakelist * revert unnecessary modifications. * fix cmakelist for recompute ut.
1 parent f1c54e9 commit c44e040

14 files changed

+111
-28
lines changed

paddle/phi/api/lib/data_transform.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,33 @@ phi::DenseTensor CastDataType(const phi::GPUContext& dev_ctx,
143143
}
144144
#endif
145145

146+
#ifdef PADDLE_WITH_XPU
147+
phi::DenseTensor CastDataType(const phi::XPUContext& dev_ctx,
148+
const phi::DenseTensor& tensor,
149+
DataType dtype) {
150+
switch (tensor.dtype()) {
151+
case DataType::FLOAT32:
152+
return phi::Cast<float>(dev_ctx, tensor, dtype);
153+
case DataType::FLOAT64:
154+
return phi::Cast<double>(dev_ctx, tensor, dtype);
155+
case DataType::INT32:
156+
return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
157+
case DataType::INT64:
158+
return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
159+
case DataType::FLOAT16:
160+
return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
161+
case DataType::BOOL:
162+
return phi::Cast<bool>(dev_ctx, tensor, dtype);
163+
case DataType::UINT8:
164+
return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
165+
default:
166+
PADDLE_THROW(common::errors::Unimplemented(
167+
"Data type (%s) is not supported when casting data type.",
168+
tensor.dtype()));
169+
}
170+
}
171+
#endif
172+
146173
inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
147174
DataType dtype) {
148175
auto& pool = phi::DeviceContextPool::Instance();
@@ -161,6 +188,11 @@ inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
161188
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
162189
return CastDataType(*dev_ctx, tensor, dtype);
163190
#endif
191+
#ifdef PADDLE_WITH_XPU
192+
} else if (tensor.place().GetType() == phi::AllocationType::XPU) {
193+
auto* dev_ctx = static_cast<phi::XPUContext*>(pool.Get(tensor.place()));
194+
return CastDataType(*dev_ctx, tensor, dtype);
195+
#endif
164196
#ifdef PADDLE_WITH_CUSTOM_DEVICE
165197
} else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) {
166198
phi::DenseTensor out;

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
optim,
7575
group=None,
7676
offload=False,
77-
device="gpu",
77+
device="xpu" if core.is_compiled_with_xpu() else "gpu",
7878
pretrain_sync_models=True,
7979
dp_group=None,
8080
**kw,
@@ -590,6 +590,12 @@ def _step(self):
590590
)
591591
.cast(dtype=param.dtype)
592592
)
593+
elif self._default_device == "xpu":
594+
param.set_value(
595+
self._master_params[param.name]
596+
.to("xpu:" + str(self.dev_id))
597+
.cast(dtype=param.dtype)
598+
)
593599
else:
594600
param.set_value(
595601
self._master_params[param.name]

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from paddle import nn
3232
from paddle.distributed import collective
3333
from paddle.distributed.utils.log_utils import get_logger
34+
from paddle.framework import core
3435

3536
from .group_sharded_optimizer_stage2 import GroupShardedOptimizerStage2
3637
from .group_sharded_storage import GradStorage
@@ -66,7 +67,7 @@ def __init__(
6667
sync_buffers=False,
6768
buffer_max_size=2**23, # 8MB
6869
auto_refresh_trainable=True,
69-
device="gpu",
70+
device="xpu" if core.is_compiled_with_xpu() else "gpu",
7071
dp_group=None,
7172
):
7273
super().__init__()

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
optimizer,
105105
group=None,
106106
sync_buffers=False,
107-
device="gpu",
107+
device="xpu" if core.is_compiled_with_xpu() else "gpu",
108108
segment_size=2**20,
109109
pretrain_sync_models=True,
110110
offload=False,
@@ -310,7 +310,10 @@ def _clear_gradients(self):
310310
paddle.CustomPlace(self._default_device, DEV_ID), True
311311
)
312312
else:
313-
tmp_var = param.cuda(DEV_ID)
313+
# both GPU and XPU
314+
tmp_var = param.to(
315+
self._default_device + ":" + (str)(DEV_ID)
316+
)
314317

315318
if (
316319
tmp_var.dtype == Type.fp32.value
@@ -1197,7 +1200,8 @@ def _cpu2device(param):
11971200
if DEV in paddle.device.get_all_custom_device_type():
11981201
tmp_p = param.fw_storage._copy_to(paddle.CustomPlace(DEV, DEV_ID), True)
11991202
else:
1200-
tmp_p = param.fw_storage.cuda(DEV_ID)
1203+
# both GPU and XPU
1204+
tmp_p = param.fw_storage.to(DEV + ":" + (str)(DEV_ID))
12011205
if (
12021206
tmp_p.dtype == Type.fp32.value
12031207
and param2dtype[param.name] == Type.fp16.value

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def _dygraph_clip(self, params_grads):
167167
global_norm_var = global_norm_var._copy_to(
168168
paddle.CustomPlace(dev_type, dev_id), True
169169
)
170+
elif dev_type == "xpu":
171+
global_norm_var = global_norm_var.to(self._device)
170172
else:
171173
global_norm_var = global_norm_var.cuda(dev_id)
172174

test/collective/fleet/CMakeLists.txt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,13 @@ if((WITH_ROCM) AND LOCAL_ALL_PLAT)
6161
"PADDLE_DIST_UT_PORT=21204;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
6262
)
6363
endif()
64-
if(WITH_NCCL OR WITH_RCCL)
65-
if((WITH_GPU OR WITH_ROCM) AND LOCAL_ALL_PLAT)
64+
if(WITH_NCCL
65+
OR WITH_RCCL
66+
OR WITH_XPU_BKCL)
67+
if((WITH_GPU
68+
OR WITH_ROCM
69+
OR WITH_XPU)
70+
AND LOCAL_ALL_PLAT)
6671
bash_test_modules(
6772
test_parallel_dygraph_mp_layers
6873
START_BASH
@@ -608,13 +613,19 @@ if((WITH_GPU OR WITH_ROCM) AND LOCAL_ALL_PLAT)
608613
set_tests_properties(test_imperative_auto_mixed_precision_for_eager
609614
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
610615
endif()
611-
if((WITH_GPU OR WITH_ROCM) AND LOCAL_ALL_PLAT)
616+
if((WITH_GPU
617+
OR WITH_ROCM
618+
OR WITH_XPU)
619+
AND LOCAL_ALL_PLAT)
612620
py_test_modules(
613621
test_dygraph_recompute_for_eager MODULES test_dygraph_recompute_for_eager
614622
ENVS
615623
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
616624
endif()
617-
if((WITH_GPU OR WITH_ROCM) AND LOCAL_ALL_PLAT)
625+
if((WITH_GPU
626+
OR WITH_ROCM
627+
OR WITH_XPU)
628+
AND LOCAL_ALL_PLAT)
618629
py_test_modules(
619630
test_dygraph_recompute MODULES test_dygraph_recompute ENVS
620631
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")

test/collective/fleet/dygraph_group_sharded_stage2.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def train_mlp(
9999
scale_fn_test=False,
100100
):
101101
if sharding_stage != "dp":
102-
group = paddle.distributed.new_group([0, 1], backend="nccl")
102+
group = paddle.distributed.new_group(
103+
[0, 1], backend="bkcl" if paddle.is_compiled_with_xpu() else "nccl"
104+
)
103105
if opt_group:
104106
optimizer = optimizer_setting(
105107
model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group
@@ -149,7 +151,7 @@ def train_mlp(
149151
)
150152

151153
if sharding_stage == 2:
152-
model.to(device="gpu")
154+
model.to(device="xpu" if paddle.is_compiled_with_xpu() else "gpu")
153155

154156
for eop in range(epoch):
155157
model.train()
@@ -210,7 +212,10 @@ def test_dp_stage2():
210212
)
211213
for i in range(len(dp_params)):
212214
np.testing.assert_allclose(
213-
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6
215+
dp_params[i].numpy(),
216+
stage2_params[i].numpy(),
217+
rtol=1e-6,
218+
atol=1e-8 if paddle.is_compiled_with_xpu() else 0,
214219
)
215220

216221
# stage2 accumulate grad
@@ -232,7 +237,10 @@ def test_dp_stage2():
232237
)
233238
for i in range(len(dp_params)):
234239
np.testing.assert_allclose(
235-
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6
240+
dp_params[i].numpy(),
241+
stage2_params[i].numpy(),
242+
rtol=1e-6,
243+
atol=1e-8 if paddle.is_compiled_with_xpu() else 0,
236244
)
237245

238246
# save/load model

test/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def train_mlp(
9898
test_minimize=False,
9999
):
100100
if sharding_stage != "dp":
101-
group = paddle.distributed.new_group([0, 1], backend="nccl")
101+
group = paddle.distributed.new_group(
102+
[0, 1], backend="bkcl" if paddle.is_compiled_with_xpu() else "nccl"
103+
)
102104
if opt_group:
103105
optimizer = optimizer_setting(
104106
model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group
@@ -140,7 +142,7 @@ def train_mlp(
140142
)
141143

142144
if sharding_stage == 2:
143-
model.to(device="gpu")
145+
model.to(device="xpu" if paddle.is_compiled_with_xpu() else "gpu")
144146

145147
for eop in range(epoch):
146148
model.train()
@@ -166,7 +168,10 @@ def train_mlp(
166168
optimizer.step()
167169
optimizer.clear_grad()
168170

169-
paddle.device.cuda.synchronize()
171+
if paddle.is_compiled_with_xpu():
172+
paddle.device.xpu.synchronize()
173+
else:
174+
paddle.device.cuda.synchronize()
170175

171176
if save_model:
172177
return model, optimizer
@@ -201,7 +206,10 @@ def test_dp_stage2():
201206
)
202207
for i in range(len(dp_params)):
203208
np.testing.assert_allclose(
204-
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6
209+
dp_params[i].numpy(),
210+
stage2_params[i].numpy(),
211+
rtol=1e-6,
212+
atol=1e-8 if paddle.is_compiled_with_xpu() else 0,
205213
)
206214

207215
# stage2 accumulate grad
@@ -223,7 +231,10 @@ def test_dp_stage2():
223231
)
224232
for i in range(len(dp_params)):
225233
np.testing.assert_allclose(
226-
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6
234+
dp_params[i].numpy(),
235+
stage2_params[i].numpy(),
236+
rtol=1e-6,
237+
atol=1e-8 if paddle.is_compiled_with_xpu() else 0,
227238
)
228239

229240
# save/load model

test/collective/fleet/dygraph_group_sharded_stage2_offload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ def train_mlp(model, offload=False, test=False):
9494

9595
for dtype in optimizer.param_storages:
9696
for dst_rank, param_storage in optimizer.param_storages[dtype].items():
97-
param_storage.to(device="gpu", dtype=dtype)
97+
param_storage.to(
98+
device="xpu" if paddle.is_compiled_with_xpu() else "gpu",
99+
dtype=dtype,
100+
)
98101

99102
return model.parameters()
100103

test/collective/fleet/dygraph_group_sharded_stage3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,9 @@ def test_stage2_stage3():
366366
)
367367

368368
# bfp16
369-
nccl_version = core.nccl_version()
370-
371369
if (
372-
nccl_version >= 21000
370+
paddle.is_compiled_with_xpu()
371+
or core.nccl_version() >= 21000
373372
and paddle.device.cuda.get_device_properties().major >= 8
374373
):
375374
stage2_params = train_mlp(

0 commit comments

Comments
 (0)