Skip to content

Commit f9c8c0f

Browse files
authored
Merge pull request #57 from amrit110/feat/simplify-model-config-interface
feat: simplify model_config interface to direct keyword arguments
2 parents 7cb962e + 640e0df commit f9c8c0f

File tree

16 files changed

+273
-152
lines changed

16 files changed

+273
-152
lines changed

docs/api.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@ Most users only need these public functions:
3434

3535
Model-specific configuration:
3636

37-
- `get_embedding(...)` and `get_embeddings_batch(...)` accept `model_config`
38-
- `export_batch(...)` supports per-model `model_config` via `ExportModelRequest(...)`
39-
- currently documented model-level `model_config` usage includes `dofa`, `anysat`, `thor`, and `satmaepp_s2_10b`
40-
- for the currently documented variant-aware models, use a unified field: `model_config={"variant": "..."}`
41-
- valid `variant` values still depend on the selected model and currently exposed published checkpoints, so check the corresponding model detail page
42-
- unsupported `model_config` usage raises `ModelError` instead of being ignored silently
37+
- `get_embedding(...)` and `get_embeddings_batch(...)` accept model-specific settings as direct keyword arguments (e.g. `variant="large"`)
38+
- `export_batch(...)` supports per-model settings via `ExportModelRequest.configure("model", variant="large")`
39+
- currently documented variant-aware models: `dofa`, `anysat`, `thor`, `prithvi`, and `satmaepp_s2_10b`
40+
- valid `variant` values depend on the model — check the corresponding model detail page or call `describe_model(model_id)`
41+
- passing unsupported keyword arguments raises `ModelError`
4342

4443
Sampling / fetch configuration:
4544

docs/api_embedding.md

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ get_embedding(
4242
temporal: Optional[TemporalSpec] = None,
4343
sensor: Optional[SensorSpec] = None,
4444
fetch: Optional[FetchSpec] = None,
45-
model_config: Optional[dict[str, Any]] = None,
4645
modality: Optional[str] = None,
4746
output: OutputSpec = OutputSpec.pooled(),
4847
backend: str = "auto",
4948
device: str = "auto",
5049
input_prep: InputPrepSpec | str = "resize",
50+
**model_kwargs,
5151
) -> Embedding
5252
```
5353

@@ -60,7 +60,7 @@ Computes the embedding for a single ROI.
6060
- `temporal`: `TemporalSpec` or `None`
6161
- `sensor`: input descriptor for on-the-fly models; for most precomputed models this can be `None`
6262
- `fetch`: lightweight sampling override for common cases such as `scale_m`, `cloudy_pct`, `composite`, and `fill_value`
63-
- `model_config`: optional model-specific runtime settings; for the currently documented variant-aware models, use it mainly as `{"variant": ...}`
63+
- `**model_kwargs`: model-specific settings passed as direct keyword arguments (e.g. `variant="large"`); the accepted keys depend on the model — call `describe_model(model_id)` to see the schema
6464
- `modality`: optional model-facing modality selector (for example `s1`, `s2`, `s2_l2a`) for models that expose multiple input branches
6565
- `output`: `OutputSpec.pooled()` or `OutputSpec.grid(...)`
6666
- `backend`: access backend. `backend="auto"` is the public default and the recommended choice. For provider-backed on-the-fly models it resolves to a compatible provider backend; for precomputed models it lets rs-embed choose the model-compatible access path.
@@ -79,17 +79,18 @@ Modality contract:
7979
- Only models that explicitly expose a given modality can use it.
8080
- Unsupported modality selections raise a `ModelError`.
8181

82-
`model_config` contract:
82+
Model-specific settings contract:
8383

84-
- `model_config` is optional and model-specific
85-
- for the currently documented variant-aware models, the public field is unified as `variant`
86-
- examples currently documented in model pages include:
87-
- `dofa`: `{"variant": "base" | "large"}`
88-
- `anysat`: `{"variant": "base"}`
89-
- `thor`: `{"variant": "tiny" | "small" | "base" | "large"}`
90-
- `satmaepp_s2_10b`: `{"variant": "large"}`
91-
- if a model does not document `model_config`, leave it unset; unsupported usage raises `ModelError`
92-
- when available, `describe()["model_config"]` is the machine-readable schema for supported keys and values
84+
- model settings are optional and vary per model
85+
- pass them as direct keyword arguments rather than a dict (e.g. `variant="large"`)
86+
- variant-aware models currently documented:
87+
- `dofa`: `variant="base"` or `variant="large"`
88+
- `anysat`: `variant="base"`
89+
- `thor`: `variant="tiny"`, `"small"`, `"base"`, or `"large"`
90+
- `satmaepp_s2_10b`: `variant="large"`
91+
- `prithvi`: `variant="prithvi_eo_v2_100_tl"`, `"prithvi_eo_v2_300_tl"`, or `"prithvi_eo_v2_600_tl"`
92+
- if a model does not accept any keyword arguments, passing unknown keys raises `ModelError`
93+
- `describe_model(model_id)["model_config"]` is the machine-readable schema for supported keys and values
9394

9495
**Returns**
9596

@@ -113,7 +114,7 @@ emb = get_embedding(
113114
vec = emb.data # (D,)
114115
```
115116

