Skip to content

Commit 69e4495

Browse files
Merge pull request #18 from RadarML/dev/modules
GRT Model Fixes & Model Features
2 parents e7caf9f + 1ac7441 commit 69e4495

File tree

16 files changed

+3177
-2345
lines changed

16 files changed

+3177
-2345
lines changed

docs/grt/index.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,49 @@ The GRT reference implementation uses a [hydra](https://hydra.cc/docs/intro/) +
1414

1515
Pre-trained model checkpoints for the GRT reference implementation on the [I/Q-1M dataset](https://radarml.github.io/red-rover/iq1m/) can also be found [here](https://radarml.github.io/red-rover/iq1m/osn/#download-from-osn).
1616

17+
With a single GPU, these checkpoints can be reproduced with the following:
18+
19+
=== "Base Model"
20+
The 3D polar occupancy base model is provided as the default configuration (i.e., `sensors=[radar,lidar]`, `transforms@transforms.sample=[radar,lidar3d]`, `objective=lidar3d`, `model/decoder=lidar3d`).
21+
```sh
22+
uv run train.py meta.name=base meta.version=small size=small
23+
```
24+
25+
=== "2D Occupancy"
26+
```sh
27+
uv run train.py meta.name=occ2d meta.version=small size=small \
28+
+base=occ3d_to_occ2d \
29+
sensors=[radar,lidar] \
30+
transforms@transforms.sample=[radar,lidar2d] \
31+
objective=lidar2d \
32+
model/decoder=lidar2d
33+
```
34+
35+
=== "Semantic Segmentation"
36+
```sh
37+
uv run train.py meta.name=semseg meta.version=small size=small \
38+
+base=occ3d_to_semseg \
39+
sensors=[radar,semseg] \
40+
transforms@transforms.sample=[radar,semseg] \
41+
objective=semseg \
42+
model/decoder=semseg
43+
```
44+
45+
=== "Ego-Motion"
46+
```sh
47+
uv run train.py meta.name=vel meta.version=small size=small \
48+
+base=occ3d_to_vel \
49+
sensors=[radar,pose]
50+
transforms@transforms.sample=[radar,vel] \
51+
objective=vel \
52+
model/decoder=vel
53+
```
54+
55+
!!! tip
56+
57+
If you're not running in a "managed" environment (e.g., Slurm, LSF, AzureML), [`nq`](https://github.com/leahneukirchen/nq) is a lightweight way to run jobs in a queue. Just `sudo apt-get install -y nq`, and run with `nq uv run train.py ...`.
58+
59+
1760
## Quick Start
1861

1962
1. Create a new repository, and copy the contents of the `grt/` directory:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
path: results/example/base/weights.pth
1+
path: results/base/small/weights.pth
22
rename:
33
- decoder.occ3d.unpatch: null
44
- decoder.occ3d: decoder.occ2d
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
path: results/example/base/weights.pth
1+
path: results/base/small/weights.pth
22
rename:
33
- decoder.occ3d.unpatch: null
44
- decoder.occ3d: decoder.semseg

grt/config/base/occ3d_to_vel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
path: results/example/base/weights.pth
1+
path: results/base/small/weights.pth
22
rename:
33
- decoder: null

grt/config/default.yaml

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@ defaults:
22
- _self_
33
- size: small
44
- aug@transforms.sample.augmentations: full
5-
- sensors@datamodule.dataset.sensors: ["radar", "lidar"]
5+
- sensors: ["radar", "lidar"]
66
- transforms@transforms.sample: ["radar", "lidar3d"]
77
- traces@datamodule.traces: ["bike", "indoor", "outdoor"]
88
- model: default
99
- trainer: default
1010
- objective: lidar3d
11+
- optimizer: default
12+
13+
hydra:
14+
run:
15+
dir: ${meta.results}/${meta.name}/${meta.version}
1116

1217
meta:
1318
dataset: data/
@@ -18,9 +23,11 @@ meta:
1823

1924
size: {}
2025

21-
hydra:
22-
run:
23-
dir: ${meta.results}/${meta.name}/${meta.version}
26+
model: {}
27+
objective: {}
28+
aug: {}
29+
optimizer: {}
30+
sensors: {}
2431

2532
datamodule:
2633
_target_: nrdk.roverd.datamodule
@@ -31,15 +38,15 @@ datamodule:
3138
_target_: abstract_dataloader.generic.Next
3239
margin: [1.0, 1.0]
3340
reference: "radar"
34-
sensors: {} # sensors/*.yaml
41+
sensors: ${sensors}
3542
traces:
3643
train:
3744
_target_: nrdk.config.expand
3845
path: ${meta.dataset}
3946
test:
4047
_target_: nrdk.config.expand
4148
path: ${meta.dataset}
42-
batch_size: 32
49+
batch_size: 64
4350
samples: 8
4451
num_workers: 16
4552
prefetch_factor: 1
@@ -48,31 +55,22 @@ datamodule:
4855
ptrain: 0.8
4956
pval: 0.2
5057

51-
model: {}
52-
53-
objective: {}
54-
5558
lightningmodule:
5659
_target_: nrdk.framework.NRDKLightningModule
5760
compile: false
5861
vis_interval: 1000
5962
vis_samples: 16
6063
objective: ${objective}
6164
model: ${model}
62-
optimizer:
63-
_target_: torch.optim.AdamW
64-
_partial_: true
65-
lr: 1e-4
66-
weight_decay: 1e-2
67-
betas: [0.9, 0.999]
68-
eps: 1e-8
65+
optimizer: ${optimizer}
6966

7067
transforms:
7168
_target_: abstract_dataloader.abstract.Pipeline
7269
sample:
7370
_target_: abstract_dataloader.ext.graph.Transform
7471
keep_all: false
7572
outputs: {} # sensors/*.yaml
73+
augmentations: ${aug}
7674
collate:
7775
_target_: abstract_dataloader.ext.torch.Collate
7876
mode: "concat"

grt/config/model/decoder/semseg.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ semseg:
1717
scale: [16.0, 16.0, 10.666, 16.0]
1818
w_min: 0.2
1919
patch: [1, 5, 5, 1]
20-
out_dim: 8
20+
out_dim: 8

grt/config/objective/lidar2d.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ occupancy2d:
77
_target_: nrdk.objectives.Occupancy2D
88
range_weighted: True
99
positive_weight: 64.0
10+
bce_weight: 0.9
1011
vis_config:
1112
cols: 8
1213
width: 512

grt/config/optimizer/default.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_target_: torch.optim.AdamW
2+
_partial_: true
3+
lr: 1e-4
4+
weight_decay: 1e-2
5+
betas: [0.9, 0.999]
6+
eps: 1e-8

grt/train.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ def _inst(path, *args, **kwargs):
6262
return hydra.utils.instantiate(
6363
cfg[path], _convert_="all", *args, **kwargs)
6464

65+
n_gpus = torch.cuda.device_count()
66+
if "batch_size" in cfg["datamodule"] and n_gpus > 1:
67+
batch_new = cfg["datamodule"]["batch_size"] // n_gpus
68+
logger.info(
69+
f"Auto-scaling batch size by n_gpus={n_gpus}: "
70+
f"{cfg["datamodule"]["batch_size"]} -> {batch_new}")
71+
cfg["datamodule"]["batch_size"] = batch_new
72+
6573
transforms = _inst("transforms")
6674
datamodule = _inst("datamodule", transforms=transforms)
6775
lightningmodule = _inst("lightningmodule", transforms=transforms)

0 commit comments

Comments
 (0)