diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 4f619307f5cd..47ce2bf5d39f 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -50,36 +50,22 @@ def get_lora_layers(): - try: - if get_env_device() == "xpu": - # If paddle_xpu is not installed, just use PaddleNLP's native lora layers - from paddle_xpu.layers.nn.lora_layers import ( - XPUColumnParallelLoRALinear as ColumnParallelLoRALinear, - ) - from paddle_xpu.layers.nn.lora_layers import ( - XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear, - ) - from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear - from paddle_xpu.layers.nn.lora_layers import ( - XPURowParallelLoRALinear as RowParallelLoRALinear, - ) - from paddle_xpu.layers.nn.lora_layers import ( - XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear, - ) - - from .lora_layers import LoRAConv2D + if get_env_device() == "xpu": + try: + import paddle_xpu + paddle_xpu.init_lora_layers() + except Exception as e: + logger.warning("Failed to import LoRALinear from paddle_xpu, using PaddleNLP's native implementation.") else: - raise ImportError # Force to use the fallback if not XPU - except ImportError: - from .lora_layers import ( - ColumnParallelLoRALinear, - ColumnSequenceParallelLoRALinear, - LoRAConv2D, - LoRALinear, - RowParallelLoRALinear, - RowSequenceParallelLoRALinear, - ) - + logger.info("Import paddle_xpu succeeded.") + from .lora_layers import ( + ColumnParallelLoRALinear, + ColumnSequenceParallelLoRALinear, + LoRAConv2D, + LoRALinear, + RowParallelLoRALinear, + RowSequenceParallelLoRALinear, + ) return { "ColumnParallelLoRALinear": ColumnParallelLoRALinear, "ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear, diff --git a/paddlenlp/transformers/linear_utils.py b/paddlenlp/transformers/linear_utils.py index f0d361068421..4a79efbb7482 100644 --- a/paddlenlp/transformers/linear_utils.py +++ b/paddlenlp/transformers/linear_utils.py @@ -29,6 +29,15 @@ MC2RowSeqParallelLinear, ) from paddlenlp.utils.tools import get_env_device +from paddlenlp.utils.log import logger + +if get_env_device() == "xpu": + try: + import paddle_xpu + except Exception as e: + logger.warning("Failed to import paddle_xpu, using PaddlePaddle's native implementations.") + else: + logger.info("Import paddle_xpu succeeded.") Linear = nn.Linear ColumnParallelLinear = mpu.ColumnParallelLinear @@ -62,23 +71,8 @@ class RowSequenceParallelLinearPass(object): ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear RowSequenceParallelLinear = MC2RowSeqParallelLinear elif get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear - from paddle_xpu.layers.nn import Linear as XPULinear - from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear - from paddle_xpu.layers.nn.sequence_parallel import ( - XPUColumnSequenceParallelLinear, - XPURowSequenceParallelLinear, - ) - - Linear = XPULinear - ColumnParallelLinear = XPUColumnParallelLinear - RowParallelLinear = XPURowParallelLinear - ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear - RowSequenceParallelLinear = XPURowSequenceParallelLinear - except ImportError: - # If paddle_xpu is not installed, just use Paddle's native Linear implementations - pass + import importlib + importlib.reload(nn) else: # By default, use Paddle's native Linear implementations pass