diff --git a/inference_video.py b/inference_video.py index 1fa8cd4..be07cb6 100644 --- a/inference_video.py +++ b/inference_video.py @@ -30,6 +30,8 @@ from torch.utils.data import DataLoader from torchvision import transforms as T from torchvision.transforms.functional import to_pil_image +from multiprocessing import Process, Pipe +from queue import Queue from threading import Thread from tqdm import tqdm from PIL import Image @@ -79,15 +81,36 @@ class VideoWriter: def __init__(self, path, frame_rate, width, height): - self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height)) + output_p, input_p = Pipe() + self.worker = Process(target=self.VideoWriterWorker, args=(path, frame_rate, width, height, (output_p, input_p))) + self.worker.start() + output_p.close() + self.input_p = input_p def add_batch(self, frames): - frames = frames.mul(255).byte() - frames = frames.cpu().permute(0, 2, 3, 1).numpy() - for i in range(frames.shape[0]): - frame = frames[i] - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - self.out.write(frame) + frames = frames.mul(255).byte().permute(0, 2, 3, 1) + self.input_p.send(frames.cpu()) + + def close(self): + self.input_p.send(0) + self.worker.join() + + @staticmethod + def VideoWriterWorker(path, frame_rate, width, height, pipe): + output_p, input_p = pipe + input_p.close() + out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height)) + while True: + read_buffer = output_p.recv() + # gracefully exit with provided exit code if it is an integer + if type(read_buffer) == int: + break + frames = read_buffer.numpy() + for i in range(frames.shape[0]): + frame = frames[i] + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + out.write(frame) + out.release() class ImageSequenceWriter: @@ -110,106 +133,136 @@ def _add_batch(self, frames, index): # --------------- Main --------------- +if __name__ == '__main__': + device = torch.device(args.device) -device = torch.device(args.device) - -# Load model -if args.model_type == 'mattingbase': - model = MattingBase(args.model_backbone) -if args.model_type == 'mattingrefine': - model = MattingRefine( - args.model_backbone, - args.model_backbone_scale, - args.model_refine_mode, - args.model_refine_sample_pixels, - args.model_refine_threshold, - args.model_refine_kernel_size) - -model = model.to(device).eval() -model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) - - -# Load video and background -vid = VideoDataset(args.video_src) -bgr = [Image.open(args.video_bgr).convert('RGB')] -dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([ - A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()), - HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), - A.PairApply(T.ToTensor()) -])) -if args.video_target_bgr: - dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())]) - -# Create output directory -if os.path.exists(args.output_dir): - if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': - shutil.rmtree(args.output_dir) - else: - exit() -os.makedirs(args.output_dir) - - -# Prepare writers -if args.output_format == 'video': - h = args.video_resize[1] if args.video_resize is not None else vid.height - w = args.video_resize[0] if args.video_resize is not None else vid.width - if 'com' in args.output_types: - com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h) - if 'pha' in args.output_types: - pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h) - if 'fgr' in args.output_types: - fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h) - if 'err' in args.output_types: - err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h) - if 'ref' in args.output_types: - ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h) -else: - if 'com' in args.output_types: - com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png') - if 'pha' in args.output_types: - pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg') - if 'fgr' in args.output_types: - fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg') - if 'err' in args.output_types: - err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg') - if 'ref' in args.output_types: - ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg') - - -# Conversion loop -with torch.no_grad(): - for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): - if args.video_target_bgr: - (src, bgr), tgt_bgr = input_batch - tgt_bgr = tgt_bgr.to(device, non_blocking=True) + # Load model + if args.model_type == 'mattingbase': + model = MattingBase(args.model_backbone) + if args.model_type == 'mattingrefine': + model = MattingRefine( + args.model_backbone, + args.model_backbone_scale, + args.model_refine_mode, + args.model_refine_sample_pixels, + args.model_refine_threshold, + args.model_refine_kernel_size) + + model = model.to(device).eval() + model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) + + + # Load video and background + vid = VideoDataset(args.video_src) + bgr = Image.open(args.video_bgr).convert('RGB') + + transforms = T.Compose([ + T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity(), + T.ToTensor() + ]) + + bgr = transforms(bgr) + dataset = VideoDataset(args.video_src, transforms=transforms) + + if args.video_target_bgr: + dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())]) + + # Create output directory + if os.path.exists(args.output_dir): + if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': + shutil.rmtree(args.output_dir) else: - src, bgr = input_batch - tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1) - src = src.to(device, non_blocking=True) - bgr = bgr.to(device, non_blocking=True) - - if args.model_type == 'mattingbase': - pha, fgr, err, _ = model(src, bgr) - elif args.model_type == 'mattingrefine': - pha, fgr, _, _, err, ref = model(src, bgr) - elif args.model_type == 'mattingbm': - pha, fgr = model(src, bgr) + exit() + os.makedirs(args.output_dir) + + # Prepare writers + if args.output_format == 'video': + h = args.video_resize[1] if args.video_resize is not None else vid.height + w = args.video_resize[0] if args.video_resize is not None else vid.width if 'com' in args.output_types: - if args.output_format == 'video': - # Output composite with green background - com = fgr * pha + tgt_bgr * (1 - pha) - com_writer.add_batch(com) - else: - # Output composite as rgba png images - com = torch.cat([fgr * pha.ne(0), pha], dim=1) - com_writer.add_batch(com) + com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h) if 'pha' in args.output_types: - pha_writer.add_batch(pha) + pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h) if 'fgr' in args.output_types: - fgr_writer.add_batch(fgr) + fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h) if 'err' in args.output_types: - err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) + err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h) if 'ref' in args.output_types: - ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) + ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h) + else: + if 'com' in args.output_types: + com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png') + if 'pha' in args.output_types: + pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg') + if 'fgr' in args.output_types: + fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg') + if 'err' in args.output_types: + err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg') + if 'ref' in args.output_types: + ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg') + + + # Conversion loop + with torch.no_grad(): + queue = Queue(1) + def load_worker(): + tgt_bgr = torch.tensor([120/255, 255/255, 155/255]).view(1, 3, 1, 1) + for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): + if args.video_target_bgr: + src, tgt_bgr = input_batch + else: + src = input_batch + queue.put((src, tgt_bgr)) + queue.put(None) + loader = Thread(target=load_worker) + loader.start() + # move background to device + bgr = (bgr[None]).to(device, non_blocking=False) + while True: + task = queue.get() + if task == None: + break + src, tgt_bgr = task + # move frame to device + src = src.to(device, non_blocking=True) + tgt_bgr = tgt_bgr.to(device, non_blocking=True) + + if args.model_type == 'mattingbase': + pha, fgr, err, _ = model(src, bgr) + elif args.model_type == 'mattingrefine': + pha, fgr, _, _, err, ref = model(src, bgr) + elif args.model_type == 'mattingbm': + pha, fgr = model(src, bgr) + + if 'com' in args.output_types: + if args.output_format == 'video': + # Output composite with green background + com = fgr * pha + tgt_bgr * (1 - pha) + com_writer.add_batch(com) + else: + # Output composite as rgba png images + com = torch.cat([fgr * pha.ne(0), pha], dim=1) + com_writer.add_batch(com) + if 'pha' in args.output_types: + pha_writer.add_batch(pha) + if 'fgr' in args.output_types: + fgr_writer.add_batch(fgr) + if 'err' in args.output_types: + err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) + if 'ref' in args.output_types: + ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) + # terminate children processes + loader.join() + if args.output_format == 'video': + if 'com' in args.output_types: + com_writer.close() + if 'pha' in args.output_types: + pha_writer.close() + if 'fgr' in args.output_types: + fgr_writer.close() + if 'err' in args.output_types: + err_writer.close() + if 'ref' in args.output_types: + ref_writer.close()