1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- #
16- # Licensed under the Apache License, Version 2.0 (the "License");
17- # you may not use this file except in compliance with the License.
18- # You may obtain a copy of the License at
19- #
20- # http://www.apache.org/licenses/LICENSE-2.0
21- #
22- # Unless required by applicable law or agreed to in writing, software
23- # distributed under the License is distributed on an "AS IS" BASIS,
24- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25- # See the License for the specific language governing permissions and
26- # limitations under the License.
27-
2815import paddle
2916import os
3017import paddle .nn as nn
@@ -68,6 +55,10 @@ def main(args):
6855 for parameter in args .opt :
6956 parameter = parameter .strip ()
7057 key , value = parameter .split ("=" )
58+ if type (config .get (key )) is int :
59+ value = int (value )
60+ if type (config .get (key )) is bool :
61+ value = (True if value .lower () == "true" else False )
7162 config [key ] = value
7263
7364 # tools.vars
@@ -79,6 +70,7 @@ def main(args):
7970 train_batch_size = config .get ("runner.train_batch_size" , None )
8071 model_save_path = config .get ("runner.model_save_path" , "model_output" )
8172 model_init_path = config .get ("runner.model_init_path" , None )
73+ use_fleet = config .get ("runner.use_fleet" , False )
8274
8375 logger .info ("**************common.configs**********" )
8476 logger .info (
@@ -102,6 +94,14 @@ def main(args):
10294 # to do : add optimizer function
10395 optimizer = dy_model_class .create_optimizer (dy_model , config )
10496
97+ # use fleet run collective
98+ if use_fleet :
99+ from paddle .distributed import fleet
100+ strategy = fleet .DistributedStrategy ()
101+ fleet .init (is_collective = True , strategy = strategy )
102+ optimizer = fleet .distributed_optimizer (optimizer )
103+ dy_model = fleet .distributed_model (dy_model )
104+
105105 logger .info ("read data" )
106106 train_dataloader = create_data_loader (config = config , place = place )
107107
@@ -186,8 +186,18 @@ def main(args):
186186 tensor_print_str + " epoch time: {:.2f} s" .format (
187187 time .time () - epoch_begin ))
188188
189- save_model (
190- dy_model , optimizer , model_save_path , epoch_id , prefix = 'rec' )
189+ if use_fleet :
190+ trainer_id = paddle .distributed .get_rank ()
191+ if trainer_id == 0 :
192+ save_model (
193+ dy_model ,
194+ optimizer ,
195+ model_save_path ,
196+ epoch_id ,
197+ prefix = 'rec' )
198+ else :
199+ save_model (
200+ dy_model , optimizer , model_save_path , epoch_id , prefix = 'rec' )
191201
192202
193203if __name__ == '__main__' :
0 commit comments