diff --git a/scripts/evaluation/inference.py b/scripts/evaluation/inference.py index 2beec8d..05a90a1 100644 --- a/scripts/evaluation/inference.py +++ b/scripts/evaluation/inference.py @@ -13,7 +13,13 @@ from funcs import load_model_checkpoint, load_prompts, load_image_batch, get_filelist, save_videos from funcs import batch_ddim_sampling from utils.utils import instantiate_from_config +import logging +def setup_logging(): + """Set up basic configuration for logging.""" + logging.basicConfig(filename='inference_log.txt', level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + logging.info("Logging is configured. Starting the program...") def get_parser(): parser = argparse.ArgumentParser() @@ -61,6 +67,8 @@ def run_inference(args, gpu_num, gpu_no, **kwargs): ## saving folders os.makedirs(args.savedir, exist_ok=True) + logging.info(f"[rank:{gpu_no}] Model and configuration are loaded.") + ## step 2: load data ## ----------------------------------------------------------------- assert os.path.exists(args.prompt_file), "Error: prompt file NOT Found!" @@ -70,7 +78,7 @@ def run_inference(args, gpu_num, gpu_no, **kwargs): samples_split = num_samples // gpu_num residual_tail = num_samples % gpu_num - print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.') + logging.info(f"[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.") indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1))) if gpu_no == 0 and residual_tail != 0: indices = indices + list(range(num_samples-residual_tail, num_samples)) @@ -86,13 +94,14 @@ def run_inference(args, gpu_num, gpu_no, **kwargs): filename_list_rank = [filename_list[i] for i in indices] + ## step 3: run over samples ## ----------------------------------------------------------------- start = time.time() n_rounds = len(prompt_list_rank) // args.bs n_rounds = n_rounds+1 if len(prompt_list_rank) % args.bs != 0 else n_rounds for idx in range(0, n_rounds): - print(f'[rank:{gpu_no}] batch-{idx+1} ({args.bs})x{args.n_samples} ...') + logging.info(f'[rank:{gpu_no}] batch-{idx+1} ({args.bs})x{args.n_samples} ...') idx_s = idx*args.bs idx_e = min(idx_s+args.bs, len(prompt_list_rank)) batch_size = idx_e - idx_s @@ -124,12 +133,13 @@ def run_inference(args, gpu_num, gpu_no, **kwargs): ## b,samples,c,t,h,w save_videos(batch_samples, args.savedir, filenames, fps=args.savefps) - print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") + logging.info(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") if __name__ == '__main__': + setup_logging() now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - print("@CoLVDM Inference: %s"%now) + logging.info(f"@CoLVDM Inference: {now}") parser = get_parser() args = parser.parse_args() seed_everything(args.seed)