116-
**Example with `model_config`**
117+
**Example with variant selection**
117118

118119
```python
119120
from rs_embed import PointBuffer, TemporalSpec, OutputSpec, get_embedding
@@ -124,7 +125,7 @@ emb = get_embedding(
124125
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
125126
output=OutputSpec.pooled(),
126127
backend="gee",
127-
model_config={"variant": "large"},
128+
variant="large",
128129
)
129130
```
130131

@@ -140,12 +141,12 @@ get_embeddings_batch(
140141
temporal: Optional[TemporalSpec] = None,
141142
sensor: Optional[SensorSpec] = None,
142143
fetch: Optional[FetchSpec] = None,
143-
model_config: Optional[dict[str, Any]] = None,
144144
modality: Optional[str] = None,
145145
output: OutputSpec = OutputSpec.pooled(),
146146
backend: str = "auto",
147147
device: str = "auto",
148148
input_prep: InputPrepSpec | str = "resize",
149+
**model_kwargs,
149150
) -> List[Embedding]
150151
```
151152

@@ -182,7 +183,7 @@ embs = get_embeddings_batch(
182183
)
183184
```
184185

185-
**Batch example with `model_config`**
186+
**Batch example with variant selection**
186187

187188
```python
188189
from rs_embed import PointBuffer, TemporalSpec, OutputSpec, get_embeddings_batch
@@ -198,7 +199,7 @@ embs = get_embeddings_batch(
198199
temporal=TemporalSpec.range("2022-01-01", "2023-01-01"),
199200
output=OutputSpec.pooled(),
200201
backend="gee",
201-
model_config={"variant": "base"},
202+
variant="base",
202203
)
203204
```
204205

docs/api_export.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,14 @@ models=[
197197
]
198198
```
199199

200-
`ExportModelRequest(...)` also carries per-model `model_config`, for example:
200+
`ExportModelRequest.configure(...)` also accepts model-specific settings as keyword arguments, for example:
201201

202202
```python
203203
from rs_embed import ExportModelRequest
204204

205205
models=[
206206
"remoteclip",
207-
ExportModelRequest("thor", model_config={"variant": "large"}),
207+
ExportModelRequest.configure("thor", variant="large"),
208208
]
209209
```
210210

@@ -213,7 +213,7 @@ Typical use cases:
213213
- one model needs a different `FetchSpec`
214214
- one model needs `modality="s1"`
215215
- one model needs a different `SensorSpec`
216-
- one model needs a different `model_config` such as `{"variant": "large"}`
216+
- one model needs a different variant (e.g. `variant="large"`)
217217
- one model should override the shared export settings
218218

219219
This also matches the implementation path: string model IDs are first converted into `ExportModelRequest(name=...)`, then resolved.
@@ -224,11 +224,11 @@ Modality rules:
224224
- one model can override it via `ExportModelRequest(...)`
225225
- unsupported modality choices raise `ModelError`
226226

227-
`model_config` rules:
227+
Per-model settings rules:
228228

229-
- `export_batch(...)` does not have one global `model_config` shared across all models
230-
- pass per-model runtime settings through `ExportModelRequest(..., model_config=...)`
231-
- unsupported `model_config` usage raises `ModelError`
229+
- `export_batch(...)` does not have a global model settings parameter shared across all models
230+
- pass per-model settings through `ExportModelRequest.configure("model", variant=...)`
231+
- unsupported keyword arguments raise `ModelError`
232232

233233
---
234234

@@ -328,7 +328,7 @@ from rs_embed import (
328328
export_batch(
329329
spatials=[PointBuffer(121.5, 31.2, 2048)],
330330
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
331-
models=[ExportModelRequest("thor", model_config={"variant": "large"})],
331+
models=[ExportModelRequest.configure("thor", variant="large")],
332332
target=ExportTarget.combined("exports/thor_large_run"),
333333
backend="gee",
334334
)
@@ -356,6 +356,6 @@ If provider-backed export is used and both `save_inputs=True` and `save_embeddin
356356

