Skip to content

Commit d464d2d

Browse files
Merge pull request #3 from RadarML/dev
GRT reference / template overhaul
2 parents b66ea7f + 2dc9b3e commit d464d2d

38 files changed

+1712
-1845
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ jobs:
2323
uses: astral-sh/setup-uv@v6
2424
with:
2525
python-version: "3.12"
26-
uv-version: 0.8.13
2726

2827
- name: Lint with ruff
2928
run: |

docs/design.md

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,10 @@ The NRDK framework is based on Pytorch Lightning, and centered around a core [`N
3131

3232
## Data Loading
3333

34-
The NRDK leans heavily on the core [abstract dataloader specifications](https://wiselabcmu.github.io/abstract-dataloader/), which describe a set of modular and composable data loading —
35-
> [`Sensor`][abstract_dataloader.spec.Sensor] → [`Trace`][abstract_dataloader.spec.Trace] → [`Dataset`][abstract_dataloader.spec.Dataset]
36-
37-
— and preprocessing —
38-
> [`Pipeline`][abstract_dataloader.spec.Pipeline] := [`.sample:Transform`][abstract_dataloader.spec.Transform] → [`Collate`][abstract_dataloader.spec.Collate] → [`.batch:Transform`][abstract_dataloader.spec.Transform]
39-
40-
— component interfaces.
34+
The NRDK leans heavily on the core [abstract dataloader specifications](https://wiselabcmu.github.io/abstract-dataloader/), which describe a set of modular and composable data loading and preprocessing component interfaces:
35+
> Load: [`Sensor`][abstract_dataloader.spec.Sensor] → [`Trace`][abstract_dataloader.spec.Trace] → [`Dataset`][abstract_dataloader.spec.Dataset]
36+
> <br>
37+
> Process: [`Pipeline`][abstract_dataloader.spec.Pipeline] := [`.sample:Transform`][abstract_dataloader.spec.Transform] &#8594; [`Collate`][abstract_dataloader.spec.Collate] &#8594; [`.batch:Transform`][abstract_dataloader.spec.Transform]
4138
4239
- When using data collected by [`red-rover`](https://radarml.github.io/red-rover/) and stored in the [`roverd`](https://radarml.github.io/red-rover/roverd/) format, read data using the [`red-rover/roverd`](https://radarml.github.io/red-rover/roverd/) library, which provides `abstract-dataloader` compliant [`Dataset`][roverd.Dataset], [`Trace`][roverd.Trace], and [`Sensor`][roverd.sensors] implementations.
4340

@@ -72,6 +69,26 @@ We currently provide the following objectives:
7269

7370
In addition to implementing the abstract specification, each provided objective includes a type specification for their expected model predictions and ground truth data. Implementations using these objectives out-of-the-box only need to provide data which fits these type interfaces.
7471

72+
## Model & Module Zoo
73+
74+
We include a number of reusable implementations of common (and stable) model modules and architectures with the `nrdk`; note that active research code should go in separate repositories, and only merged here once stable &mdash; and ready to publicly release!
75+
76+
<div class="grid cards" markdown>
77+
78+
- [`nrdk.models`](nrdk/models.md)
79+
80+
---
81+
82+
reusable, stable model architectures.
83+
84+
- [`nrdk.modules`](nrdk/modules.md)
85+
86+
---
87+
88+
modules beyond of the standard library.
89+
90+
</div>
91+
7592

7693
## Other Modules
7794

@@ -91,12 +108,6 @@ The `nrdk` includes a number of other submodules intended as reusable libraries
91108

92109
training and evaluation metrics
93110

94-
- [`nrdk.modules`](nrdk/modules.md)
95-
96-
---
97-
98-
pytorch [`nn.Module`][torch.nn.Module] implementations outside of the standard library
99-
100111
- [`nrdk.visualization`](nrdk/vis.md)
101112

102113
---

docs/grt/config.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Included Configurations
2+
3+
Hydra is organized around "configuration groups", which are collections of related configuration files. The reference training script is organized around several configuration groups which are intended to be used in routine experiments, as well as some others which should not generally need to be modified (`hydra`, `lightningmodule`).
4+
5+
## `meta` &mdash; Metadata
6+
7+
```sh
8+
uv run train.py meta.name=example_experiment meta.version=test_version_0 resume=path/to/previous/checkpoint.ckpt
9+
```
10+
11+
Experiment metadata is configured using the `meta` config group, including an optional `resume` field which can be used to resume training from a previous checkpoint (with identical model settings).
12+
13+
!!! warning
14+
15+
Make sure to set the `meta.name` and `meta.version` fields!
16+
17+
## `size` &mdash; Model Dimensions
18+
19+
```sh
20+
uv run train.py size.d_model=256 size.d_feedforward=1024 size.nhead=4 size.enc_layers=3 size.dec_layers=3
21+
```
22+
23+
Certain model dimensions can be configured globally using a `size` config group, which is referenced by other model configurations:
24+
25+
```yaml
26+
size:
27+
d_model: 512
28+
d_feedforward: 2048
29+
nhead: 8
30+
enc_layers: 4
31+
dec_layers: 4
32+
```
33+
34+
## `base` &mdash; Base Model
35+
36+
```sh
37+
uv run train.py +base=occ3d_to_semseg
38+
```
39+
40+
Load a base model using the specified configuration; see [`NRDKLightningModule.load_weights`][nrdk.framework.NRDKLightningModule.load_weights] for details about how to configure this behavior.
41+
42+
=== "Base &rarr; 2D Occupancy"
43+
44+
```sh
45+
--8<-- "grt/config/base/occ3d_to_occ2d.yaml"
46+
```
47+
48+
=== "Base &rarr; Semseg"
49+
50+
```sh
51+
--8<-- "grt/config/base/occ3d_to_semseg.yaml"
52+
```
53+
54+
=== "Base &rarr; Odometry"
55+
56+
```sh
57+
--8<-- "grt/config/base/occ3d_to_vel.yaml"
58+
```
59+
60+
## `datamodule` &mdash; Dataloader
61+
62+
To configure the sensors to load:
63+
```sh
64+
uv run train.py sensors@datamodule.dataset.sensors=[radar,lidar,semseg,pose]
65+
```
66+
67+
To configure which traces to include in the dataset:
68+
```sh
69+
uv run train.py traces@datamodule.traces=[bike,indoor,outdoor]
70+
```
71+
72+
See [`nrdk.roverd.datamodule`][nrdk.roverd.datamodule] for more details about the dataloader configuration.
73+
74+
## `model` &mdash; Model Architecture
75+
76+
```sh
77+
uv run train.py model=default
78+
uv run train.py model/decoder=semseg
79+
```
80+
81+
The model is just any `torch.nn.Module`; the included [`TokenizerEncoderDecoder`][nrdk.framework.TokenizerEncoderDecoder] is a good starting point.
82+
83+
!!! warning
84+
85+
The reference configs are built around [`TokenizerEncoderDecoder`][nrdk.framework.TokenizerEncoderDecoder], with the tokenizer, encoder, and decoder as nested sub-configs.
86+
87+
Note that these sub-configs must be specified as `model/decoder=...`, not `model.decoder=...`!
88+
89+
---
90+
91+
## `objective` &mdash; Training Objective
92+
93+
```sh
94+
uv run train.py objective=lidar3d
95+
```
96+
97+
Training objectives are expected to implement the the [`Objective`][abstract_dataloader.ext.objective] interface.
98+
99+
## `transforms` &mdash; Data Processing
100+
101+
```sh
102+
uv run train.py transforms@transforms.sample=[radar,lidar3d]
103+
```
104+
105+
## `lightningmodule` &mdash; Training Loop
106+
107+
The default `lightningmodule` config should not need to be modified, and pulls in the `${objective}` and `{$model}` configs.

docs/grt/index.md

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,64 @@
1-
# GRT Reference Implementation
1+
# GRT Reference Implementation & Project Template
22

3-
!!! warning "Required Extras"
3+
The GRT reference implementation uses a [hydra](https://hydra.cc/docs/intro/) + [pytorch lightning](https://lightning.ai/docs/pytorch/stable/)-based stack on top of the NRDK; use this reference implementation to get started on a new project.
44

5-
The GRT reference implementation requires the following extras:
5+
!!! tip
66

7-
- [`roverd`](https://radarml.github.io/red-rover/roverd/): a dataloader for data collected by the [red-rover system](https://radarml.github.io/red-rover/)
8-
- [`xwr`](https://radarml.github.io/xwr/): radar signal processing for TI mmWave radars
7+
The included reference implementation can be run out-of-the-box:
98

10-
Install/run with with
11-
```sh
12-
uv sync --extra xwr --extra roverd
13-
uv run --extra xwr --extra roverd train.py ...
14-
# or just
15-
uv sync --all-extras
16-
```
9+
1. Obtain a copy of the [I/Q-1M](https://radarml.github.io/red-rover/iq1m/), and save it (or link it) to `nrdk/grt/data/`.
10+
2. Create a virtual environment in `nrdk/grt` with `uv sync`.
11+
3. Run with `uv run train.py`; see the hydra config files in `nrdk/grt/config/` for options.
12+
13+
## Quick Start
14+
15+
1. Create a new repository, and copy the contents of the `grt/` directory:
16+
```
17+
example-project/
18+
├── config/
19+
│ ├── aug/
20+
| ...
21+
├── grt/
22+
│ ├── __init__.py
23+
| ...
24+
├── pyproject.toml
25+
├── train.py
26+
└── train_minimal.py
27+
```
28+
29+
!!! tip
30+
31+
Don't forget to change the `name`, `authors`, and `description`!
32+
33+
2. Set up the `nrdk` dependency.
34+
35+
!!! warning "Required Extras"
36+
37+
Make sure you include the `roverd` extra, which installs the following:
38+
39+
- [`roverd`](https://radarml.github.io/red-rover/roverd/): a dataloader for data collected by the [red-rover system](https://radarml.github.io/red-rover/)
40+
- [`xwr`](https://radarml.github.io/xwr/): radar signal processing for TI mmWave radars
41+
42+
If using `uv`, uncomment one of the corresponding lines in the supplied `pyproject.toml` (and comment out the included `nrdk = { path = "../" }` line):
43+
44+
=== "Via Github"
45+
46+
```toml
47+
[tool.uv.sources]
48+
nrdk = { git = "ssh://git@github.com/radarml/nrdk.git" }
49+
```
50+
51+
=== "Via Submodule"
52+
53+
After `git submodule add git@github.com:RadarML/nrdk.git`:
54+
```toml
55+
[tool.uv.sources]
56+
nrdk = { path = "./nrdk" }
57+
```
58+
59+
## Training Script
60+
61+
The GRT template includes reference training scripts which can be used for high level training and fine tuning control flow. You can use these scripts as-is, or modify them to suit your needs; where possible, stick to the same general structure to maintain compatibility.
1762

1863
??? quote "Reference Training Script"
1964

@@ -26,16 +71,3 @@
2671
```python title="grt/train_minimal.py"
2772
--8<-- "grt/train_minimal.py"
2873
```
29-
30-
Current train command (WIP):
31-
```sh
32-
uv run grt/train_minimal.py +objectives@objectives=lidar2d decoder@model.decoder=lidar2d sensors=[radar,lidar2d] aug@transforms.sample.augmentations=full
33-
```
34-
35-
```sh
36-
uv run grt/train.py trainer=debug globals.d_feedforward=1024 model/decoder=semseg --cfg job
37-
```
38-
39-
```sh
40-
uv run python grt/train.py sensors@datamodule.dataset.sensors=[radar,semseg] transforms@transforms.sample=[radar,semseg] objective=semseg model/decoder@lightningmodule.model.decoder=semseg +base=occ3d_to_semseg meta.name=example meta.version=version_0
41-
```

docs/nrdk/models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:::nrdk.models

docs/tss/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
Use a effective sample size-corrected paired z-test to compare methods which have been evaluated on moderately-sized time series data.
66

7-
The neural radar development kit ships with a small statistical module (`nrdk.tss`) & CLI tool (`tss`) for analyzing time series performance metrics.
7+
The neural radar development kit ships with a small, self-contained statistical module (`nrdk.tss`) & CLI tool (`tss`) for analyzing time series performance metrics.
88

99
!!! warning "Time-Correlated Samples"
1010

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
path: results/example/baseline3/weights.pth
1+
path: results/example/base/weights.pth
22
rename:
33
- decoder.occ3d.unpatch: null
44
- decoder.occ3d: decoder.occ2d
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
path: results/example/baseline3/weights.pth
1+
path: results/example/base/weights.pth
22
rename:
33
- decoder.occ3d.unpatch: null
44
- decoder.occ3d: decoder.semseg

grt/config/base/occ3d_to_vel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
path: results/example/baseline3/weights.pth
1+
path: results/example/base/weights.pth
22
rename:
33
- decoder: null

grt/config/default.yaml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
defaults:
2+
- _self_
3+
- size: small
24
- aug@transforms.sample.augmentations: full
35
- sensors@datamodule.dataset.sensors: ["radar", "lidar"]
46
- transforms@transforms.sample: ["radar", "lidar3d"]
57
- traces@datamodule.traces: ["bike", "indoor", "outdoor"]
6-
- model@lightningmodule.model: default
8+
- model: default
79
- trainer: default
810
- objective: lidar3d
9-
- _self_
1011

1112
meta:
1213
dataset: data/
@@ -15,12 +16,7 @@ meta:
1516
version: null
1617
resume: null
1718

18-
globals:
19-
d_model: 512
20-
d_feedforward: 2048
21-
nhead: 8
22-
enc_layers: 4
23-
dec_layers: 4
19+
size: {}
2420

2521
hydra:
2622
run:
@@ -42,18 +38,24 @@ datamodule:
4238
path: ${meta.dataset}
4339
batch_size: 32
4440
samples: 8
45-
num_workers: 12
46-
prefetch_factor: 2
41+
num_workers: 16
42+
prefetch_factor: 1
4743
subsample:
4844
val: 16384
4945
ptrain: 0.8
5046
pval: 0.2
5147

48+
model: {}
49+
50+
objective: {}
51+
5252
lightningmodule:
5353
_target_: nrdk.framework.NRDKLightningModule
54-
vis_interval: 10
54+
compile: false
55+
vis_interval: 1000
5556
vis_samples: 16
5657
objective: ${objective}
58+
model: ${model}
5759
optimizer:
5860
_target_: torch.optim.AdamW
5961
_partial_: true
@@ -69,5 +71,5 @@ transforms:
6971
keep_all: false
7072
outputs: {} # sensors/*.yaml
7173
collate:
72-
_target_: abstract_dataloader.torch.Collate
74+
_target_: abstract_dataloader.ext.torch.Collate
7375
mode: "concat"

0 commit comments

Comments
 (0)