-
Let's say I sample points from a mesh and retain the relationship between points and mesh faces. I then process the points through some network (for example PointNet++) to produce pointwise predictions. I would then like to aggregate these predictions to the corresponding mesh faces by pooling (for example mode) the predictions of points associated to each face. Is there existing capability for this? I know the pytorch_scatter functions can be used to do this aggregation in theory but not only is it aggregating points to each face, but also needs to do this separately across the examples in a batch. I was wondering if there is already something in PyG or the related dependencies that can do it, before I try and implement it myself. Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This was actually not as difficult as I initially thought. For anyone stumbling across this in future for their use case, essentially you make the point-face index unique across the batch by incrementing them (similar to what is done for the mesh face indices when collating) and then use scatter add across the one hot vectors to get the modal prediction (similar to how modal label is obtained here. Sample implementation if interested: increment = torch.concat([torch.tensor([0]), torch.cumsum(num_faces_per_sample)])[:-1]
unique_face_idxs = pointwise_face_idxs + offset.repeat_interleave(num_points_per_sample)
# We also increment the labels so that mesh faces that were missed in the point cloud sampling are given a different label
# during the scatter_add and argmax steps
onehot_labels = F.one_hot(pointwise_labels+1)
onehot_sums = torch_scatter.scatter_add(onehot_labels, unique_face_idxs , dim=0)
aggregated_labels = onehot_sums.argmax(dim=-1) - 1
# If any of these 'missed' faces are at the very end, these will not have been accounted for so:
num_missed_faces = num_faces_per_sample.sum() - aggregated_labels.size(0)
if num_missed_faces > 0:
aggregated_labels = torch.concat( [aggregated_labels, torch.ones(num_missed_faces)*-1] ) Assuming that the undeclared variables can be obtained from the data batch or otherwise. Unsure if this is a common enough use case to warrant adding? |
Beta Was this translation helpful? Give feedback.
This was actually not as difficult as I initially thought. For anyone stumbling across this in future for their use case, essentially you make the point-face index unique across the batch by incrementing them (similar to what is done for the mesh face indices when collating) and then use scatter add across the one hot vectors to get the modal prediction (similar to how modal label is obtained here.
Sample implementation if interested: