diff --git a/pcdet/datasets/custom/custom_dataset.py b/pcdet/datasets/custom/custom_dataset.py index 3715210b1..42ba086f7 100644 --- a/pcdet/datasets/custom/custom_dataset.py +++ b/pcdet/datasets/custom/custom_dataset.py @@ -193,7 +193,8 @@ def create_groundtruth_database(self, info_path=None, used_classes=None, split=' for i in range(num_obj): filename = '%s_%s_%d.bin' % (sample_idx, names[i], i) filepath = database_save_path / filename - gt_points = points[point_indices[i] > 0] + gt_points = points[point_indices[i] > 0].astype(np.float32) + gt_boxes[i, :3] = gt_boxes[i, :3].astype(np.float32) gt_points[:, :3] -= gt_boxes[i, :3] with open(filepath, 'w') as f: