Skip to content

Commit b704e6c

Browse files
committed
fix
1 parent 3630f72 commit b704e6c

File tree

4 files changed

+5
-7
lines changed

4 files changed

+5
-7
lines changed

colossalai/pipeline/schedule/interleaved_pp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.nn import Module, ModuleList
77
from torch.utils._pytree import tree_map
88

9-
from colossalai.accelerator import get_accelerator
9+
from colossalai.accelerator import get_accelerator, BaseAccelerator
1010
from colossalai.interface import OptimizerWrapper
1111
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
1212
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -17,7 +17,7 @@
1717
from .base import PipelineSchedule
1818

1919

20-
def _wait_p2p(wait_handles: List[get_accelerator().Event]) -> None:
20+
def _wait_p2p(wait_handles: List[BaseAccelerator.Event]) -> None:
2121
if wait_handles is not None:
2222
for req in wait_handles:
2323
req.wait()

colossalai/shardformer/layer/normalization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
SUPPORT_NPU = False
1818
try:
1919
import torch_npu
20-
2120
SUPPORT_NPU = True
22-
warnings.warn("support npu")
2321
except Exception:
24-
warnings.warn("support gpu")
22+
pass
2523

2624

2725
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]

examples/language/llama/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def empty_init():
340340

341341
torch.set_default_dtype(torch.float)
342342
coordinator.print_on_master(
343-
f"Booster init max NPU memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
343+
f"Booster init max device memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
344344
)
345345
coordinator.print_on_master(
346346
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"

extensions/pybind/flash_attention/flash_attention_npu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ...base_extension import _Extension
2+
import math
23

34

45
class FlashAttentionNpuExtension(_Extension):
@@ -27,7 +28,6 @@ def build_jit(self) -> None:
2728
)
2829

2930
def load(self):
30-
import math
3131
from typing import Optional
3232

3333
import torch

0 commit comments

Comments
 (0)