@@ -49,8 +49,7 @@ def dataloader_by_name(readerclass,
4949 files .sort ()
5050
5151 # for local cluster: discard some files if files cannot be divided equally between GPUs
52- if (context ["device" ] == "GPU"
53- ) and os .getenv ("PADDLEREC_GPU_NUMS" ) is not None :
52+ if (context ["device" ] == "GPU" ) and "PADDLEREC_GPU_NUMS" in os .environ :
5453 selected_gpu_nums = int (os .getenv ("PADDLEREC_GPU_NUMS" ))
5554 discard_file_nums = len (files ) % selected_gpu_nums
5655 if (discard_file_nums != 0 ):
@@ -122,8 +121,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
122121 files .sort ()
123122
124123 # for local cluster: discard some files if files cannot be divided equally between GPUs
125- if (context ["device" ] == "GPU"
126- ) and os .getenv ("PADDLEREC_GPU_NUMS" ) is not None :
124+ if (context ["device" ] == "GPU" ) and "PADDLEREC_GPU_NUMS" in os .environ :
127125 selected_gpu_nums = int (os .getenv ("PADDLEREC_GPU_NUMS" ))
128126 discard_file_nums = len (files ) % selected_gpu_nums
129127 if (discard_file_nums != 0 ):
@@ -176,84 +174,3 @@ def gen_batch_reader():
176174 if hasattr (reader , 'generate_batch_from_trainfiles' ):
177175 return gen_batch_reader ()
178176 return gen_reader
179-
180-
181- def slotdataloader (readerclass , train , yaml_file , context ):
182- if train == "TRAIN" :
183- reader_name = "SlotReader"
184- namespace = "train.reader"
185- data_path = get_global_env ("train_data_path" , None , namespace )
186- else :
187- reader_name = "SlotReader"
188- namespace = "evaluate.reader"
189- data_path = get_global_env ("test_data_path" , None , namespace )
190-
191- if data_path .startswith ("paddlerec::" ):
192- package_base = get_runtime_environ ("PACKAGE_BASE" )
193- assert package_base is not None
194- data_path = os .path .join (package_base , data_path .split ("::" )[1 ])
195-
196- hidden_file_list , files = check_filelist (
197- hidden_file_list = [], data_file_list = [], train_data_path = data_path )
198- if (hidden_file_list is not None ):
199- print (
200- "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}" .
201- format (hidden_file_list ))
202-
203- files .sort ()
204-
205- # for local cluster: discard some files if files cannot be divided equally between GPUs
206- if (context ["device" ] == "GPU"
207- ) and os .getenv ("PADDLEREC_GPU_NUMS" ) is not None :
208- selected_gpu_nums = int (os .getenv ("PADDLEREC_GPU_NUMS" ))
209- discard_file_nums = len (files ) % selected_gpu_nums
210- if (discard_file_nums != 0 ):
211- warnings .warn (
212- "Because files cannot be divided equally between GPUs,discard these files:{}" .
213- format (files [- discard_file_nums :]))
214- files = files [:len (files ) - discard_file_nums ]
215-
216- need_split_files = False
217- if context ["engine" ] == EngineMode .LOCAL_CLUSTER :
218- # for local cluster: split files for multi process
219- need_split_files = True
220- elif context ["engine" ] == EngineMode .CLUSTER and context [
221- "cluster_type" ] == "K8S" :
222- # for k8s mount mode, split files for every node
223- need_split_files = True
224-
225- if need_split_files :
226- files = split_files (files , context ["fleet" ].worker_index (),
227- context ["fleet" ].worker_num ())
228-
229- sparse = get_global_env ("sparse_slots" , "#" , namespace )
230- if sparse == "" :
231- sparse = "#"
232- dense = get_global_env ("dense_slots" , "#" , namespace )
233- if dense == "" :
234- dense = "#"
235- padding = get_global_env ("padding" , 0 , namespace )
236- reader = SlotReader (yaml_file )
237- reader .init (sparse , dense , int (padding ))
238-
239- def gen_reader ():
240- for file in files :
241- with open (file , 'r' ) as f :
242- for line in f :
243- line = line .rstrip ('\n ' )
244- iter = reader .generate_sample (line )
245- for parsed_line in iter ():
246- if parsed_line is None :
247- continue
248- else :
249- values = []
250- for pased in parsed_line :
251- values .append (pased [1 ])
252- yield values
253-
254- def gen_batch_reader ():
255- return reader .generate_batch_from_trainfiles (files )
256-
257- if hasattr (reader , 'generate_batch_from_trainfiles' ):
258- return gen_batch_reader ()
259- return gen_reader
0 commit comments