Skip to content

Commit b87496a

Browse files
authored
[hotfix] fix auto policy of test_sharded_optim_v2 (#2157)
1 parent 16335cb commit b87496a

File tree

4 files changed

+3
-13
lines changed

4 files changed

+3
-13
lines changed

colossalai/gemini/memory_tracer/chunk_memstats_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ def record_model_data_volume(self) -> None:
3333

3434
@property
3535
def cuda_margin_mem(self) -> float:
36-
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')
36+
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda

colossalai/gemini/memory_tracer/memory_stats.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,6 @@ def max_non_model_data(self, device_type: str) -> float:
107107
else:
108108
raise TypeError
109109

110-
def max_overall_cuda(self, device_type: str) -> float:
111-
if device_type == 'cuda':
112-
return max(self._overall_cuda_list)
113-
elif device_type == 'cpu':
114-
return max(self._overall_cpu_list)
115-
else:
116-
raise TypeError
117-
118110
def clear(self):
119111
self._model_data_cuda_list = []
120112
self._overall_cuda_list = []

colossalai/gemini/memory_tracer/memstats_collector.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def record_model_data_volume(self) -> None:
7979
if self._start_flag and not self.use_outside_memstats:
8080
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
8181
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
82-
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
83-
self._memstats.append_model_data('cuda', cuda_mem)
84-
self._memstats.append_model_data('cpu', cpu_mem)
82+
self._memstats.record_max_cuda_model_data(cuda_mem)
8583

8684
def sample_overall_data(self) -> None:
8785
"""

tests/test_zero/test_sharded_optim_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
6464
zero_model = ShardedModelV2(
6565
zero_model,
6666
shard_strategy,
67-
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
67+
tensor_placement_policy='cpu' if cpu_offload else 'auto',
6868
reuse_fp16_shard=use_cpuadam,
6969
)
7070

0 commit comments

Comments
 (0)