-
Notifications
You must be signed in to change notification settings - Fork 97
Open
Description
Hi, nice work!
I am wondering how to use the sparse VAE to encode and decode meshes, as shown by the code in :
| feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx'] |
May I know how to prepare the input for the encoder?
I tried the following approach but got unexpected results.
import torch
import trimesh
from omegaconf import OmegaConf
from direct3d_s2.utils import instantiate_from_config
from direct3d_s2.utils.mesh import compute_valid_udf, normalize_mesh
def preprocess(mesh, size=512, device="cuda:0"):
vertices = torch.Tensor(mesh.vertices).float().to(device) * 0.5
faces = torch.Tensor(mesh.faces).int().to(device)
sdf = compute_valid_udf(vertices, faces, dim=size, threshold=4.0)
sdf = sdf.reshape(size, size, size).unsqueeze(0)
sparse_index = (sdf < 4/size).nonzero()
sparse_sdf = sdf[sdf < 4/size]
return sparse_index, sparse_sdf
mesh = trimesh.load('output_512.obj')
mesh = normalize_mesh(mesh)
# mesh.show()
sparse_index, sparse_sdf = preprocess(mesh, size=512, device="cuda:0")
model_sparse_512_path = 'wushuang98/Direct3D-S2/direct3d-s2-v-1-1/model_sparse_512.ckpt'
config_path = 'wushuang98/Direct3D-S2/direct3d-s2-v-1-1/config.yaml'
cfg = OmegaConf.load(config_path)
state_dict_sparse_512 = torch.load(model_sparse_512_path, map_location='cpu', weights_only=True)
print(f"Load sparse vae 512: {cfg.sparse_vae_512}")
sparse_vae_512 = instantiate_from_config(cfg.sparse_vae_512)
sparse_vae_512.load_state_dict(state_dict_sparse_512["vae"], strict=True)
sparse_vae_512 = sparse_vae_512.eval().to("cuda:0")
dtype = next(sparse_vae_512.parameters()).dtype
sparse_sdf = sparse_sdf.to(dtype)
## encode
batch = dict(sparse_sdf=sparse_sdf, sparse_index=sparse_index[:, 1:], batch_idx=sparse_index[:, 0])
with torch.no_grad():
z, posterior = sparse_vae_512.encode(batch)
## decode
with torch.no_grad():
mesh = sparse_vae_512.decode_mesh(latents=z,
voxel_resolution=512,
mc_threshold= 0.2,
return_feat=False,
factor=1.0)[0]
mesh = normalize_mesh(mesh)
mesh.export('reconstructed.obj')
input:
output:

Metadata
Metadata
Assignees
Labels
No labels