1919from paddlerec .core .utils .envs import get_runtime_environ
2020from paddlerec .core .reader import SlotReader
2121from paddlerec .core .trainer import EngineMode
22+ from paddlerec .core .utils .util import split_files
2223
2324
2425def dataloader_by_name (readerclass ,
@@ -39,7 +40,8 @@ def dataloader_by_name(readerclass,
3940
4041 files = [str (data_path ) + "/%s" % x for x in os .listdir (data_path )]
4142 if context ["engine" ] == EngineMode .LOCAL_CLUSTER :
42- files = context ["fleet" ].split_files (files )
43+ files = split_files (files , context ["fleet" ].worker_index (),
44+ context ["fleet" ].worker_num ())
4345 print ("file_list : {}" .format (files ))
4446
4547 reader = reader_class (yaml_file )
@@ -80,7 +82,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
8082
8183 files = [str (data_path ) + "/%s" % x for x in os .listdir (data_path )]
8284 if context ["engine" ] == EngineMode .LOCAL_CLUSTER :
83- files = context ["fleet" ].split_files (files )
85+ files = split_files (files , context ["fleet" ].worker_index (),
86+ context ["fleet" ].worker_num ())
8487 print ("file_list: {}" .format (files ))
8588
8689 sparse = get_global_env (name + "sparse_slots" , "#" )
@@ -133,7 +136,8 @@ def slotdataloader(readerclass, train, yaml_file, context):
133136
134137 files = [str (data_path ) + "/%s" % x for x in os .listdir (data_path )]
135138 if context ["engine" ] == EngineMode .LOCAL_CLUSTER :
136- files = context ["fleet" ].split_files (files )
139+ files = split_files (files , context ["fleet" ].worker_index (),
140+ context ["fleet" ].worker_num ())
137141 print ("file_list: {}" .format (files ))
138142
139143 sparse = get_global_env ("sparse_slots" , "#" , namespace )
0 commit comments