Skip to content

How to use sparse vae? #61

@wusize

Description

@wusize

Hi, nice work!

I am wondering how to use the sparse VAE to encode and decode meshes, as shown by the code in :

feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx']

May I know how to prepare the input for the encoder?

I tried the following approach but got unexpected results.

import torch
import trimesh
from omegaconf import OmegaConf
from direct3d_s2.utils import instantiate_from_config

from direct3d_s2.utils.mesh import compute_valid_udf, normalize_mesh


def preprocess(mesh, size=512, device="cuda:0"):
    vertices = torch.Tensor(mesh.vertices).float().to(device) * 0.5
    faces = torch.Tensor(mesh.faces).int().to(device)
    sdf = compute_valid_udf(vertices, faces, dim=size, threshold=4.0)
    sdf = sdf.reshape(size, size, size).unsqueeze(0)

    sparse_index = (sdf < 4/size).nonzero()
    sparse_sdf = sdf[sdf < 4/size]

    return sparse_index, sparse_sdf



mesh = trimesh.load('output_512.obj')

mesh = normalize_mesh(mesh)
# mesh.show()

sparse_index, sparse_sdf = preprocess(mesh, size=512, device="cuda:0")


model_sparse_512_path = 'wushuang98/Direct3D-S2/direct3d-s2-v-1-1/model_sparse_512.ckpt'
config_path = 'wushuang98/Direct3D-S2/direct3d-s2-v-1-1/config.yaml'

cfg = OmegaConf.load(config_path)

state_dict_sparse_512 = torch.load(model_sparse_512_path, map_location='cpu', weights_only=True)
print(f"Load sparse vae 512: {cfg.sparse_vae_512}")

sparse_vae_512 = instantiate_from_config(cfg.sparse_vae_512)
sparse_vae_512.load_state_dict(state_dict_sparse_512["vae"], strict=True)
sparse_vae_512 = sparse_vae_512.eval().to("cuda:0")

dtype = next(sparse_vae_512.parameters()).dtype

sparse_sdf = sparse_sdf.to(dtype)

## encode
batch = dict(sparse_sdf=sparse_sdf, sparse_index=sparse_index[:, 1:], batch_idx=sparse_index[:, 0])


with torch.no_grad():
    z, posterior = sparse_vae_512.encode(batch)


## decode
with torch.no_grad():
    mesh = sparse_vae_512.decode_mesh(latents=z,
                                      voxel_resolution=512,
                                      mc_threshold= 0.2,
                                      return_feat=False,
                                      factor=1.0)[0]

mesh = normalize_mesh(mesh)
mesh.export('reconstructed.obj')

input:

Image

output:

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions