Skip to content

Commit 4e78ad7

Browse files
committed
calc valid data in data visitor op
1 parent e88f72c commit 4e78ad7

File tree

2 files changed

+65
-14
lines changed

2 files changed

+65
-14
lines changed

fedlearner/common/common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,36 @@ def convert_time_string_to_datetime(value):
253253
return date_time
254254

255255

256+
def get_process_dates(start_date, end_date=None, fmt='%Y%m%d'):
257+
today = datetime.date.today()
258+
today_date = datetime.datetime(today.year, today.month, today.day)
259+
if end_date is None or end_date > today_date:
260+
end_date = today_date
261+
if start_date > end_date:
262+
raise ValueError("start_date should be less than or equal to end_date")
263+
process_dates = []
264+
current_date = start_date
265+
while current_date <= end_date:
266+
process_dates.append(current_date.strftime(fmt))
267+
current_date += datetime.timedelta(days=1)
268+
return process_dates
269+
270+
271+
def end_with_valid_date(path: str) -> bool:
272+
last_field = path.rstrip('/').split('/')[-1]
273+
274+
def is_valid_date(date_str: str) -> bool:
275+
for fmt in ('%Y-%m-%d', '%Y%m%d', '%Y/%m/%d', '%Y.%m.%d'):
276+
try:
277+
datetime.strptime(date_str, fmt)
278+
return True
279+
except ValueError:
280+
continue
281+
return False
282+
283+
return is_valid_date(last_field)
284+
285+
256286
def set_logger():
257287
verbosity = int(os.environ.get('VERBOSITY', 1))
258288
if verbosity == 0:

fedlearner/trainer/data_visitor.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from fedlearner.common import fl_logging
3131
from fedlearner.common import trainer_master_service_pb2 as tm_pb
3232
from fedlearner.common.common import convert_time_string_to_datetime
33+
from fedlearner.common.common import end_with_valid_date
34+
from fedlearner.common.common import get_process_dates
3335
from fedlearner.data_join.data_block_visitor import DataBlockVisitor
3436
from fedlearner.trainer.utils import match_date
3537

@@ -351,22 +353,41 @@ def __init__(self,
351353
if end_date:
352354
end_date = convert_time_string_to_datetime(str(end_date))
353355
datablocks = []
354-
for dirname, _, filenames in tf.io.gfile.walk(data_path):
355-
for filename in filenames:
356-
if not fnmatch(os.path.join(dirname, filename), wildcard):
356+
if start_date and not end_with_valid_date(data_path):
357+
process_dates = get_process_dates(start_date, end_date)
358+
miss_dates = []
359+
for process_date in process_dates:
360+
dir_path = os.path.join(data_path, process_date)
361+
if not tf.io.gfile.exists(dir_path):
362+
miss_dates.append(process_date)
357363
continue
358-
subdirname = os.path.relpath(dirname, data_path)
359-
try:
360-
cur_date = datetime.strptime(subdirname, '%Y%m%d')
361-
if not match_date(cur_date, start_date, end_date):
364+
for _, _, filenames in tf.io.gfile.walk(dir_path):
365+
for filename in filenames:
366+
if not fnmatch(os.path.join(dir_path, filename), wildcard):
367+
continue
368+
block_id = os.path.join(process_date, filename)
369+
datablock = _RawDataBlock(
370+
id=block_id, data_path=os.path.join(dir_path, filename),
371+
start_time=None, end_time=None, type=tm_pb.JOINED)
372+
datablocks.append(datablock)
373+
fl_logging.info('miss_dates: [%s]', ",".join(miss_dates))
374+
else:
375+
for dirname, _, filenames in tf.io.gfile.walk(data_path):
376+
for filename in filenames:
377+
if not fnmatch(os.path.join(dirname, filename), wildcard):
362378
continue
363-
except Exception:
364-
fl_logging.info('subdirname is not the format of time')
365-
block_id = os.path.join(subdirname, filename)
366-
datablock = _RawDataBlock(
367-
id=block_id, data_path=os.path.join(dirname, filename),
368-
start_time=None, end_time=None, type=tm_pb.JOINED)
369-
datablocks.append(datablock)
379+
subdirname = os.path.relpath(dirname, data_path)
380+
try:
381+
cur_date = datetime.strptime(subdirname, '%Y%m%d')
382+
if not match_date(cur_date, start_date, end_date):
383+
continue
384+
except Exception:
385+
fl_logging.info('subdirname is not the format of time')
386+
block_id = os.path.join(subdirname, filename)
387+
datablock = _RawDataBlock(
388+
id=block_id, data_path=os.path.join(dirname, filename),
389+
start_time=None, end_time=None, type=tm_pb.JOINED)
390+
datablocks.append(datablock)
370391
datablocks.sort(key=lambda x: x.id)
371392

372393
fl_logging.info("create DataVisitor by local_data_path: %s",

0 commit comments

Comments
 (0)