357357
!!! tip "Simple rule"
358358
Start with `ExportTarget.combined(...)` + `ExportConfig()`.
359-
Add `ExportModelRequest(...)` only for the few models that need per-model sensor, modality, or `model_config` overrides.
359+
Add `ExportModelRequest.configure(...)` only for the few models that need per-model sensor, modality, or variant overrides.
360360

361361
---

docs/api_specs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,12 +355,12 @@ ExportConfig(
355355

356356
ExportModelRequest("remoteclip")
357357
ExportModelRequest("terrafm", modality="s1", sensor=my_s1_sensor)
358-
ExportModelRequest("thor", model_config={"variant": "large"})
358+
ExportModelRequest.configure("thor", variant="large")
359359
```
360360

361361
- `ExportTarget`: where outputs should be written
362362
- `ExportConfig`: how the export should run
363-
- `ExportModelRequest`: optional per-model overrides when one export job mixes different model-specific settings such as sensor, modality, or `model_config`
363+
- `ExportModelRequest`: optional per-model overrides when one export job mixes different model-specific settings such as sensor, modality, or variant; use `ExportModelRequest.configure(...)` to pass model settings as keyword arguments
364364

365365
Legacy `out + layout`, `out_dir` / `out_path`, and per-model dict overrides are still accepted for backward compatibility.
366366

docs/extending.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class EmbedderBase:
104104
If your model exposes public `model_config` keys, document them in `describe()["model_config"]`
105105
with a JSON-serializable schema.
106106
For model detail docs, surface those public keys near the top as well, for example in the
107-
`Quick Facts` table as `Model config keys | model_config["variant"]`.
107+
`Quick Facts` table as `Model config keys | \`variant\` (default: \`base\`; choices: \`base\`, \`large\`)`.
108108

109109
---
110110

docs/models/anysat.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
| Default resolution | 10m default provider fetch (`sensor.scale_m`) |
1515
| Temporal mode | `range` in practice (adapter normalizes `year`/`None` to range) |
1616
| Output modes | `pooled`, `grid` |
17-
| Model config keys | `model_config["variant"]` (default: `base`; choices: `base`) |
17+
| Model config keys | `variant` (default: `base`; choices: `base`) |
1818
| Extra side inputs | **required** `s2_dates` (per-frame DOY values) |
1919
| Training alignment (adapter path) | Medium (depends on frame count, normalization mode, and image size) |
2020

docs/models/dofa.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
| Default resolution | 10m default provider fetch (`sensor.scale_m`) |
1515
| Temporal mode | provider path requires `TemporalSpec.range(...)` |
1616
| Output modes | `pooled`, `grid` |
17-
| Model config keys | `model_config["variant"]` (default: `base`; choices: `base`, `large`) |
17+
| Model config keys | `variant` (default: `base`; choices: `base`, `large`) |
1818
| Extra side inputs | **required** wavelength vector (`wavelengths_um`) |
1919
| Training alignment (adapter path) | Medium-High (when wavelengths and band semantics are correct) |
2020

@@ -110,19 +110,16 @@ Fixed adapter behavior:
110110

111111
Non-env model selection knobs:
112112

113-
- `model_config["variant"]`: `base` / `large` (default: `base`)
113+
- `variant`: `base` / `large` (default: `base`)
114114
- `sensor.bands`: channel semantics for provider fetch and wavelength inference
115115
- `sensor.wavelengths`: explicit wavelength vector (µm)
116116

117-
If `model_config["variant"]` is omitted, rs-embed uses the `base` DOFA checkpoint by default. Set `model_config={"variant": "large"}` to switch to the larger model.
117+
If `variant` is omitted, rs-embed uses the `base` DOFA checkpoint by default. Pass `variant="large"` to switch to the larger model.
118118

119119
Quick reminder:
120120

121-
- DOFA supports `variant` directly through `model_config`
122-
- current public usage is:
123-
- `model_config={"variant": "base"}`
124-
- `model_config={"variant": "large"}`
125-
- for export jobs, pass the same setting via `ExportModelRequest("dofa", model_config={"variant": ...})`
121+
- pass `variant` as a keyword argument directly: `get_embedding("dofa", ..., variant="base")`
122+
- for export jobs, use `ExportModelRequest.configure("dofa", variant="large")`
126123

127124
---
128125

@@ -152,7 +149,6 @@ emb = get_embedding(
152149
"dofa",
153150
spatial=PointBuffer(lon=121.5, lat=31.2, buffer_m=2048),
154151
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
155-
model_config={"variant": "base"},
156152
output=OutputSpec.pooled(),
157153
backend="gee",
158154
)
@@ -167,9 +163,9 @@ emb = get_embedding(
167163
"dofa",
168164
spatial=PointBuffer(lon=121.5, lat=31.2, buffer_m=2048),
169165
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
170-
model_config={"variant": "large"},
171166
output=OutputSpec.pooled(),
172167
backend="gee",
168+
variant="large",
173169
)
174170
```
175171

docs/models/prithvi.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Prithvi-EO v2 (`prithvi`)
22

3-
> Vendored Prithvi runtime for Sentinel-2 6-band inputs, with required temporal/location coordinate side inputs derived by rs-embed and `model_config["variant"]` support for TL checkpoints.
3+
> Vendored Prithvi runtime for Sentinel-2 6-band inputs, with required temporal/location coordinate side inputs derived by rs-embed and `variant` keyword support for TL checkpoints.
44
55
## Quick Facts
66

@@ -15,7 +15,7 @@
1515
| Default resolution | 30m default provider fetch (`sensor.scale_m`) |
1616
| Temporal mode | `range` preferred; adapter normalizes `year`/`None` to a range |
1717
| Output modes | `pooled`, `grid` |
18-
| Model config keys | `model_config["variant"]` (default: `prithvi_eo_v2_100_tl`) |
18+
| Model config keys | `variant` (default: `prithvi_eo_v2_100_tl`) |
1919
| Extra side inputs | **required** temporal coords + location coords (derived by adapter) |
2020
| Training alignment (adapter path) | Medium (depends on preprocessing mode and resize/pad choices) |
2121

@@ -95,15 +95,15 @@ Default `SensorSpec` if omitted:
9595

9696
---
9797

98-
## `model_config`
98+
## Model-specific Settings
9999

100100
| Key | Type | Default | Choices |
101101
|---|---|---|---|
102102
| `variant` | `string` | `prithvi_eo_v2_100_tl` | `prithvi_eo_v2_100_tl`, `prithvi_eo_v2_300_tl`, `prithvi_eo_v2_600_tl` |
103103

104104
Notes:
105105

106-
- `model_config["variant"]` overrides `RS_EMBED_PRITHVI_KEY`.
106+
- `variant` overrides `RS_EMBED_PRITHVI_KEY`.
107107
- Short aliases `100_tl`, `300_tl`, and `600_tl` are also accepted in code.
108108

109109
---
@@ -148,7 +148,7 @@ emb = get_embedding(
148148
# export RS_EMBED_PRITHVI_PRETRAINED=1
149149
```
150150

151-
### With `model_config["variant"]`
151+
### With variant selection
152152

153153
```python
154154
from rs_embed import get_embedding, PointBuffer, TemporalSpec, OutputSpec
@@ -159,7 +159,7 @@ emb = get_embedding(
159159
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
160160
output=OutputSpec.pooled(),
161161
backend="gee",
162-
model_config={"variant": "prithvi_eo_v2_300_tl"},
162+
variant="prithvi_eo_v2_300_tl",
163163
)
164164
```
165165

docs/models/satmaepp.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
| Default resolution | 10m default provider fetch (`sensor.scale_m`) | 10m default provider fetch (`sensor.scale_m`) |
1515
| Temporal mode | range window + single composite | range window + single composite |
1616
| Output modes | `pooled`, `grid` | `pooled`, `grid` |
17-
| Model config keys | none | `model_config["variant"]` (default: `large`; choices: `large`) |
17+
| Model config keys | none | `variant` (default: `large`; choices: `large`) |
1818
| Core extraction | `forward_encoder(mask_ratio=0.0)` | `forward_encoder(mask_ratio=0.0)` |
1919

2020
---
@@ -185,7 +185,7 @@ emb_s2 = get_embedding(
185185
# export RS_EMBED_SATMAEPP_S2_GRID_REDUCE=mean
186186
```
187187

188-
### Example with `model_config`
188+
### Example with variant selection
189189

190190
```python
191191
from rs_embed import get_embedding, PointBuffer, TemporalSpec, OutputSpec
@@ -199,12 +199,12 @@ emb_s2 = get_embedding(
199199
temporal=temporal,
200200
output=OutputSpec.grid(),
201201
backend="gee",
202-
model_config={"variant": "large"},
202+
variant="large",
203203
)
204204
```
205205

206206
For export jobs, the same setting goes through
207-
`ExportModelRequest("satmaepp_s2_10b", model_config={"variant": "large"})`.
207+
`ExportModelRequest.configure("satmaepp_s2_10b", variant="large")`.
208208

209209
---
210210

docs/models/thor.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
| Default resolution | 10m default provider fetch (`sensor.scale_m`) |
1616
| Temporal mode | `range` in practice (composite window) |
1717
| Output modes | `pooled`, `grid` |
18-
| Model config keys | `model_config["variant"]` (default: `base`; choices: `tiny`, `small`, `base`, `large`) |
18+
| Model config keys | `variant` (default: `base`; choices: `tiny`, `small`, `base`, `large`) |
1919
| Extra side inputs | none required in current adapter |
2020
| Training alignment (adapter path) | High when `thor_stats` normalization and default S2 SR setup are preserved |
2121

@@ -99,10 +99,10 @@ Notes:
9999
- `RS_EMBED_THOR_PATCH_SIZE` and `RS_EMBED_THOR_IMG` jointly affect token layout and `ground_cover_m`.
100100
- Changing `group_merge` changes grid channel semantics and dimensionality (especially `concat`).
101101

102-
## `model_config`
102+
## Model-specific Settings
103103

104-
- `model_config["variant"]`: `tiny` / `small` / `base` / `large`
105-
- for export jobs, pass it via `ExportModelRequest("thor", model_config={"variant": ...})`
104+
- `variant`: `tiny` / `small` / `base` / `large`
105+
- for export jobs, pass it via `ExportModelRequest.configure("thor", variant=...)`
106106

107107
Example:
108108

@@ -115,7 +115,7 @@ emb = get_embedding(
115115
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
116116
output=OutputSpec.pooled(),
117117
backend="gee",
118-
model_config={"variant": "large"},
118+
variant="large",
119119
)
120120
```
121121

@@ -163,7 +163,7 @@ emb = get_embedding(
163163
# export RS_EMBED_THOR_PATCH_SIZE=16
164164
```
165165

166-
### Example with `model_config`
166+
### Example with variant selection
167167

168168
```python
169169
from rs_embed import get_embedding, PointBuffer, TemporalSpec, OutputSpec
@@ -174,7 +174,7 @@ emb = get_embedding(
174174
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
175175
output=OutputSpec.grid(pooling="mean"),
176176
backend="gee",
177-
model_config={"variant": "small"},
177+
variant="small",
178178
)
179179
```
180180

0 commit comments

Comments
 (0)