Skip to content

Commit 58d8b8a

Browse files
authored
[misc] fit torch api upgradation and remove legecy import (#6093)
* [amp] fit torch's new api * [amp] fix api call * [amp] fix api call * [misc] fit torch pytree api upgrade * [misc] remove legacy import * [misc] fit torch amp api * [misc] fit torch amp api
1 parent 5ddad48 commit 58d8b8a

File tree

7 files changed

+20
-12
lines changed

7 files changed

+20
-12
lines changed

colossalai/accelerator/cuda_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,4 @@ def autocast(
279279
"""
280280
Return autocast function
281281
"""
282-
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
282+
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)

colossalai/kernel/jit/option.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22

33
from colossalai.accelerator import get_accelerator
4-
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
54

65
from .bias_dropout_add import bias_dropout_add_fused_train
76
from .bias_gelu import bias_gelu_impl
@@ -45,6 +44,7 @@ def warmup_jit_fusion(
4544
dtype: torch.dtype = torch.float32,
4645
):
4746
"""Compile JIT functions before the main training steps"""
47+
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
4848

4949
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
5050
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())

colossalai/pipeline/schedule/_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import torch
55
import torch.cuda
6+
from packaging.version import Version
67
from torch.nn import Module
7-
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
8+
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
89

910

1011
# this register are for torch under version 1.13.1, maybe removed in the future
@@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
1617
return OrderedDict((key, value) for key, value in zip(context, values))
1718

1819

19-
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
20+
if Version(torch.__version__) <= Version("1.13.1"):
21+
try:
22+
from torch.utils._pytree import register_pytree_node as _register_pytree_node
23+
except ImportError:
24+
from torch.utils._pytree import _register_pytree_node
25+
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
2026

2127

2228
def tree_map_hf(fn: Any, pytree: Any):

colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
import torch.nn
22

3-
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
4-
GradMemStats,
5-
GradMemTracerHook,
6-
ParamMemTracerHook,
7-
)
83
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
94
from colossalai.utils import _cast_float
105

@@ -27,6 +22,12 @@ class RuntimeMemTracer:
2722

2823
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
2924
super().__init__()
25+
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
26+
GradMemStats,
27+
GradMemTracerHook,
28+
ParamMemTracerHook,
29+
)
30+
3031
self.module = module
3132
self.dtype = dtype
3233
self._gradstat = GradMemStats()

colossalai/zero/gemini/placement_policy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.distributed as dist
99

1010
from colossalai.accelerator import get_accelerator
11-
from colossalai.legacy.utils.memory import colo_device_memory_capacity
1211
from colossalai.zero.gemini.chunk import Chunk
1312

1413
from .chunk import Chunk, ChunkManager
@@ -172,6 +171,8 @@ def evict_tensors(
172171
Returns:
173172
int: the volume of memory that is evicted
174173
"""
174+
from colossalai.legacy.utils.memory import colo_device_memory_capacity
175+
175176
start = time()
176177
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
177178
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]

docs/source/en/features/mixed_precision_training_with_booster.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
1616
AMP stands for automatic mixed precision training.
1717
In Colossal-AI, we have incorporated different implementations of mixed precision training:
1818

19-
1. torch.cuda.amp
19+
1. torch.amp
2020
2. apex.amp
2121
3. naive amp
2222

docs/source/zh-Hans/features/mixed_precision_training_with_booster.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
AMP 代表自动混合精度训练。
1717
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:
1818

19-
1. torch.cuda.amp
19+
1. torch.amp
2020
2. apex.amp
2121
3. naive amp
2222

0 commit comments

Comments
 (0)