Skip to content

Commit 7c8d713

Browse files
authored
Update sequence_parallel for predict (#8547)
1 parent 6757ff9 commit 7c8d713

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040
import paddle.nn as nn
4141
from packaging import version
4242
from paddle import framework
43-
from paddle.base import core
43+
44+
try:
45+
from paddle.base import core
46+
except:
47+
core = None
4448
from paddle.distributed import fleet
4549
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
4650
HybridParallelOptimizer,

paddlenlp/transformers/linear_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
import paddle.distributed.fleet.meta_parallel as mpu
2020
from paddle import nn
21-
from paddle.distributed.fleet.utils import sequence_parallel_utils
21+
22+
try:
23+
from paddle.distributed.fleet.utils import sequence_parallel_utils
24+
except:
25+
sequence_parallel_utils = None
2226

2327
from paddlenlp.transformers.mc2_parallel_linear import (
2428
MC2ColumnSeqParallelLinear,
@@ -29,8 +33,12 @@
2933
Linear = nn.Linear
3034
ColumnParallelLinear = mpu.ColumnParallelLinear
3135
RowParallelLinear = mpu.RowParallelLinear
32-
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
33-
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
36+
try:
37+
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
38+
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
39+
except:
40+
ColumnSequenceParallelLinear = None
41+
RowSequenceParallelLinear = None
3442

3543
if get_env_device() == "npu":
3644
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:

0 commit comments

Comments
 (0)