Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)

# Load and preprocess example images (replace with your own image paths)
image_names = ["path/to/imageA.png", "path/to/imageB.png", "path/to/imageC.png"]
images = load_and_preprocess_images(image_names).to(device)
images = load_and_preprocess_images(image_names)[0].to(device)

with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
Expand Down
12 changes: 10 additions & 2 deletions demo_colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,15 @@ def demo_fn(args):
vggt_fixed_resolution = 518
img_load_resolution = 1024

images, original_coords = load_and_preprocess_images_square(image_path_list, img_load_resolution)
images, alpha_masks, original_coords = load_and_preprocess_images_square(image_path_list, img_load_resolution)
images = images.to(device)
alpha_masks = F.interpolate(
alpha_masks,
size=(vggt_fixed_resolution, vggt_fixed_resolution),
mode="bilinear",
align_corners=False,
).squeeze(1)
alpha_masks = alpha_masks.cpu().numpy()
original_coords = original_coords.to(device)
print(f"Loaded {len(images)} images from {image_dir}")

Expand Down Expand Up @@ -209,7 +216,8 @@ def demo_fn(args):
# (S, H, W, 3), with x, y coordinates and frame indices
points_xyf = create_pixel_coordinate_grid(num_frames, height, width)

conf_mask = depth_conf >= conf_thres_value
# mask out points with low confidence and transparent pixels
conf_mask = (depth_conf >= conf_thres_value) & (alpha_masks > 0.5)
# at most writing 100000 3d points to colmap reconstruction object
conf_mask = randomly_limit_trues(conf_mask, max_points_for_colmap)

Expand Down
6 changes: 5 additions & 1 deletion demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def run_model(target_dir, model) -> dict:
if len(image_names) == 0:
raise ValueError("No images found. Check your upload.")

images = load_and_preprocess_images(image_names).to(device)
images, alpha_masks = load_and_preprocess_images(image_names)
images = images.to(device)
print(f"Preprocessed images shape: {images.shape}")

# Run inference
Expand Down Expand Up @@ -92,6 +93,8 @@ def run_model(target_dir, model) -> dict:
world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
predictions["world_points_from_depth"] = world_points

predictions["alpha_masks"] = alpha_masks.cpu().numpy()

# Clean up
torch.cuda.empty_cache()
return predictions
Expand Down Expand Up @@ -299,6 +302,7 @@ def update_visualization(
"extrinsic",
"intrinsic",
"world_points_from_depth",
"alpha_masks",
]

loaded = np.load(predictions_path)
Expand Down
15 changes: 14 additions & 1 deletion demo_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def viser_wrapper(
"depth_conf": (S, H, W),
"extrinsic": (S, 3, 4),
"intrinsic": (S, 3, 3),
"alpha_masks": (S, H, W),
}
port (int): Port number for the viser server.
init_conf_threshold (float): Initial percentage of low-confidence points to filter out.
Expand All @@ -77,6 +78,8 @@ def viser_wrapper(
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)

alpha_masks = pred_dict["alpha_masks"] # (S, H, W)

