|
| 1 | +import os |
| 2 | +import subprocess |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +from pydantic import PositiveInt |
| 6 | + |
| 7 | +import data_juicer |
| 8 | +from data_juicer.ops.load import load_ops |
| 9 | +from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE |
| 10 | +from data_juicer.utils.constant import Fields, MetaKeys |
| 11 | +from data_juicer.utils.lazy_loader import LazyLoader |
| 12 | + |
| 13 | +from ..base_op import OPERATORS, Mapper |
| 14 | +from ..op_fusion import LOADED_VIDEOS |
| 15 | + |
| 16 | +OP_NAME = "video_camera_pose_mapper" |
| 17 | + |
| 18 | +cv2 = LazyLoader("cv2", "opencv-python") |
| 19 | +torch = LazyLoader("torch") |
| 20 | + |
| 21 | + |
| 22 | +@OPERATORS.register_module(OP_NAME) |
| 23 | +@LOADED_VIDEOS.register_module(OP_NAME) |
| 24 | +class VideoCameraPoseMapper(Mapper): |
| 25 | + """Extract camera poses by leveraging MegaSaM and MoGe-2.""" |
| 26 | + |
| 27 | + _accelerator = "cuda" |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + moge_model_path: str = "Ruicheng/moge-2-vitl", |
| 32 | + frame_num: PositiveInt = 3, |
| 33 | + duration: float = 0, |
| 34 | + tag_field_name: str = MetaKeys.video_camera_pose_tags, |
| 35 | + frame_dir: str = DATA_JUICER_ASSETS_CACHE, |
| 36 | + if_output_moge_info: bool = False, |
| 37 | + moge_output_info_dir: str = DATA_JUICER_ASSETS_CACHE, |
| 38 | + if_save_info: bool = True, |
| 39 | + output_info_dir: str = DATA_JUICER_ASSETS_CACHE, |
| 40 | + max_frames: int = 1000, |
| 41 | + *args, |
| 42 | + **kwargs, |
| 43 | + ): |
| 44 | + """ |
| 45 | + Initialization method. |
| 46 | +
|
| 47 | + :param moge_model_path: The path to the Moge-2 model. |
| 48 | + :param frame_num: The number of frames to be extracted uniformly from |
| 49 | + the video. If it's 1, only the middle frame will be extracted. If |
| 50 | + it's 2, only the first and the last frames will be extracted. If |
| 51 | + it's larger than 2, in addition to the first and the last frames, |
| 52 | + other frames will be extracted uniformly within the video duration. |
| 53 | + If "duration" > 0, frame_num is the number of frames per segment. |
| 54 | + :param duration: The duration of each segment in seconds. |
| 55 | + If 0, frames are extracted from the entire video. |
| 56 | + If duration > 0, the video is segmented into multiple segments |
| 57 | + based on duration, and frames are extracted from each segment. |
| 58 | + :param tag_field_name: The field name to store the tags. It's |
| 59 | + "video_camera_pose_tags" in default. |
| 60 | + :param frame_dir: Output directory to save extracted frames. |
| 61 | + :param if_output_moge_info: Whether to save the results from MoGe-2 |
| 62 | + to an JSON file. |
| 63 | + :param moge_output_info_dir: Output directory for saving camera |
| 64 | + parameters. |
| 65 | + :param if_save_info: Whether to save the results to an npz file. |
| 66 | + :param output_info_dir: Path for saving the results. |
| 67 | + :param max_frames: Maximum number of frames to save. |
| 68 | + :param args: extra args |
| 69 | + :param kwargs: extra args |
| 70 | +
|
| 71 | + """ |
| 72 | + |
| 73 | + super().__init__(*args, **kwargs) |
| 74 | + |
| 75 | + self.video_camera_calibration_static_moge_mapper_args = { |
| 76 | + "model_path": moge_model_path, |
| 77 | + "frame_num": frame_num, |
| 78 | + "duration": duration, |
| 79 | + "frame_dir": frame_dir, |
| 80 | + "if_output_points_info": False, |
| 81 | + "if_output_depth_info": True, |
| 82 | + "if_output_mask_info": True, |
| 83 | + "if_output_info": if_output_moge_info, |
| 84 | + "output_info_dir": moge_output_info_dir, |
| 85 | + } |
| 86 | + self.fused_ops = load_ops( |
| 87 | + [{"video_camera_calibration_static_moge_mapper": self.video_camera_calibration_static_moge_mapper_args}] |
| 88 | + ) |
| 89 | + |
| 90 | + megasam_repo_path = os.path.join(DATA_JUICER_ASSETS_CACHE, "mega-sam") |
| 91 | + if not os.path.exists(megasam_repo_path): |
| 92 | + subprocess.run(["git", "clone", "https://github.com/mega-sam/mega-sam.git", megasam_repo_path], check=True) |
| 93 | + subprocess.run( |
| 94 | + ["git", "submodule", "update", "--init", "--recursive"], cwd=os.path.join(megasam_repo_path, "base") |
| 95 | + ) |
| 96 | + |
| 97 | + with open(os.path.join(megasam_repo_path, "base", "src", "altcorr_kernel.cu"), "r") as f: |
| 98 | + temp_file_content = f.read() |
| 99 | + temp_file_content = temp_file_content.replace(".type()", ".scalar_type()") |
| 100 | + |
| 101 | + with open(os.path.join(megasam_repo_path, "base", "src", "altcorr_kernel.cu"), "w") as f: |
| 102 | + f.write(temp_file_content) |
| 103 | + |
| 104 | + with open(os.path.join(megasam_repo_path, "base", "src", "correlation_kernels.cu"), "r") as f: |
| 105 | + temp_file_content = f.read() |
| 106 | + temp_file_content = temp_file_content.replace(".type()", ".scalar_type()") |
| 107 | + |
| 108 | + with open(os.path.join(megasam_repo_path, "base", "src", "correlation_kernels.cu"), "w") as f: |
| 109 | + f.write(temp_file_content) |
| 110 | + |
| 111 | + with open(os.path.join(megasam_repo_path, "base", "src", "droid_kernels.cu"), "r") as f: |
| 112 | + temp_file_content = f.read() |
| 113 | + temp_file_content = temp_file_content.replace(".type()", ".scalar_type()") |
| 114 | + |
| 115 | + with open(os.path.join(megasam_repo_path, "base", "src", "droid_kernels.cu"), "w") as f: |
| 116 | + f.write(temp_file_content) |
| 117 | + |
| 118 | + with open( |
| 119 | + os.path.join(megasam_repo_path, "base", "thirdparty", "lietorch", "lietorch", "src", "lietorch_gpu.cu"), |
| 120 | + "r", |
| 121 | + ) as f: |
| 122 | + temp_file_content = f.read() |
| 123 | + temp_file_content = temp_file_content.replace(".type()", ".scalar_type()") |
| 124 | + |
| 125 | + with open( |
| 126 | + os.path.join(megasam_repo_path, "base", "thirdparty", "lietorch", "lietorch", "src", "lietorch_gpu.cu"), |
| 127 | + "w", |
| 128 | + ) as f: |
| 129 | + f.write(temp_file_content) |
| 130 | + |
| 131 | + with open( |
| 132 | + os.path.join( |
| 133 | + megasam_repo_path, "base", "thirdparty", "lietorch", "lietorch", "src", "lietorch_cpu.cpp" |
| 134 | + ), |
| 135 | + "r", |
| 136 | + ) as f: |
| 137 | + temp_file_content = f.read() |
| 138 | + temp_file_content = temp_file_content.replace(".type()", ".scalar_type()") |
| 139 | + |
| 140 | + with open( |
| 141 | + os.path.join( |
| 142 | + megasam_repo_path, "base", "thirdparty", "lietorch", "lietorch", "src", "lietorch_cpu.cpp" |
| 143 | + ), |
| 144 | + "w", |
| 145 | + ) as f: |
| 146 | + f.write(temp_file_content) |
| 147 | + |
| 148 | + try: |
| 149 | + import droid_backends |
| 150 | + import lietorch |
| 151 | + |
| 152 | + self.droid_backends = droid_backends |
| 153 | + self.lietorch = lietorch |
| 154 | + except ImportError: |
| 155 | + subprocess.run(["python", "setup.py", "install"], cwd=os.path.join(megasam_repo_path, "base")) |
| 156 | + |
| 157 | + try: |
| 158 | + import torch_scatter |
| 159 | + |
| 160 | + self.torch_scatter = torch_scatter |
| 161 | + except ImportError: |
| 162 | + """ "Please refer to https://github.com/rusty1s/pytorch_scatter to locate the |
| 163 | + installation link that is compatible with your PyTorch and CUDA versions.""" |
| 164 | + torch_version = "2.8.0" |
| 165 | + cuda_version = "cu128" |
| 166 | + subprocess.run( |
| 167 | + [ |
| 168 | + "pip", |
| 169 | + "install", |
| 170 | + "torch-scatter", |
| 171 | + "-f", |
| 172 | + f"https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html", |
| 173 | + ], |
| 174 | + cwd=os.path.join(megasam_repo_path, "base"), |
| 175 | + ) |
| 176 | + |
| 177 | + import sys |
| 178 | + |
| 179 | + sys.path.append(os.path.join(megasam_repo_path, "base", "droid_slam")) |
| 180 | + from droid import Droid |
| 181 | + from lietorch import SE3 |
| 182 | + |
| 183 | + self.SE3 = SE3 |
| 184 | + self.Droid = Droid |
| 185 | + |
| 186 | + self.tag_field_name = tag_field_name |
| 187 | + self.if_save_info = if_save_info |
| 188 | + self.output_info_dir = output_info_dir |
| 189 | + self.max_frames = max_frames |
| 190 | + self.frame_dir = frame_dir |
| 191 | + |
| 192 | + def image_stream(self, frames_path, depth_list, intrinsics_list): |
| 193 | + |
| 194 | + for t, (image_path, depth, intrinsics) in enumerate(zip(frames_path, depth_list, intrinsics_list)): |
| 195 | + image = cv2.imread(image_path) |
| 196 | + h0, w0, _ = image.shape |
| 197 | + h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) |
| 198 | + w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) |
| 199 | + |
| 200 | + image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_AREA) |
| 201 | + image = image[: h1 - h1 % 8, : w1 - w1 % 8] |
| 202 | + image = torch.as_tensor(image).permute(2, 0, 1) |
| 203 | + image = image[None] |
| 204 | + |
| 205 | + depth = torch.as_tensor(depth) |
| 206 | + depth = torch.nn.functional.interpolate(depth[None, None], (h1, w1), mode="nearest-exact").squeeze() |
| 207 | + depth = depth[: h1 - h1 % 8, : w1 - w1 % 8] |
| 208 | + |
| 209 | + mask = torch.ones_like(depth) |
| 210 | + |
| 211 | + intrinsics = torch.as_tensor([intrinsics[0][0], intrinsics[1][1], intrinsics[0][2], intrinsics[1][2]]) |
| 212 | + intrinsics[0::2] *= w1 / w0 |
| 213 | + intrinsics[1::2] *= h1 / h0 |
| 214 | + |
| 215 | + yield t, image, depth, intrinsics, mask |
| 216 | + |
| 217 | + def process_single(self, sample=None, rank=None): |
| 218 | + # check if it's generated already |
| 219 | + if self.tag_field_name in sample[Fields.meta]: |
| 220 | + return sample |
| 221 | + |
| 222 | + # there is no video in this sample |
| 223 | + if self.video_key not in sample or not sample[self.video_key]: |
| 224 | + return [] |
| 225 | + |
| 226 | + ds_list = [{"videos": sample[self.video_key]}] |
| 227 | + |
| 228 | + dataset = data_juicer.core.data.NestedDataset.from_list(ds_list) |
| 229 | + if Fields.meta not in dataset.features: |
| 230 | + dataset = dataset.add_column(name=Fields.meta, column=[{}] * dataset.num_rows) |
| 231 | + dataset = dataset.map(self.fused_ops[0].process, num_proc=1, with_rank=True) |
| 232 | + res_list = dataset.to_list() |
| 233 | + |
| 234 | + temp_frame_name = os.path.splitext(os.path.basename(sample[self.video_key][0]))[0] |
| 235 | + frames_root = os.path.join(self.frame_dir, temp_frame_name) |
| 236 | + frame_names = os.listdir(frames_root) |
| 237 | + frames_path = sorted([os.path.join(frames_root, frame_name) for frame_name in frame_names]) |
| 238 | + |
| 239 | + depth_list = res_list[0][Fields.meta][MetaKeys.static_camera_calibration_moge_tags]["depth_list"] |
| 240 | + intrinsics_list = res_list[0][Fields.meta][MetaKeys.static_camera_calibration_moge_tags]["intrinsics_list"] |
| 241 | + |
| 242 | + valid_image_list = [] |
| 243 | + valid_depth_list = [] |
| 244 | + valid_intrinsics_list = [] |
| 245 | + valid_mask_list = [] |
| 246 | + |
| 247 | + # for t, (image_path, depth, intrinsics) in enumerate(zip(frames_path, depth_list, intrinsics_list)): |
| 248 | + |
| 249 | + for t, image, depth, intrinsics, mask in self.image_stream(frames_path, depth_list, intrinsics_list): |
| 250 | + |
| 251 | + valid_image_list.append(image[0]) |
| 252 | + valid_depth_list.append(depth) |
| 253 | + valid_mask_list.append(mask) |
| 254 | + valid_intrinsics_list.append(intrinsics) |
| 255 | + |
| 256 | + if t == 0: |
| 257 | + args = droid_args(image_size=[image.shape[2], image.shape[3]]) |
| 258 | + droid = self.Droid(args) |
| 259 | + |
| 260 | + droid.track(t, image, depth, intrinsics=intrinsics, mask=mask) |
| 261 | + |
| 262 | + droid.track_final(t, image, depth, intrinsics=intrinsics, mask=mask) |
| 263 | + |
| 264 | + traj_est, depth_est, motion_prob = droid.terminate( |
| 265 | + self.image_stream(frames_path, depth_list, intrinsics_list), |
| 266 | + _opt_intr=True, |
| 267 | + full_ba=True, |
| 268 | + scene_name=temp_frame_name, |
| 269 | + ) |
| 270 | + |
| 271 | + t = traj_est.shape[0] |
| 272 | + images = np.array(valid_image_list[:t]) |
| 273 | + disps = 1.0 / (np.array(valid_depth_list[:t]) + 1e-6) |
| 274 | + |
| 275 | + poses = traj_est |
| 276 | + intrinsics = droid.video.intrinsics[:t].cpu().numpy() |
| 277 | + |
| 278 | + intrinsics = intrinsics[0] * 8.0 |
| 279 | + poses_th = torch.as_tensor(poses, device="cpu") |
| 280 | + cam_c2w = self.SE3(poses_th).inv().matrix().numpy() |
| 281 | + |
| 282 | + K = np.eye(3) |
| 283 | + K[0, 0] = intrinsics[0] |
| 284 | + K[1, 1] = intrinsics[1] |
| 285 | + K[0, 2] = intrinsics[2] |
| 286 | + K[1, 2] = intrinsics[3] |
| 287 | + |
| 288 | + max_frames = min(self.max_frames, images.shape[0]) |
| 289 | + |
| 290 | + return_images = np.uint8(images[:max_frames, ::-1, ...].transpose(0, 2, 3, 1)) |
| 291 | + return_depths = np.float32(1.0 / disps[:max_frames, ...]) |
| 292 | + return_cam_c2w = cam_c2w[:max_frames] |
| 293 | + |
| 294 | + if self.if_save_info: |
| 295 | + os.makedirs(self.output_info_dir, exist_ok=True) |
| 296 | + |
| 297 | + np.savez( |
| 298 | + os.path.join(self.output_info_dir, "%s_droid.npz" % temp_frame_name), |
| 299 | + images=return_images, |
| 300 | + depths=return_depths, |
| 301 | + intrinsic=K, |
| 302 | + cam_c2w=return_cam_c2w, |
| 303 | + ) |
| 304 | + |
| 305 | + sample[Fields.meta][self.tag_field_name] = { |
| 306 | + "frames_folder": frames_root, |
| 307 | + "frame_names": frame_names, |
| 308 | + "images": return_images, |
| 309 | + "depths": return_depths, |
| 310 | + "intrinsic": K, |
| 311 | + "cam_c2w": return_cam_c2w, |
| 312 | + } |
| 313 | + |
| 314 | + return sample |
| 315 | + |
| 316 | + |
| 317 | +class droid_args: |
| 318 | + def __init__(self, image_size): |
| 319 | + self.weights = os.path.join(DATA_JUICER_ASSETS_CACHE, "mega-sam", "checkpoints", "megasam_final.pth") |
| 320 | + self.disable_vis = True |
| 321 | + self.image_size = image_size |
| 322 | + self.buffer = 1024 |
| 323 | + self.stereo = False |
| 324 | + self.filter_thresh = 2.0 |
| 325 | + |
| 326 | + self.warmup = 8 |
| 327 | + self.beta = 0.3 |
| 328 | + self.frontend_nms = 1 |
| 329 | + self.keyframe_thresh = 2.0 |
| 330 | + self.frontend_window = 25 |
| 331 | + self.frontend_thresh = 12.0 |
| 332 | + self.frontend_radius = 2 |
| 333 | + |
| 334 | + self.upsample = False |
| 335 | + self.backend_thresh = 16.0 |
| 336 | + self.backend_radius = 2 |
| 337 | + self.backend_nms = 3 |
0 commit comments