Skip to content

Commit 4797086

Browse files
committed
Demo update + single inference instance retrieval
refactoring change + single inference args fix demo cleanup adding single inference instance retrieval readme update
1 parent 35f9c5b commit 4797086

File tree

9 files changed

+670
-178
lines changed

9 files changed

+670
-178
lines changed

DATA.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ File structure below:
4646
## :wrench: Data Preprocessing
4747
In order to process data faster during training + inference, we preprocess 1D (referral), 2D (RGB + floorplan) & 3D (Point Cloud + CAD) for both object instances and scenes. Note that, since for 3RScan dataset, they do not provide frame-wise RGB segmentations, we project the 3D data to 2D and store it in `.npz` format for every scan. We provide the scripts for projection. Here's an overview which data features are precomputed:
4848

49-
- Object Instance: Referral, Multi-view RGB images, Point Cloud & CAD (only for ScanNet)
50-
- Scene: Referral, Multi-view RGB images, Floorplan (only for ScanNet) Point Cloud
49+
- Object Instance: Referral, Multi-view RGB images, Point Cloud, & CAD (only for ScanNet)
50+
- Scene: Referral, Multi-view RGB images, Floorplan (only for ScanNet), & Point Cloud
5151

5252
We provide the preprocessing scripts which should be easily cusotmizable for new datasets. Further instructions below.
5353

