Skip to content

Commit bc9683f

Browse files
authored
Merge pull request #291 from Tencent/develop
update readme adding superPod multi-node results.
2 parents bd40a88 + a5a8e6a commit bc9683f

File tree

11 files changed

+106
-30
lines changed

11 files changed

+106
-30
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ We also evaluated PatrickStar v0.4.3 on a single node of A100 SuperPod. It is ab
2525

2626
Detail benchmark results on WeChat AI data center as well as NVIDIA SuperPod are posted on this [Google Doc](https://docs.google.com/spreadsheets/d/136CWc_jA_2zC4h1r-6dzD4PrOvp6aw6uCDchEyQv6sE/edit?usp=sharing).
2727

28+
29+
Scale PatrickStar to multiple machine (node) on SuperPod.
30+
We succeed to train a GPT3-175B on 32 GPU. As far as we known, it is the first work
31+
to run GPT3 on such small GPU cluster.
32+
Microsoft used 10,000 V100 to pertrain GPT3.
33+
Now you can finetune it or even pretrain your own one on 32 A100 GPU, amazing!
34+
35+
![alt perf](./doc/m_node_superpod.png "performance testing result on multiple Node of SuperNode")
36+
37+
2838
We've also trained the [CLUE-GPT2](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) model with PatrickStar, the loss and accuracy curve is shown below:
2939

3040
![CLUE-GPT2](./doc/clue-gpt2-loss-n-acc.png)

doc/m_node_superpod.png

90.3 KB
Loading

doc/one_node_perf_a100.png

-2.69 KB
Loading

examples/benchmark/process_logs.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
import os
3131
import sys
32+
import numpy as np
33+
from scipy.stats import t
3234

3335

3436
def is_run_this_file(path, file, res_dict, file_dict):
@@ -48,6 +50,8 @@ def is_run_this_file(path, file, res_dict, file_dict):
4850

4951
f = open(path + "/" + file)
5052
is_run = True
53+
54+
perf_list = np.array([])
5155
if not os.path.isdir(file):
5256
fn_list = file.split(".")[1].split("_")
5357
for i in range(len(fn_list)):
@@ -62,17 +66,31 @@ def is_run_this_file(path, file, res_dict, file_dict):
6266
if "Tflops" in line and "WARM" not in line:
6367
sline = line.split()
6468
perf = float(sline[-2])
65-
if key not in res_dict:
66-
res_dict[key] = perf
67-
file_dict[key] = file
68-
else:
69-
if res_dict[key] < perf:
70-
res_dict[key] = perf
71-
file_dict[key] = file
69+
70+
perf_list = np.append(perf_list, perf)
71+
7272
is_run = False
7373
if "RuntimeError" in line:
7474
return False
7575

76+
if len(perf_list) == 0:
77+
return False
78+
79+
# calculate CI of perf_list
80+
perf_list = perf_list[1:-1]
81+
m = perf_list.mean()
82+
s = perf_list.std()
83+
dof = len(perf_list) - 1
84+
confidence = 0.95
85+
t_crit = np.abs(t.ppf((1 - confidence) / 2, dof))
86+
ic_perf = (
87+
-s * t_crit / np.sqrt(len(perf_list)),
88+
+s * t_crit / np.sqrt(len(perf_list)),
89+
)
90+
91+
res_dict[key] = (*ic_perf, m)
92+
file_dict[key] = file
93+
7694
return is_run
7795

7896

examples/model_builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ def model_config(model_name):
176176
SEQ_LEN = 1024
177177
NUM_LAYER = 96
178178
NUM_HEAD = 96
179+
elif model_name == "GPT_220B":
180+
HIDDEN_DIM = 12288
181+
SEQ_LEN = 1024
182+
NUM_LAYER = 120
183+
NUM_HEAD = 96
184+
elif model_name == "GPT_250B":
185+
HIDDEN_DIM = 12288
186+
SEQ_LEN = 1024
187+
NUM_LAYER = 137
188+
NUM_HEAD = 96
179189
elif model_name == "GPT_310B":
180190
HIDDEN_DIM = 16384
181191
SEQ_LEN = 1024

examples/pretrain_bert_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_transformer_model_helper(
237237
is_ckp=use_ckp,
238238
is_fp16=use_fp16,
239239
dist_plan=dist_plan,
240-
num_steps=5,
240+
num_steps=20,
241241
)
242242
print("*" * 20 + " LOSS " + "*" * 20)
243243
print(f"{loss_list}")

examples/run_transformers.sh

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,25 @@ export MEM_PROF=${MEM_PROF:-0}
2828
# asyn memory monitor for mem sampler
2929
export AMM=${AMM:-1}
3030
# mem saving comm
31-
export MSC=${MSC:-0}
31+
export MSC=${MSC:-1}
3232
# mem caching comm
3333
export CACHE=${CACHE:-1}
3434
# async move
3535
export ASYNC_MOVE=${ASYNC_MOVE:-0}
3636
# linear tiling comm
3737
export TILING=${TILING:-0}
38+
# hybrid adam
39+
export HYB=${HYB:-1}
40+
3841
export LOCAL_WORLD_SIZE=${LOCAL_WORLD_SIZE:-1}
3942
export CS_SEARCH=${CS_SEARCH:-0}
4043

44+
export NNODES=${NNODES:-1}
45+
export NODE_RANK=${NODE_RANK:-0}
46+
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
47+
export MASTER_PORT=${MASTER_PORT:-"12345"}
48+
export SUFFIX=${SUFFIX:-""}
49+
4150
if [[ ${TILING} == 1 ]]; then
4251
TILING_FLAG="--with_tiling_linear"
4352
else
@@ -109,13 +118,20 @@ else
109118
fi
110119

111120
let CHUNK_SIZE=${CS}*1024*1024
112-
export HYBRID_ADAM_FLAG="--use_hybrid_adam"
121+
122+
if [[ ${HYB} == 1 ]]; then
123+
export HYBRID_ADAM_FLAG="--use_hybrid_adam"
124+
else
125+
export HYBRID_ADAM_FLAG=""
126+
fi
127+
128+
113129

114130
LOG_DIR="./logs_${MODEL_NAME}"
115131
mkdir -p ${LOG_DIR}
116132

117133
GIT_VER=`git rev-parse --short=5 HEAD`
118-
LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}"
134+
LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_hyb_${HYB}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}_node_${NNODES}_${SUFFIX}"
119135