# Compute world points from depth if not using the precomputed point map
if not use_point_map:
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
Expand All @@ -99,6 +102,12 @@ def viser_wrapper(
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
conf_flat = conf.reshape(-1)

alpha_filter = alpha_masks.flatten() > .5 # visibility threshold at 50%

points = points[alpha_filter]
colors_flat = colors_flat[alpha_filter]
conf_flat = conf_flat[alpha_filter]

cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically
# For convenience, we store only (3,4) portion
cam_to_world = cam_to_world_mat[:, :3, :]
Expand Down Expand Up @@ -356,7 +365,8 @@ def main():
image_names = glob.glob(os.path.join(args.image_folder, "*"))
print(f"Found {len(image_names)} images")

images = load_and_preprocess_images(image_names).to(device)
images, alpha_masks = load_and_preprocess_images(image_names)
images = images.to(device)
print(f"Preprocessed images shape: {images.shape}")

print("Running inference...")
Expand All @@ -376,6 +386,9 @@ def main():
if isinstance(predictions[key], torch.Tensor):
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension and convert to numpy

# Process alpha masks separate from predictions due to shape mismatch (no batch dim)
predictions["alpha_masks"] = alpha_masks.cpu().numpy().squeeze(1)

if args.use_point_map:
print("Visualizing 3D points from point map")
else:
Expand Down
138 changes: 101 additions & 37 deletions vggt/utils/load_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,38 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Literal

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms as TF
import numpy as np


def load_and_preprocess_images_square(image_path_list, target_size=1024):
_IMAGENET_MEAN: tuple[float, float, float] = (0.485, 0.456, 0.406)


def load_and_preprocess_images_square(
image_path_list: list[str],
target_size: int = 1024,
background_color: Literal["white", "black", "imagenet_mean"]="white",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Load and preprocess images by center padding to square and resizing to target size.
Also returns the position information of original pixels after transformation.

Args:
image_path_list (list): List of paths to image files
target_size (int, optional): Target size for both width and height. Defaults to 518.
background_color (str, optional): Background color for transparent images to blend onto,
either "white", "black" or "imagenet_mean". Defaults to "white".

Returns:
tuple: (
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
torch.Tensor: Batched tensor of alpha masks with shape (N, 1, H, W),
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image.
)

Raises:
Expand All @@ -32,22 +44,21 @@ def load_and_preprocess_images_square(image_path_list, target_size=1024):
if len(image_path_list) == 0:
raise ValueError("At least 1 image is required")

bg_map = {
"black": (0.0, 0.0, 0.0),
"white": (1.0, 1.0, 1.0),
"imagenet_mean": _IMAGENET_MEAN,
}

images = []
masks = []
original_coords = [] # Renamed from position_info to be more descriptive
to_tensor = TF.ToTensor()

for image_path in image_path_list:
# Open image
img = Image.open(image_path)

# If there's an alpha channel, blend onto white background
if img.mode == "RGBA":
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
img = Image.alpha_composite(background, img)

# Convert to RGB
img = img.convert("RGB")

# Get original dimensions
width, height = img.size

Expand All @@ -70,19 +81,37 @@ def load_and_preprocess_images_square(image_path_list, target_size=1024):
# Store original image coordinates and scale
original_coords.append(np.array([x1, y1, x2, y2, width, height]))

# Convert to tensor
img_tensor = to_tensor(img)

# Create a new black square image and paste original
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
square_img.paste(img, (left, top))
square_tensor = img_tensor.new_zeros(img_tensor.shape[0], max_dim, max_dim)
square_tensor[:, top : top + height, left : left + width] = img_tensor

# Resize to target size
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)

# Convert to tensor
img_tensor = to_tensor(square_img)
images.append(img_tensor)
square_tensor = F.interpolate(
square_tensor.unsqueeze(0),
size=(target_size, target_size),
mode="bilinear",
align_corners=False,
).squeeze(0)

# Handle alpha channel
if square_tensor.shape[0] == 4:
mask = square_tensor[3]
square_tensor = square_tensor[:3]
bg_col = bg_map.get(background_color, (1.0, 1.0, 1.0))
bg = square_tensor.new_tensor(bg_col).view(3, 1, 1)
square_tensor = square_tensor * mask + bg * (1 - mask)
else:
mask = torch.ones_like(square_tensor[0])

images.append(square_tensor)
masks.append(mask.unsqueeze(0))

# Stack all images
images = torch.stack(images)
masks = torch.stack(masks)
original_coords = torch.from_numpy(np.array(original_coords)).float()

# Add additional dimension if single image to ensure correct shape
Expand All @@ -91,10 +120,17 @@ def load_and_preprocess_images_square(image_path_list, target_size=1024):
images = images.unsqueeze(0)
original_coords = original_coords.unsqueeze(0)

return images, original_coords
if masks.dim() == 3:
masks = masks.unsqueeze(0)

return images, masks, original_coords

def load_and_preprocess_images(image_path_list, mode="crop"):

def load_and_preprocess_images(
image_path_list: list[str],
mode: Literal["crop", "pad"]="crop",
background_color: Literal["white", "black", "imagenet_mean"]="white",
) -> tuple[torch.Tensor, torch.Tensor]:
"""
A quick start function to load and preprocess images for model input.
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
Expand All @@ -105,9 +141,13 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
- "crop" (default): Sets width to 518px and center crops height if needed.
- "pad": Preserves all pixels by making the largest dimension 518px
and padding the smaller dimension to reach a square shape.
background_color (str, optional): Background color for transparent images to blend onto,
either "white", "black" or "imagenet_mean". Defaults to "white".

Returns:
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
tuple[torch.Tensor, torch.Tensor]: Two element tuple containing:
Batched tensor of preprocessed images with shape (N, 3, H, W).
Batched tensor of alpha masks with shape (N, 1, H, W).

Raises:
ValueError: If the input list is empty or if mode is invalid
Expand All @@ -129,7 +169,14 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
if mode not in ["crop", "pad"]:
raise ValueError("Mode must be either 'crop' or 'pad'")

bg_map = {
"black": (0.0, 0.0, 0.0),
"white": (1.0, 1.0, 1.0),
"imagenet_mean": _IMAGENET_MEAN,
}

images = []
masks = []
shapes = set()
to_tensor = TF.ToTensor()
target_size = 518
Expand All @@ -138,17 +185,6 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
for image_path in image_path_list:
# Open image
img = Image.open(image_path)

# If there's an alpha channel, blend onto white background:
if img.mode == "RGBA":
# Create white background
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
# Alpha composite onto the white background
img = Image.alpha_composite(background, img)

# Now convert to "RGB" (this step assigns white for transparent areas)
img = img.convert("RGB")

width, height = img.size

if mode == "pad":
Expand Down Expand Up @@ -185,13 +221,24 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
pad_left = w_padding // 2
pad_right = w_padding - pad_left

# Pad with white (value=1.0)
img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
# Pad with white (value=0.0)
img = F.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=0.0
)

# Handle alpha channel
if img.shape[0] == 4:
mask = img[3]
img = img[:3]
bg_col = bg_map.get(background_color, (1.0, 1.0, 1.0))
bg = img.new_tensor(bg_col).view(3, 1, 1)
img = img * mask + bg * (1 - mask)
else:
mask = torch.ones_like(img[0])

shapes.add((img.shape[1], img.shape[2]))
images.append(img)
masks.append(mask.unsqueeze(0))

# Check if we have different shapes
# In theory our model can also work well with different shapes
Expand All @@ -203,7 +250,8 @@ def load_and_preprocess_images(image_path_list, mode="crop"):

# Pad images if necessary
padded_images = []
for img in images:
padded_masks = []
for img, mask in zip(images, masks):
h_padding = max_height - img.shape[1]
w_padding = max_width - img.shape[2]

Expand All @@ -213,18 +261,34 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
pad_left = w_padding // 2
pad_right = w_padding - pad_left

img = torch.nn.functional.pad(
if mask is None:
mask = torch.ones_like(img[0])

mask = F.pad(
mask, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=0.0
)

img = F.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
)

padded_images.append(img)
padded_masks.append(mask)

images = padded_images
masks = padded_masks

images = torch.stack(images) # concatenate images
masks = torch.stack(masks)

# Ensure correct shape when single image
if len(image_path_list) == 1:
# Verify shape is (1, C, H, W)
if images.dim() == 3:
images = images.unsqueeze(0)

return images
#Verify mask is (1, 1, H, W)
if masks.dim() == 3:
masks = masks.unsqueeze(0)

return images, masks
Loading