Skip to content

Commit a764c56

Browse files
authored
Save recognition inference logs to file (#12542)
* Save recognition inference logs to file * Formatted code with black
1 parent 3286d29 commit a764c56

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

tools/infer/predict_rec.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737

3838

3939
class TextRecognizer(object):
40-
def __init__(self, args):
40+
def __init__(self, args, logger=None):
41+
if logger is None:
42+
logger = get_logger()
4143
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
4244
self.rec_batch_num = args.rec_batch_num
4345
self.rec_algorithm = args.rec_algorithm
@@ -157,7 +159,7 @@ def __init__(self, args):
157159
model_precision=args.precision,
158160
batch_size=args.rec_batch_num,
159161
data_shape="dynamic",
160-
save_path=None, # args.save_log_path,
162+
save_path=None, # not used if logger is not None
161163
inference_config=self.config,
162164
pids=pid,
163165
process_name=None,
@@ -701,14 +703,25 @@ def __call__(self, img_list):
701703

702704
def main(args):
703705
image_file_list = get_image_file_list(args.image_dir)
704-
text_recognizer = TextRecognizer(args)
705706
valid_image_file_list = []
706707
img_list = []
707708

709+
# logger
710+
log_file = args.save_log_path
711+
if os.path.is_dir(args.save_log_path) or (
712+
not os.path.exists(args.save_log_path) and args.save_log_path.endswith("/")
713+
):
714+
log_file = os.path.join(log_file, "benchmark_recognition.log")
715+
logger = get_logger(log_file=log_file)
716+
717+
# create text recognizer
718+
text_recognizer = TextRecognizer(args)
719+
708720
logger.info(
709721
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
710722
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
711723
)
724+
712725
# warmup 2 times
713726
if args.warmup:
714727
img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)

0 commit comments

Comments
 (0)