TRAIN.md

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,31 @@ We provide all available checkpoints on huggingface 👉 [here](https://huggingf
5555

5656

5757
# :shield: Single Inference
58-
We release script to perform inference (generate scene-level embeddings) on a single scan of 3RScan/Scannet. Detailed usage in the file. Quick instructions below:
58+
59+
## Instance Inference
60+
We provide script to perform instance-level cross-modal retrieval inference on a single scan, and report retrieval metrics and matched objects within the scene, across all available modality pairs. Detailed usage in the file. Quick instructions below:
61+
62+
```bash
63+
$ python single_inference/instance_inference.py
64+
```
65+
66+
Various configurable parameters:
67+
68+
- `--dataset`: Dataset name - Options: `scannet`, `scan3r`, `arkitscenes`, `multiscan`
69+
- `--process_dir`: Path to processed features directory containing preprocessed object data
70+
- `--ckpt`: Path to the pre-trained instance crossover model checkpoint (details [here](TRAIN.md#checkpoint-inventory)), example_path: `./checkpoints/instance_crossover_scannet+scan3r+multiscan+arkitscenes.pth`
71+
- `--scan_id`: Scan ID to run inference on (e.g., `scene_00004_00`)
72+
- `--modalities`: List of modalities to use (default: `['rgb', 'point', 'cad', 'referral']`)
73+
- `--input_dim_3d`: Input dimension for 3D features (default: 384)
74+
- `--input_dim_2d`: Input dimension for 2D features (default: 1536)
75+
- `--input_dim_1d`: Input dimension for 1D features (default: 768)
76+
- `--out_dim`: Output embedding dimension (default: 768)
77+
78+
79+
> **Note**: This script requires preprocessed object data for the target scene, namely `objectsDataMultimodal.npz` files generated during data preprocessing as described in [DATA.md](DATA.md/#wrench-data-preprocessing). The scan must have valid object instances across the specified modalities.
80+
81+
## Scene Inference
82+
We release a script to perform inference (generate scene-level embeddings) on a single scan of all supported datasets. Detailed usage in the file. Quick instructions below:
5983

6084
```bash
6185
$ python single_inference/scene_inference.py
@@ -65,12 +89,13 @@ Various configurable parameters:
6589

6690
- `--dataset`: dataset name, Scannet/Scan3R
6791
- `--data_dir`: data directory (eg: `./datasets/Scannet`, assumes similar structure as in `preprocess.md`).
68-
- `--floorplan_dir`: directory consisting of the rasterized floorplans (this can point to the downloaded preprocessed directory), only for Scannet
69-
- `--ckpt`: Path to the pre-trained scene crossover model checkpoint (details [here](TRAIN.md#checkpoint-inventory)), example_path: `./checkpoints/scene_crossover_scannet+scan3r.pth/`).
92+
- `--process_dir`: preprocessed data directory (this can point to the downloaded preprocessed directory)
93+
- `--ckpt`: Path to the pre-trained scene crossover model checkpoint (details [here](TRAIN.md#checkpoint-inventory)), example_path: (`./checkpoints/scene_crossover_scannet+scan3r.pth/`).
7094
- `--scan_id`: the scan id from the dataset you'd like to calculate embeddings for (if not provided, embeddings for all scans are calculated).
7195

7296
The script will output embeddings in the same format as provided [here](DATA.md/#generated-embedding-data).
7397

98+
7499
# :bar_chart: Evaluation
75100
#### Cross-Modal Object Retrieval
76101
Run the following script (refer to the script to run instance baseline/instance crossover) for object instance + scene retrieval results using the instance-based methods. Detailed usage inside the script.

data/datasets/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .scannet import *
22
from .scan3r import *
33
from .arkit import *
4-
from .multiscan import *
5-
from .structured3d import *
4+
from .multiscan import *

demo/demo_instance_retrieval.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626

2727
DEFAULT_CONFIG = {
2828
'dataset': 'scannet', # scannet, scan3r, arkitscenes, multiscan
29-
'data_dir': '/drive/datasets/Scannet', # Update this with your data path
30-
'process_dir': '/drive/dumps/multimodal-spaces/preprocess_feats/Scannet', # Update this with your processed data path
31-
'ckpt': '/drive/dumps/multimodal-spaces/runs/new_runs/instance_crossover_scannet+scan3r+multiscan+arkitscenes.pth', # Update this with your model checkpoint
32-
'scan_id': 'scene0568_00', # Default scan to search in
33-
'query_modality': 'point', # point, rgb, referral
34-
'target_modality': 'referral', # point, rgb, referral, cad
35-
'query_path': './demo_data/kitchen/scene.ply', # Path to your query file
29+
'data_dir': '/drive/datasets/Scannet',
30+
'process_dir': '/drive/dumps/multimodal-spaces/preprocess_feats/Scannet',
31+
'ckpt': '/drive/dumps/multimodal-spaces/runs/new_runs/instance_crossover_scannet+scan3r+multiscan+arkitscenes.pth',
32+
'scan_id': 'scene0568_00',
33+
'query_modality': 'point',
34+
'target_modality': 'point',
35+
'query_path': './demo_data/kitchen/scene.ply', # Path to your query file - refers to query object PCL
3636
'top_k': 5
3737
}
3838
# =============================================================================
@@ -49,7 +49,6 @@ def __init__(self, args):
4949
self.args = args
5050
self.setup_model()
5151

52-
# Setup image transforms
5352
self.image_transform = tvf.Compose([
5453
tvf.ToTensor(),
5554
tvf.Normalize(mean=[0.485, 0.456, 0.406],
@@ -62,7 +61,6 @@ def setup_model(self):
6261
kwargs = [init_kwargs]
6362
self.accelerator = Accelerator(kwargs_handlers=kwargs)
6463

65-
# Convert args to DictConfig format expected by model
6664
model_args = DictConfig({
6765
'out_dim': self.args.out_dim,
6866
'input_dim_3d': self.args.input_dim_3d,
@@ -111,7 +109,7 @@ def _encode_point_query(self, path: str) -> torch.Tensor:
111109
points = np.asarray(pcd.points)
112110

113111
# Send raw point cloud as list (like datasets) - model will handle sampling
114-
point_clouds = [points] # List of raw point clouds
112+
point_clouds = [points]
115113
point_masks = torch.ones(1, 1).bool() # (1, 1)
116114

117115
data_dict = {
@@ -132,11 +130,9 @@ def _encode_rgb_query(self, path: str) -> torch.Tensor:
132130

133131
image = Image.open(path)
134132
image = image.resize((224, 224), Image.BICUBIC)
135-
image_pt = self.image_transform(image).unsqueeze(0) # (1, C, H, W)
136-
137-
# Convert to model expected format: (batch_size, num_objects, num_views, C, H, W)
138-
rgb_data = image_pt.unsqueeze(0).unsqueeze(0) # (1, 1, 1, C, H, W)
139-
rgb_masks = torch.ones(1, 1).bool() # (1, 1)
133+
image_pt = self.image_transform(image).unsqueeze(0)
134+
rgb_data = image_pt.unsqueeze(0).unsqueeze(0)
135+
rgb_masks = torch.ones(1, 1).bool()
140136

141137
data_dict = {
142138
'objects': {
@@ -152,11 +148,9 @@ def _encode_rgb_query(self, path: str) -> torch.Tensor:
152148

153149
def _encode_referral_query(self, path: str) -> torch.Tensor:
154150
"""Encode text referral query"""
155-
if os.isfile(path):
156-
with open(path, 'r') as f:
157-
text = f.read().strip()
158-
else:
159-
text = path # Assume path is the text itself
151+
assert os.path.isfile(path), 'Referral Path should be a text file'
152+
with open(path, 'r') as f:
153+
text = f.read().strip()
160154

161155
data_dict = {'referral_texts': [[[text]]]}
162156

@@ -168,29 +162,23 @@ def _encode_referral_query(self, path: str) -> torch.Tensor:
168162
def encode_scene(self, scan_id: str) -> Dict[str, torch.Tensor]:
169163
"""Encode all objects in the scene and return embeddings by modality"""
170164

171-
# Setup dataset for this specific scan
172165
self.setup_dataset(scan_id)
173-
174-
# Get the data for this scan
175166
data_dict = self.dataset.get_data()
176167

177168

178169
with torch.no_grad():
179170
output = self.model(data_dict)
180171

181-
# Extract embeddings and masks for each modality
182172
scene_embeddings = {}
183173
for modality in output['embeddings']:
184174
embeddings = output['embeddings'][modality].cpu()
185175
masks = data_dict['masks'][modality].cpu()
186176

187-
# Remove batch dimension
188177
if len(embeddings.shape) == 3:
189178
embeddings = embeddings.squeeze(0)
190179
if len(masks.shape) == 2:
191180
masks = masks.squeeze(0)
192181

193-
# Store embeddings and masks
194182
scene_embeddings[modality] = {
195183
'embeddings': embeddings,
196184
'masks': masks,
@@ -229,7 +217,6 @@ def retrieve(
229217
target_embeddings = scene_data[target_modality]['embeddings']
230218
target_masks = scene_data[target_modality]['masks']
231219

232-
# Filter valid objects only
233220
valid_mask = target_masks.bool()
234221
if valid_mask.sum() == 0:
235222
log.warning("No valid objects found in target modality")
@@ -278,7 +265,6 @@ def main():
278265
choices=['point', 'rgb', 'referral', 'cad'],
279266
help=f'Target modality to match against - default: {DEFAULT_CONFIG["target_modality"]}')
280267

281-
# Dataset arguments with defaults from config
282268
parser.add_argument('--dataset', type=str, default=DEFAULT_CONFIG['dataset'],
283269
choices=['scannet', 'scan3r', 'arkitscenes', 'multiscan'],
284270
help=f'Dataset name - default: {DEFAULT_CONFIG["dataset"]}')
@@ -289,7 +275,6 @@ def main():
289275
parser.add_argument('--ckpt', type=str, default=DEFAULT_CONFIG['ckpt'],
290276
help=f'Path to model checkpoint - default: {DEFAULT_CONFIG["ckpt"]}')
291277

292-
# Optional arguments
293278
parser.add_argument('--top_k', type=int, default=DEFAULT_CONFIG['top_k'],
294279
help=f'Number of top results to return - default: {DEFAULT_CONFIG["top_k"]}')
295280

@@ -301,7 +286,6 @@ def main():
301286

302287
args = parser.parse_args()
303288

304-
# Print configuration being used
305289
log.info("=== Instance Retrieval Configuration ===")
306290
log.info(f"Dataset: {args.dataset}")
307291
log.info(f"Data directory: {args.data_dir}")
@@ -329,15 +313,14 @@ def main():
329313

330314
# Run retrieval
331315
retriever = InstanceRetrieval(args)
332-
results = retriever.retrieve(
316+
retriever.retrieve(
333317
args.query_path,
334318
args.query_modality,
335319
args.scan_id,
336320
args.target_modality,
337321
args.top_k
338322
)
339323

340-
return results
341324

342325

343326
if __name__ == '__main__':

0 commit comments

Comments
 (0)