|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | - |
| 14 | +import os |
15 | 15 | import random |
16 | 16 | import tempfile |
17 | 17 |
|
|
21 | 21 | import paddle.distributed as dist |
22 | 22 | from paddle import nn |
23 | 23 |
|
| 24 | +MODEL_STATE_DIC = "model_state" |
| 25 | +OPTIMIZER_STATE_DIC = "optimizer_state" |
| 26 | +MASTER_WEIGHT_DIC = "master_weight" |
| 27 | + |
24 | 28 |
|
25 | 29 | class SimpleModel(nn.Layer): |
26 | 30 | def __init__(self, hidden_size=3072, layer_num=2): |
@@ -50,44 +54,94 @@ def create_model_and_optimizer(self): |
50 | 54 | learning_rate=0.001, parameters=model.parameters() |
51 | 55 | ) |
52 | 56 | opt = dist.shard_optimizer(opt, dist.ShardingStage1("dp", self.mesh)) |
| 57 | + model, opt = paddle.amp.decorate( |
| 58 | + model, optimizers=opt, level='O2', master_grad=True |
| 59 | + ) |
53 | 60 | return model, opt |
54 | 61 |
|
55 | 62 | def run_training_and_save(self): |
56 | 63 | model, opt = self.create_model_and_optimizer() |
| 64 | + |
57 | 65 | for step in range(3): |
58 | | - inputs = paddle.randn([self.batch_size, self.hidden_size]) |
59 | | - labels = paddle.randn([self.batch_size, self.hidden_size]) |
| 66 | + inputs = paddle.ones( |
| 67 | + [self.batch_size, self.hidden_size], dtype='float16' |
| 68 | + ) |
| 69 | + labels = paddle.ones( |
| 70 | + [self.batch_size, self.hidden_size], dtype='float16' |
| 71 | + ) |
60 | 72 | inputs = dist.shard_tensor(inputs, self.mesh, [dist.Shard(0)]) |
61 | 73 | logits = model(inputs) |
62 | 74 | loss = paddle.nn.functional.mse_loss(logits, labels) |
63 | 75 | loss.backward() |
64 | | - opt.step() |
| 76 | + if step == 2: |
| 77 | + loss_md5 = loss._md5sum() |
| 78 | + else: |
| 79 | + opt.step() |
65 | 80 | print(f"Train step {step}, loss: {loss.item()}") |
| 81 | + |
66 | 82 | save_md5 = [p._md5sum() for p in model.parameters()] |
| 83 | + |
| 84 | + # save model and optimizer |
| 85 | + model_state_dict_path = os.path.join(self.ckpt_path, MODEL_STATE_DIC) |
| 86 | + opt_state_dict_path = os.path.join(self.ckpt_path, OPTIMIZER_STATE_DIC) |
| 87 | + master_weights_path = os.path.join(self.ckpt_path, MASTER_WEIGHT_DIC) |
67 | 88 | sharded_state_dict = model.sharded_state_dict() |
68 | | - dist.save_state_dict(sharded_state_dict, self.ckpt_path) |
69 | | - return save_md5 |
| 89 | + dist.save_state_dict(sharded_state_dict, model_state_dict_path) |
| 90 | + optimizer_states = {} |
| 91 | + master_weights = {} |
| 92 | + opt_sharded_state_dict = opt.sharded_state_dict(sharded_state_dict) |
| 93 | + for k, v in opt_sharded_state_dict.items(): |
| 94 | + if k.endswith(".w_0"): |
| 95 | + master_weights[k] = v |
| 96 | + else: |
| 97 | + optimizer_states[k] = v |
| 98 | + dist.save_state_dict(optimizer_states, opt_state_dict_path) |
| 99 | + dist.save_state_dict(master_weights, master_weights_path) |
| 100 | + return save_md5, loss_md5 |
70 | 101 |
|
71 | 102 | def run_loading_and_validation(self): |
72 | 103 | model, opt = self.create_model_and_optimizer() |
| 104 | + |
| 105 | + # load model and optimizer |
| 106 | + model_state_dict_path = os.path.join(self.ckpt_path, MODEL_STATE_DIC) |
| 107 | + master_weights_path = os.path.join(self.ckpt_path, MASTER_WEIGHT_DIC) |
| 108 | + opt_states_path = os.path.join(self.ckpt_path, OPTIMIZER_STATE_DIC) |
73 | 109 | sharded_state_dict = model.sharded_state_dict() |
74 | | - dist.load_state_dict(sharded_state_dict, self.ckpt_path) |
| 110 | + dist.load_state_dict(sharded_state_dict, model_state_dict_path) |
| 111 | + opt_sharded_state_dict = opt.sharded_state_dict(sharded_state_dict) |
| 112 | + opt_states = {} |
| 113 | + master_weights = {} |
| 114 | + for k, v in opt_sharded_state_dict.items(): |
| 115 | + if k.endswith(".w_0"): |
| 116 | + master_weights[k] = v |
| 117 | + else: |
| 118 | + opt_states[k] = v |
| 119 | + dist.load_state_dict(opt_states, opt_states_path) |
| 120 | + dist.load_state_dict(master_weights, master_weights_path) |
| 121 | + |
75 | 122 | load_md5 = [p._md5sum() for p in model.parameters()] |
76 | | - for step in range(3): |
77 | | - inputs = paddle.randn([self.batch_size, self.hidden_size]) |
78 | | - labels = paddle.randn([self.batch_size, self.hidden_size]) |
| 123 | + |
| 124 | + for step in range(1): |
| 125 | + inputs = paddle.ones( |
| 126 | + [self.batch_size, self.hidden_size], dtype='float16' |
| 127 | + ) |
| 128 | + labels = paddle.ones( |
| 129 | + [self.batch_size, self.hidden_size], dtype='float16' |
| 130 | + ) |
79 | 131 | inputs = dist.shard_tensor(inputs, self.mesh, [dist.Shard(0)]) |
80 | 132 | logits = model(inputs) |
81 | 133 | loss = paddle.nn.functional.mse_loss(logits, labels) |
82 | 134 | loss.backward() |
83 | 135 | opt.step() |
| 136 | + loss_md5 = loss._md5sum() |
84 | 137 | print(f"Train step {step}, loss: {loss.item()}") |
85 | | - return load_md5 |
| 138 | + return load_md5, loss_md5 |
86 | 139 |
|
87 | 140 | def run_test(self): |
88 | | - save_param_md5sum = self.run_training_and_save() |
89 | | - load_param_md5sum = self.run_loading_and_validation() |
| 141 | + save_param_md5sum, loss_md5 = self.run_training_and_save() |
| 142 | + load_param_md5sum, loss_md5_reload = self.run_loading_and_validation() |
90 | 143 | np.testing.assert_equal(save_param_md5sum, load_param_md5sum) |
| 144 | + np.testing.assert_equal(loss_md5, loss_md5_reload) |
91 | 145 |
|
92 | 146 |
|
93 | 147 | if __name__ == '__main__': |
|
0 commit comments