Skip to content

Commit 3c73421

Browse files
authored
Merge pull request #25 from GradientSpaces/anchor-free
- Add anchor-free mode for RPF. - Update Readme on camera-ready paper.
2 parents 4baf4bc + 8fe3d15 commit 3c73421

18 files changed

+203
-117
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818

1919
## 🔔 News
20+
- [Oct 26, 2025] Our NeurIPS camera-ready [paper](https://arxiv.org/abs/2506.05282v2) is available! 🎉
21+
- We include additional experiments on generalizability and a new **anchor-free** model, which aligns more with practical assembly assumptions.
22+
- We release **Version 1.1** to support the anchor-free model; see the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details.
23+
2024
- [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego!
2125

2226
- [July 22, 2025] **Version 1.0**: We strongly recommend updating to this version, which includes:
@@ -292,6 +296,7 @@ Define parameters for Lightning's [Trainer](https://lightning.ai/docs/pytorch/la
292296

293297
**Dataloader workers killed**: Usually this is a signal of insufficient CPU memory or stack. You may try to reduce the `num_workers`.
294298

299+
295300
> [!NOTE]
296301
> Please don't hesitate to open an [issue](/issues) if you encounter any problems or bugs!
297302

config/RPF_base_main_10k.yaml

Lines changed: 0 additions & 35 deletions
This file was deleted.

config/RPF_base_predict_overlap.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ ckpt_path: null # when null, the checkpoint will be downloaded from Hug
2323
# Model settings
2424
model:
2525
compute_overlap_points: true
26+
build_overlap_head: true

config/RPF_base_pretrain.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ hydra:
2323

2424
model:
2525
compute_overlap_points: true
26+
build_overlap_head: true
2627

2728
data:
2829
limit_val_samples: 1000

config/data/ikea.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ num_points_to_sample: 5000
33
min_parts: 2
44
max_parts: 64
55
min_points_per_part: 20
6-
multi_anchor: true
6+
anchor_free: true
7+
multi_anchor: false
78

89
data_root: ${data_root}
910
dataset_names: ["ikea"]

config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ num_points_to_sample: 5000
33
min_parts: 2
44
max_parts: 64
55
min_points_per_part: 20
6-
multi_anchor: true
6+
anchor_free: true
7+
multi_anchor: false
78

89
data_root: ${data_root}
910
dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl"]

config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ num_points_to_sample: 5000
33
min_parts: 2
44
max_parts: 64
55
min_points_per_part: 20
6-
multi_anchor: true
6+
anchor_free: true
7+
multi_anchor: false
78

89
data_root: ${data_root}
910
dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl", "objaverse_v1"]

config/data/ikea_twobytwo.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ num_points_to_sample: 5000
33
min_parts: 2
44
max_parts: 64
55
min_points_per_part: 20
6+
anchor_free: true
67
multi_anchor: false
78

89
data_root: ${data_root}

config/model/rectified_point_flow.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ timestep_sampling: "u_shaped"
2222
inference_sampler: "euler"
2323
inference_sampling_steps: 50
2424
n_generations: 1
25+
anchor_free: true

rectified_point_flow/data/datamodule.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
up_axis: dict[str, str] = {},
3434
min_parts: int = 2,
3535
max_parts: int = 64,
36+
anchor_free: bool = True,
3637
num_points_to_sample: int = 5000,
3738
min_points_per_part: int = 20,
3839
min_dataset_size: int = 2000,
@@ -51,6 +52,9 @@ def __init__(
5152
If not provided, the up axis is assumed to be 'y'. This only affects the visualization.
5253
min_parts: Minimum number of parts in a point cloud.
5354
max_parts: Maximum number of parts in a point cloud.
55+
anchor_free: Whether to use anchor-free mode.
56+
If True, the anchor part is centered and randomly rotated, like the non-anchor parts (default).
57+
If False, the anchor part is not centered and thus its pose in the CoM frame of the GT point cloud is given (align with GARF).
5458
num_points_to_sample: Number of points to sample from each point cloud.
5559
min_points_per_part: Minimum number of points per part.
5660
min_dataset_size: Minimum number of point clouds in a dataset.
@@ -65,6 +69,7 @@ def __init__(
6569
self.up_axis = up_axis
6670
self.min_parts = min_parts
6771
self.max_parts = max_parts
72+
self.anchor_free = anchor_free
6873
self.num_points_to_sample = num_points_to_sample
6974
self.min_points_per_part = min_points_per_part
7075
self.batch_size = batch_size
@@ -120,6 +125,7 @@ def setup(self, stage: str):
120125
num_points_to_sample=self.num_points_to_sample,
121126
min_points_per_part=self.min_points_per_part,
122127
min_dataset_size=self.min_dataset_size,
128+
anchor_free=self.anchor_free,
123129
random_scale_range=self.random_scale_range,
124130
multi_anchor=self.multi_anchor,
125131
)
@@ -136,6 +142,7 @@ def setup(self, stage: str):
136142
dataset_name=dataset_name,
137143
min_parts=self.min_parts,
138144
max_parts=self.max_parts,
145+
anchor_free=self.anchor_free,
139146
num_points_to_sample=self.num_points_to_sample,
140147
min_points_per_part=self.min_points_per_part,
141148
limit_val_samples=self.limit_val_samples,
@@ -146,6 +153,7 @@ def setup(self, stage: str):
146153
logger.info(make_line())
147154
logger.info("Total Train Samples: " + str(self.train_dataset.cumulative_sizes[-1]))
148155
logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1]))
156+
logger.info("Anchor-free Mode: " + str(self.anchor_free))
149157

150158
elif stage == "validate":
151159
self.val_dataset = ConcatDataset(
@@ -157,6 +165,7 @@ def setup(self, stage: str):
157165
up_axis=self.up_axis.get(dataset_name, "y"),
158166
min_parts=self.min_parts,
159167
max_parts=self.max_parts,
168+
anchor_free=self.anchor_free,
160169
num_points_to_sample=self.num_points_to_sample,
161170
min_points_per_part=self.min_points_per_part,
162171
limit_val_samples=self.limit_val_samples,
@@ -166,6 +175,7 @@ def setup(self, stage: str):
166175
)
167176
logger.info(make_line())
168177
logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1]))
178+
logger.info("Anchor-free Mode: " + str(self.anchor_free))
169179

170180
elif stage in ["test", "predict"]:
171181
self.test_dataset = [
@@ -176,6 +186,7 @@ def setup(self, stage: str):
176186
up_axis=self.up_axis.get(dataset_name, "y"),
177187
min_parts=self.min_parts,
178188
max_parts=self.max_parts,
189+
anchor_free=self.anchor_free,
179190
num_points_to_sample=self.num_points_to_sample,
180191
min_points_per_part=self.min_points_per_part,
181192
limit_val_samples=self.limit_val_samples,
@@ -184,6 +195,7 @@ def setup(self, stage: str):
184195
]
185196
logger.info(make_line())
186197
logger.info("Total Test Samples: " + str(sum(len(dataset) for dataset in self.test_dataset)))
198+
logger.info("Anchor-free Mode: " + str(self.anchor_free))
187199

188200
def train_dataloader(self):
189201
"""Get training dataloader."""

0 commit comments

Comments
 (0)