Skip to content

Commit 21c9e8b

Browse files
authored
[Auto-Parallel] adapt optimizer_sharded_state_dict in FlexCheckpoint (PaddlePaddle#76305)
1 parent db57dc8 commit 21c9e8b

File tree

3 files changed

+95
-21
lines changed

3 files changed

+95
-21
lines changed

python/paddle/distributed/flex_checkpoint/dcp/sharded_weight.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(
6161
self.key = key
6262
if local_tensor.is_dist():
6363
self.local_tensor = local_tensor._local_value()
64+
# Note: The local_tensor must keep the same name with the original tensor. Otherwise, the static_to_struct_mapping will be wrong.
65+
self.local_tensor.name = local_tensor.name
6466
self.local_shape = local_tensor._local_shape
6567
else:
6668
self.local_tensor = local_tensor

python/paddle/optimizer/adamw.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -797,11 +797,20 @@ def _generate_base_static_name(vname):
797797

798798
# Determine tensor partitioning scheme
799799
if _MOMENT_NAME in optim_state_type:
800-
optimizer_sharded_state_dict[unified_name] = (
801-
create_sharded_weight_with_new_local(
802-
unified_name, tensor, sharded_weight
800+
if tensor.is_dist():
801+
optimizer_sharded_state_dict[unified_name] = ShardedWeight(
802+
key=unified_name,
803+
local_tensor=tensor,
804+
local_shape=tensor.shape,
805+
global_shape=tensor.shape,
806+
global_offset=sharded_weight.global_offset,
807+
)
808+
else:
809+
optimizer_sharded_state_dict[unified_name] = (
810+
create_sharded_weight_with_new_local(
811+
unified_name, tensor, sharded_weight
812+
)
803813
)
804-
)
805814
else: # Non-momentum parameters
806815
optimizer_sharded_state_dict[unified_name] = ShardedWeight(
807816
key=unified_name,
@@ -817,10 +826,19 @@ def _generate_base_static_name(vname):
817826
struct_name = static_to_struct_mapping[key]
818827
sharded_weight = model_sharded_state_dict[struct_name]
819828
unified_name = f"{struct_name}.w_0"
820-
optimizer_sharded_state_dict[unified_name] = (
821-
create_sharded_weight_with_new_local(
822-
unified_name, tensor, sharded_weight
829+
if tensor.is_dist():
830+
optimizer_sharded_state_dict[unified_name] = ShardedWeight(
831+
key=unified_name,
832+
local_tensor=tensor,
833+
local_shape=tensor.shape,
834+
global_shape=tensor.shape,
835+
global_offset=sharded_weight.global_offset,
836+
)
837+
else:
838+
optimizer_sharded_state_dict[unified_name] = (
839+
create_sharded_weight_with_new_local(
840+
unified_name, tensor, sharded_weight
841+
)
823842
)
824-
)
825843

826844
return optimizer_sharded_state_dict

test/auto_parallel/semi_auto_parallel_for_flex_checkpoint.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import os
1515
import random
1616
import tempfile
1717

@@ -21,6 +21,10 @@
2121
import paddle.distributed as dist
2222
from paddle import nn
2323

24+
MODEL_STATE_DIC = "model_state"
25+
OPTIMIZER_STATE_DIC = "optimizer_state"
26+
MASTER_WEIGHT_DIC = "master_weight"
27+
2428

2529
class SimpleModel(nn.Layer):
2630
def __init__(self, hidden_size=3072, layer_num=2):
@@ -50,44 +54,94 @@ def create_model_and_optimizer(self):
5054
learning_rate=0.001, parameters=model.parameters()
5155
)
5256
opt = dist.shard_optimizer(opt, dist.ShardingStage1("dp", self.mesh))
57+
model, opt = paddle.amp.decorate(
58+
model, optimizers=opt, level='O2', master_grad=True
59+
)
5360
return model, opt
5461

5562
def run_training_and_save(self):
5663
model, opt = self.create_model_and_optimizer()
64+
5765
for step in range(3):
58-
inputs = paddle.randn([self.batch_size, self.hidden_size])
59-
labels = paddle.randn([self.batch_size, self.hidden_size])
66+
inputs = paddle.ones(
67+
[self.batch_size, self.hidden_size], dtype='float16'
68+
)
69+
labels = paddle.ones(
70+
[self.batch_size, self.hidden_size], dtype='float16'
71+
)
6072
inputs = dist.shard_tensor(inputs, self.mesh, [dist.Shard(0)])
6173
logits = model(inputs)
6274
loss = paddle.nn.functional.mse_loss(logits, labels)
6375
loss.backward()
64-
opt.step()
76+
if step == 2:
77+
loss_md5 = loss._md5sum()
78+
else:
79+
opt.step()
6580
print(f"Train step {step}, loss: {loss.item()}")
81+
6682
save_md5 = [p._md5sum() for p in model.parameters()]
83+
84+
# save model and optimizer
85+
model_state_dict_path = os.path.join(self.ckpt_path, MODEL_STATE_DIC)
86+
opt_state_dict_path = os.path.join(self.ckpt_path, OPTIMIZER_STATE_DIC)
87+
master_weights_path = os.path.join(self.ckpt_path, MASTER_WEIGHT_DIC)
6788
sharded_state_dict = model.sharded_state_dict()
68-
dist.save_state_dict(sharded_state_dict, self.ckpt_path)
69-
return save_md5
89+
dist.save_state_dict(sharded_state_dict, model_state_dict_path)
90+
optimizer_states = {}
91+
master_weights = {}
92+
opt_sharded_state_dict = opt.sharded_state_dict(sharded_state_dict)
93+
for k, v in opt_sharded_state_dict.items():
94+
if k.endswith(".w_0"):
95+
master_weights[k] = v
96+
else:
97+
optimizer_states[k] = v
98+
dist.save_state_dict(optimizer_states, opt_state_dict_path)
99+
dist.save_state_dict(master_weights, master_weights_path)
100+
return save_md5, loss_md5
70101

71102
def run_loading_and_validation(self):
72103
model, opt = self.create_model_and_optimizer()
104+
105+
# load model and optimizer
106+
model_state_dict_path = os.path.join(self.ckpt_path, MODEL_STATE_DIC)
107+
master_weights_path = os.path.join(self.ckpt_path, MASTER_WEIGHT_DIC)
108+
opt_states_path = os.path.join(self.ckpt_path, OPTIMIZER_STATE_DIC)
73109
sharded_state_dict = model.sharded_state_dict()
74-
dist.load_state_dict(sharded_state_dict, self.ckpt_path)
110+
dist.load_state_dict(sharded_state_dict, model_state_dict_path)
111+
opt_sharded_state_dict = opt.sharded_state_dict(sharded_state_dict)
112+
opt_states = {}
113+
master_weights = {}
114+
for k, v in opt_sharded_state_dict.items():
115+
if k.endswith(".w_0"):
116+
master_weights[k] = v
117+
else:
118+
opt_states[k] = v
119+
dist.load_state_dict(opt_states, opt_states_path)
120+
dist.load_state_dict(master_weights, master_weights_path)
121+
75122
load_md5 = [p._md5sum() for p in model.parameters()]
76-
for step in range(3):
77-
inputs = paddle.randn([self.batch_size, self.hidden_size])
78-
labels = paddle.randn([self.batch_size, self.hidden_size])
123+
124+
for step in range(1):
125+
inputs = paddle.ones(
126+
[self.batch_size, self.hidden_size], dtype='float16'
127+
)
128+
labels = paddle.ones(
129+
[self.batch_size, self.hidden_size], dtype='float16'
130+
)
79131
inputs = dist.shard_tensor(inputs, self.mesh, [dist.Shard(0)])
80132
logits = model(inputs)
81133
loss = paddle.nn.functional.mse_loss(logits, labels)
82134
loss.backward()
83135
opt.step()
136+
loss_md5 = loss._md5sum()
84137
print(f"Train step {step}, loss: {loss.item()}")
85-
return load_md5
138+
return load_md5, loss_md5
86139

87140
def run_test(self):
88-
save_param_md5sum = self.run_training_and_save()
89-
load_param_md5sum = self.run_loading_and_validation()
141+
save_param_md5sum, loss_md5 = self.run_training_and_save()
142+
load_param_md5sum, loss_md5_reload = self.run_loading_and_validation()
90143
np.testing.assert_equal(save_param_md5sum, load_param_md5sum)
144+
np.testing.assert_equal(loss_md5, loss_md5_reload)
91145

92146

93147
if __name__ == '__main__':

0 commit comments

Comments
 (0)