120136
is_run_flag=`python ./benchmark/is_run_this_file.py --path "${LOG_DIR}" --file "${LOG_FILE}"`
121137
echo is_run_flag $is_run_flag
@@ -183,6 +199,7 @@ python -m torch.distributed.launch --nproc_per_node=1 \
183199
done
184200
else
185201
env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} \
202+
--nnodes=${NNODES} --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
186203
pretrain_bert_demo.py \
187204
--default_chunk_size=${CHUNK_SIZE} \
188205
${cmd_opts} \

patrickstar/core/chunk_list.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from patrickstar.core.const import ChunkType
3535
from patrickstar.core.memtracer import RuntimeMemTracer
3636
from patrickstar.profiler import profiler
37-
from patrickstar.utils import logger, get_rank, get_world_size
37+
from patrickstar.utils import logger, get_rank, get_world_size, log_dist
38+
import logging
3839
import patrickstar.utils.global_timer as global_timer
3940
from .chunk_data import Chunk
4041
from .comm import CommInfo
@@ -216,23 +217,26 @@ def prepare_device(self, target_device: torch.device, need_bytes: int):
216217
target_device.type
217218
)
218219

219-
logger.debug(
220+
log_dist(
220221
f"prepare_target: device {target_device} need_bytes {need_bytes / 1e6} MB, "
221222
f"ava_chunk_mem_size {ava_chunk_mem_size / 1e6} MB, "
222-
f"remaining_chunk_mem_size {remaining_chunk_mem_size / 1e6} MB."
223+
f"remaining_chunk_mem_size {remaining_chunk_mem_size / 1e6} MB.",
224+
level=logging.DEBUG,
223225
)
224226

225227
# TODO(jiaruifang) Situation where there is no space.
226228
# This condition is not good enough, we need to check if botn CPU and GPU
227229
# don't have enough space.
228230
if ava_chunk_mem_size < need_bytes:
229-
logger.warning(
230-
f"{target_device} has not enough space for {need_bytes} elements"
231+
log_dist(
232+
f"{target_device} has not enough space for {need_bytes} elements",
233+
level=logging.WARNING,
231234
)
232-
logger.warning(
235+
log_dist(
233236
f"{target_device} has not enough space for {need_bytes / 1e6} MB. "
234237
f"Device used Chunk Memory is {self.get_chunk_memory_used(target_device) / 1e6} MB. "
235-
f"Avaibale Chunk Memory is {ava_chunk_mem_size / 1e6} MB"
238+
f"Avaibale Chunk Memory is {ava_chunk_mem_size / 1e6} MB",
239+
level=logging.WARNING,
236240
)
237241
if self._time_profile:
238242
global_timer.my_timer.finish_profile("CHUNK_LIST_prepare_device")

patrickstar/core/client.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def __init__(self, rank: int, default_chunk_size: int, config=None):
7474
tracer_config = default_tracer_config
7575
opt_config = default_opt_config
7676

77-
self.mem_tracer = RuntimeMemTracer(self.local_rank, tracer_config)
77+
self.mem_tracer = RuntimeMemTracer(
78+
self.local_rank, tracer_config, opt_config["with_mem_saving_comm"]
79+
)
7880
self.opt_config = opt_config
7981

8082
self.chunk_eviction_strategy = LatestAccessChunkEvictionPolicy(
@@ -396,6 +398,7 @@ def _fetch_remote_chunks(
396398
# If the gpu owns the chunk (local rank), access it.
397399
# If the gpu do not own the chunk (remote chunk), allocate memory.
398400
if src_rank == rank:
401+
self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)
399402
self.chunk_list.access_chunk(chunk_id, compute_device)
400403
else:
401404
self.chunk_list.try_best_allocate_payload(
@@ -447,6 +450,7 @@ def _fetch_remote_chunks(
447450

448451
# Use collective communication to achieve the most efficient communication.
449452
# However, it is memory consumping. world_size chunks on GPU simutaneously.
453+
self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device)
450454
self.chunk_list.access_chunk(local_chunk_id, compute_device)
451455
self.chunk_list[local_chunk_id].pin()
452456
allgather_payload_buff = []
@@ -493,6 +497,7 @@ def _fetch_remote_chunks(
493497
global_timer.my_timer.finish_profile("CLIENT_fetch_remote_chunks")
494498

495499
def _access_tensor_in_chunk(self, param, access_type, compute_device, chunk_id):
500+
self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)
496501
self.chunk_list.access_chunk(chunk_id, compute_device)
497502
# 2. Locate the param on the chunk.
498503
tensor_id = param.ps_attr.get_tensor_id(access_type)
@@ -584,7 +589,7 @@ def access_dist(
584589
local_chunk_id = chunk_id
585590

586591
# collect the time a chunk has to be placed on compute-device
587-
self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device)
592+
# self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device)
588593

589594
ret = self._access_tensor_in_chunk(param, access_type, compute_device, chunk_id)
590595
if self._time_profile:
@@ -640,7 +645,7 @@ def access(
640645
chunk_id = self.chunk_tensor_index.get_chunk_id(param, access_type)
641646

642647
# collect the time a chunk has to be placed on compute-device
643-
self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)
648+
# self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)
644649

645650
if chunk_id is None:
646651
raise RuntimeError(
@@ -763,6 +768,7 @@ def release_dist(
763768
break
764769
if do_allreduce:
765770
# move the chunk_id to GPU
771+
self.chunk_eviction_strategy.trace_access(chunk_id, self.device)
766772
self.chunk_list.access_chunk(chunk_id, self.device)
767773
if self._time_profile:
768774
global_timer.my_timer.start_profile(
@@ -818,6 +824,7 @@ def release_dist(
818824
assert self.chunk_list[local_chunk_id].payload is not None
819825
input_list = []
820826
for i in chunk_id_list:
827+
self.chunk_eviction_strategy.trace_access(i, self.device)
821828
self.chunk_list.access_chunk(i, self.device)
822829
self.chunk_list[i].pin()
823830
input_list.append(self.chunk_list[i].payload)

patrickstar/core/eviction_policy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from queue import PriorityQueue
3333
from patrickstar.core.memtracer import Metronome
3434
from patrickstar.core.const import ChunkState
35-
from patrickstar.utils import logger
35+
from patrickstar.utils import log_dist
36+
import logging
3637

3738

3839
class ChunkEvictionPolicyBase(ABC):
@@ -112,6 +113,8 @@ def derive_eviction_list(self, id_to_chunk_map, need_bytes, target_device):
112113
chunk.get_device() is not None
113114
and chunk.get_device().type == target_device.type
114115
and chunk.get_state() != ChunkState.COMPUTE
116+
and chunk.get_state() != ChunkState.RELEASED
117+
and chunk.get_state() != ChunkState.FREE
115118
and not chunk.is_pin()
116119
):
117120
# The next moment when this chunk was accessed.
@@ -133,10 +136,12 @@ def derive_eviction_list(self, id_to_chunk_map, need_bytes, target_device):
133136

134137
# Raise error when failed to make enough room.
135138
if moved_bytes < need_bytes:
136-
logger.warning(
139+
log_dist(
137140
f"device {target_device} still needs {need_bytes / 1e6} MB, "
138141
f"but there is not enough space on it, only {moved_bytes / 1e6} MB available. "
139-
f"movable_chunk_info {movable_chunk_info}"
142+
f"movable_chunk_info {movable_chunk_info}",
143+
[0],
144+
logging.WARNING,
140145
)
141146
return moved_list
142147

0 commit comments

Comments
 (0)