Skip to content

Commit ace7e63

Browse files
committed
feat:anysat dense/tile mode support
1 parent 09a2447 commit ace7e63

File tree

8 files changed

+379
-25
lines changed

8 files changed

+379
-25
lines changed

docs/assets/vis.png

181 KB
Loading

docs/models/anysat.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# AnySat (`anysat`)
22

3-
> Multi-frame Sentinel-2 time-series adapter that builds AnySat inputs (`s2` + `s2_dates`) from a temporal window and returns patch-grid features or pooled vectors.
3+
> Multi-frame Sentinel-2 time-series adapter that builds AnySat inputs (`s2` + `s2_dates`) from a temporal window and returns dense sub-patch grids by default or pooled vectors.
44
55
## Quick Facts
66

@@ -53,11 +53,13 @@ For `input_chw`, the adapter accepts either `CHW` or `TCHW` with `C=10`. A `CHW`
5353
<span class="pipeline-arrow">-&gt;</span> build AnySat side inputs
5454
<span class="pipeline-branch">s2:</span> [1,T,10,H,W]
5555
<span class="pipeline-branch">s2_dates:</span> [1,T] from frame-bin DOY midpoints
56-
<span class="pipeline-arrow">-&gt;</span> forward with output="patch" and patch_size=sensor.scale_m
56+
<span class="pipeline-arrow">-&gt;</span> forward with AnySat spatial output
57+
<span class="pipeline-branch">grid path:</span> default `output="dense"` (`grid_feature_mode="dense"`)
58+
<span class="pipeline-branch">pooled path:</span> default compatibility path keeps `output="patch"`; optional `pooled_source="tile"` uses native AnySat tile output
5759
<span class="pipeline-arrow">-&gt;</span> map [B,H,W,D] -&gt; rs-embed grid [D,H,W]
5860
<span class="pipeline-arrow">-&gt;</span> output projection
59-
<span class="pipeline-branch">pooled:</span> spatial mean / max over grid
60-
<span class="pipeline-branch">grid:</span> model patch grid</code></pre>
61+
<span class="pipeline-branch">pooled:</span> spatial mean / max over patch grid
62+
<span class="pipeline-branch">grid:</span> dense sub-patch grid by default (or patch grid when overridden)</code></pre>
6163

6264
Important constraint:
6365

@@ -73,6 +75,8 @@ Important constraint:
7375
| `RS_EMBED_ANYSAT_IMG` | `24` | Per-frame resize target (square) |
7476
| `RS_EMBED_ANYSAT_NORM` | `per_tile_zscore` | Series normalization mode |
7577
| `RS_EMBED_ANYSAT_MODEL_SIZE` | `base` | AnySat model size |
78+
| `RS_EMBED_ANYSAT_GRID_MODE` | `dense` | Grid path native AnySat spatial output (`dense` or `patch`) |
79+
| `RS_EMBED_ANYSAT_POOLED_SOURCE` | `patch` | Pooled path source (`patch` compatibility pooling or native `tile`) |
7680
| `RS_EMBED_ANYSAT_FLASH_ATTN` | `0` | Enable flash attention path if supported |
7781
| `RS_EMBED_ANYSAT_PRETRAINED` | `1` | Load pretrained checkpoint weights |
7882
| `RS_EMBED_ANYSAT_CKPT` | unset | Local checkpoint override |
@@ -85,7 +89,7 @@ Important constraint:
8589

8690
## Output Semantics
8791

88-
AnySat follows the standard patch-grid pattern for multi-frame adapters. `pooled` applies spatial pooling over the patch grid, and `grid` returns `(D,H,W)` in model patch space rather than georeferenced raster pixels. The more distinctive AnySat details, such as frame packaging and `doy0_values`, are recorded in metadata rather than requiring a long per-page output section.
92+
AnySat now uses two spatial output paths inside the adapter. `pooled` defaults to the historical rs-embed behavior and applies spatial pooling over the AnySat `patch` grid, which preserves the previous pooled vector dimensionality; pass `pooled_source="tile"` (or set `RS_EMBED_ANYSAT_POOLED_SOURCE=tile`) to use the native AnySat tile embedding instead. `grid` defaults to AnySat `dense`, so the returned `(D,H,W)` is a denser sub-patch feature map by default; pass `grid_feature_mode="patch"` to the public API (or set `RS_EMBED_ANYSAT_GRID_MODE=patch`) to recover the older patch-grid behavior. As with other on-the-fly models, this grid is model space rather than guaranteed georeferenced raster pixels.
8993

9094
---
9195

docs/models_reference.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Use this table for a first-pass side-by-side comparison of input assumptions and
4848
| `satmaepp` | `rshf.satmaepp.SatMAEPP` | 10m | S2 RGB (`B4,B3,B2`) | raw SR -> `/10000` -> RGB `uint8`; SatMAE++ fMoW eval preprocessing (`Normalize + Resize(short side) + CenterCrop`), default channel order `bgr` | default 224; source-aligned short-side resize + center crop; no pad | token sequence -> pooled or patch-token grid | High |
4949
| `satmaepp_s2_10b` | SatMAE++ grouped-channel source branch (`models_mae_group_channels.py`, `base` / `large` runtime families) | 10m | S2 SR 10-band (`B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12`) | clip `0..10000`; source Sentinel min/max mapping to `uint8`; `ToTensor + Resize(short side) + CenterCrop` | default 96 with patch size 8; source-style resize/crop; no pad | grouped token sequence -> pooled or group-reduced spatial token grid | High |
5050
| `scalemae` | `rshf.scalemae.ScaleMAE` (ViT style) | 10m | S2 RGB (`B4,B3,B2`) + `input_res_m` | raw SR -> `/10000` -> RGB `uint8`; CLIP norm tensor; pass `input_res_m` | default 224; CLIP path has `Resize + CenterCrop`; no pad | token sequence or pooled vector depending on wrapper output | Medium |
51-
| `anysat` | AnySat from upstream `hubconf.py` (`AnySat`, `tiny` / `small` / `base`) | 10m | S2 10-band TCHW (or CHW auto-expanded) | clip to `0..10000`; normalize mode default `per_tile_zscore`; builds per-frame `s2_dates` | resize TCHW to default 24; no crop, no pad | patch output `[D,H,W]`, pooled by spatial mean/max | Medium |
51+
| `anysat` | AnySat from upstream `hubconf.py` (`AnySat`, `tiny` / `small` / `base`) | 10m | S2 10-band TCHW (or CHW auto-expanded) | clip to `0..10000`; normalize mode default `per_tile_zscore`; builds per-frame `s2_dates` | resize TCHW to default 24; no crop, no pad | grid defaults to dense sub-patch output `[D,H,W]`; pooled defaults to patch-grid mean/max, optional native tile vector | Medium |
5252
| `galileo` | `Encoder` from official `single_file_galileo.py` | 10m | S2 10-band TCHW (or CHW auto-expanded) | clip to `0..10000`; normalize mode default `unit_scale`; constructs Galileo tensors with configurable `T` + per-frame `months`, optional NDVI channel | default 64 with patch 8; bilinear resize; no pad | pooled token vector and S2-group token grid | Medium |
5353
| `wildsat` | WildSAT backbone + optional image head from checkpoint | 10m | S2 RGB CHW | clip to `0..10000` then `/10000`; default normalization `minmax`; convert to `uint8` then unit tensor | default 224; resize RGB; no pad | pooled branch output and optional grid (token or feature path) | Medium-Low |
5454
| `prithvi` | Vendored `PrithviMAE` runtime with HF checkpoints | 30m | S2 6-band (`BLUE,GREEN,RED,NIR_NARROW,SWIR_1,SWIR_2`) | raw SR -> `/10000` -> clamp `[0,1]`; prep mode from env | default mode `resize` to 224; optional `pad` to patch multiple (legacy) | token sequence -> pooled or patch-token grid | Medium |
@@ -126,7 +126,7 @@ This table only lists env vars that materially change model input construction o
126126
| `satmaepp` | `RS_EMBED_SATMAEPP_ID`, `RS_EMBED_SATMAEPP_IMG`, `RS_EMBED_SATMAEPP_CHANNEL_ORDER`, `RS_EMBED_SATMAEPP_BGR` |
127127
| `satmaepp_s2_10b` | `RS_EMBED_SATMAEPP_S2_CKPT_REPO`, `RS_EMBED_SATMAEPP_S2_CKPT_FILE`, `RS_EMBED_SATMAEPP_S2_MODEL_FN`, `RS_EMBED_SATMAEPP_S2_IMG`, `RS_EMBED_SATMAEPP_S2_PATCH`, `RS_EMBED_SATMAEPP_S2_GRID_REDUCE`, `RS_EMBED_SATMAEPP_S2_WEIGHTS_ONLY` |
128128
| `scalemae` | `RS_EMBED_SCALEMAE_IMG` |
129-
| `anysat` | `RS_EMBED_ANYSAT_IMG`, `RS_EMBED_ANYSAT_NORM`, `RS_EMBED_ANYSAT_FRAMES` |
129+
| `anysat` | `RS_EMBED_ANYSAT_IMG`, `RS_EMBED_ANYSAT_NORM`, `RS_EMBED_ANYSAT_FRAMES`, `RS_EMBED_ANYSAT_GRID_MODE`, `RS_EMBED_ANYSAT_POOLED_SOURCE` |
130130
| `galileo` | `RS_EMBED_GALILEO_IMG`, `RS_EMBED_GALILEO_PATCH`, `RS_EMBED_GALILEO_NORM`, `RS_EMBED_GALILEO_INCLUDE_NDVI`, `RS_EMBED_GALILEO_FRAMES`, `RS_EMBED_GALILEO_MONTH` |
131131
| `wildsat` | `RS_EMBED_WILDSAT_IMG`, `RS_EMBED_WILDSAT_NORM` |
132132
| `prithvi` | `RS_EMBED_PRITHVI_PREP`, `RS_EMBED_PRITHVI_IMG`, `RS_EMBED_PRITHVI_PATCH_MULT` |

examples/playground.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@
656656
" ], # or you can spcify as [\"remoteclip\", \"terrafm\"]\n",
657657
" target=ExportTarget.per_item(\"exports\", names=[\"p1\", \"p2\"]),\n",
658658
" output=OutputSpec.grid(),\n",
659-
" config=ExportConfig(save_inputs=True, input_prep='tile'),\n",
659+
" config=ExportConfig(save_inputs=True, input_prep=\"tile\"),\n",
660660
" backend=\"gee\",\n",
661661
")"
662662
]

src/rs_embed/embedders/onthefly_anysat.py

Lines changed: 140 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,38 @@ def _normalize_anysat_model_size(model_size: Any) -> str:
101101
return resolved
102102

103103

104+
def _normalize_anysat_grid_feature_mode(value: Any) -> str:
105+
raw = str(value).strip().lower()
106+
aliases = {
107+
"dense": "dense",
108+
"subpatch": "dense",
109+
"subpatches": "dense",
110+
"patch": "patch",
111+
"patches": "patch",
112+
}
113+
resolved = aliases.get(raw)
114+
if resolved is None:
115+
raise ModelError(
116+
f"Unknown AnySat grid_feature_mode={value!r} (expected one of: dense, patch)."
117+
)
118+
return resolved
119+
120+
121+
def _normalize_anysat_pooled_source(value: Any) -> str:
122+
raw = str(value).strip().lower()
123+
aliases = {
124+
"patch": "patch",
125+
"grid": "patch",
126+
"tile": "tile",
127+
"native": "tile",
128+
"model": "tile",
129+
}
130+
resolved = aliases.get(raw)
131+
if resolved is None:
132+
raise ModelError(f"Unknown AnySat pooled_source={value!r} (expected one of: patch, tile).")
133+
return resolved
134+
135+
104136
def _resolve_anysat_runtime_config(
105137
*,
106138
model_config: dict[str, Any] | None,
@@ -129,6 +161,16 @@ def _resolve_anysat_runtime_config(
129161

130162
norm_mode = os.environ.get("RS_EMBED_ANYSAT_NORM", "per_tile_zscore").strip()
131163

164+
grid_feature_mode_v = model_config_value(model_config, "grid_feature_mode")
165+
if grid_feature_mode_v is None:
166+
grid_feature_mode_v = os.environ.get("RS_EMBED_ANYSAT_GRID_MODE", "dense")
167+
grid_feature_mode = _normalize_anysat_grid_feature_mode(grid_feature_mode_v)
168+
169+
pooled_source_v = model_config_value(model_config, "pooled_source")
170+
if pooled_source_v is None:
171+
pooled_source_v = os.environ.get("RS_EMBED_ANYSAT_POOLED_SOURCE", "patch")
172+
pooled_source = _normalize_anysat_pooled_source(pooled_source_v)
173+
132174
ckpt_path = os.environ.get("RS_EMBED_ANYSAT_CKPT")
133175
ckpt_path = ckpt_path or None
134176

@@ -156,6 +198,8 @@ def _resolve_anysat_runtime_config(
156198
"image_size": int(image_size),
157199
"n_frames": int(n_frames),
158200
"norm_mode": norm_mode,
201+
"grid_feature_mode": grid_feature_mode,
202+
"pooled_source": pooled_source,
159203
"ckpt_path": ckpt_path,
160204
"hf_repo": hf_repo,
161205
"hf_filename": hf_filename,
@@ -385,11 +429,12 @@ def _prepare_anysat_s2_input(
385429
}
386430

387431

388-
def _anysat_patch_features(
432+
def _anysat_spatial_features(
389433
model: Any,
390434
s2_input: dict[str, Any],
391435
*,
392436
patch_size_m: int,
437+
feature_mode: str,
393438
) -> tuple[np.ndarray, dict[str, Any]]:
394439
ensure_torch()
395440
import torch
@@ -399,27 +444,73 @@ def _anysat_patch_features(
399444
f"AnySat patch_size must be a positive multiple of 10 (meters), got {patch_size_m}"
400445
)
401446

447+
native_output = str(feature_mode).strip().lower()
448+
if native_output not in {"patch", "dense"}:
449+
raise ModelError(
450+
f"Unknown AnySat feature_mode={feature_mode!r} (expected 'patch' or 'dense')."
451+
)
452+
402453
with torch.no_grad():
403-
out = model(s2_input, patch_size=int(patch_size_m), output="patch")
454+
out = model(s2_input, patch_size=int(patch_size_m), output=native_output)
404455

405456
if not hasattr(out, "ndim") or int(out.ndim) != 4:
406457
raise ModelError(
407-
f"AnySat output='patch' returned unexpected shape/type: {type(out)} {getattr(out, 'shape', None)}"
458+
"AnySat output="
459+
f"{native_output!r} returned unexpected shape/type: {type(out)} {getattr(out, 'shape', None)}"
408460
)
409461

410-
# AnySat patch output: [B,H,W,D]
462+
# AnySat spatial outputs are [B,H,W,D] for both patch and dense.
411463
if int(out.shape[0]) != 1:
412464
raise ModelError(f"AnySat embedder expects B=1 per call, got {tuple(out.shape)}")
413465
arr = out[0].detach().float().cpu().numpy().astype(np.float32) # [H,W,D]
414466
grid = np.transpose(arr, (2, 0, 1)).astype(np.float32) # [D,H,W]
415467
meta = {
416-
"patch_output_hw": (int(arr.shape[0]), int(arr.shape[1])),
468+
"feature_mode": native_output,
469+
"native_output_hw": (int(arr.shape[0]), int(arr.shape[1])),
417470
"feature_dim": int(arr.shape[2]),
418471
"patch_size_m": int(patch_size_m),
419472
}
473+
if native_output == "patch":
474+
meta["patch_output_hw"] = meta["native_output_hw"]
475+
else:
476+
meta["dense_output_hw"] = meta["native_output_hw"]
420477
return grid, meta
421478

422479

480+
def _anysat_tile_features(
481+
model: Any,
482+
s2_input: dict[str, Any],
483+
*,
484+
patch_size_m: int,
485+
) -> tuple[np.ndarray, dict[str, Any]]:
486+
ensure_torch()
487+
import torch
488+
489+
if patch_size_m <= 0 or (patch_size_m % 10) != 0:
490+
raise ModelError(
491+
f"AnySat patch_size must be a positive multiple of 10 (meters), got {patch_size_m}"
492+
)
493+
494+
with torch.no_grad():
495+
out = model(s2_input, patch_size=int(patch_size_m), output="tile")
496+
497+
if not hasattr(out, "ndim") or int(out.ndim) != 2:
498+
raise ModelError(
499+
"AnySat output='tile' returned unexpected shape/type: "
500+
f"{type(out)} {getattr(out, 'shape', None)}"
501+
)
502+
if int(out.shape[0]) != 1:
503+
raise ModelError(f"AnySat embedder expects B=1 per call, got {tuple(out.shape)}")
504+
505+
vec = out[0].detach().float().cpu().numpy().astype(np.float32)
506+
meta = {
507+
"pooled_source": "tile",
508+
"feature_dim": int(vec.shape[0]),
509+
"patch_size_m": int(patch_size_m),
510+
}
511+
return vec, meta
512+
513+
423514
@register("anysat")
424515
class AnySatEmbedder(EmbedderBase):
425516
DEFAULT_FETCH_WORKERS = 8
@@ -458,19 +549,31 @@ def describe(self) -> dict[str, Any]:
458549
"cloudy_pct": self.input_spec.cloudy_pct,
459550
"composite": self.input_spec.composite,
460551
"normalization": "per_tile_zscore",
552+
"grid_feature_mode": "dense",
553+
"pooled_source": "patch",
461554
},
462555
"model_config": {
463556
"variant": {
464557
"type": "string",
465558
"default": "base",
466559
"choices": ["base"],
467-
}
560+
},
561+
"grid_feature_mode": {
562+
"type": "string",
563+
"default": "dense",
564+
"choices": ["dense", "patch"],
565+
},
566+
"pooled_source": {
567+
"type": "string",
568+
"default": "patch",
569+
"choices": ["patch", "tile"],
570+
},
468571
},
469572
"notes": [
470573
"AnySat expects S2 time-series + day-of-year dates.",
471574
"This adapter builds T frames by splitting TemporalSpec.range into equal sub-windows.",
472575
"Loads AnySat from a vendored local runtime and optional Hugging Face checkpoint.",
473-
"grid output maps AnySat output='patch' to [D,H,W].",
576+
"grid output defaults to AnySat output='dense'; pooled defaults to patch-grid pooling but can opt into native tile output.",
474577
],
475578
}
476579

@@ -514,6 +617,8 @@ def get_embedding(
514617
flash_attn = bool(runtime_cfg["flash_attn"])
515618
image_size = int(runtime_cfg["image_size"])
516619
norm_mode = str(runtime_cfg["norm_mode"])
620+
grid_feature_mode = str(runtime_cfg["grid_feature_mode"])
621+
pooled_source = str(runtime_cfg["pooled_source"])
517622
patch_size_m = int(ss.scale_m)
518623

519624
if input_chw is None:
@@ -557,11 +662,22 @@ def get_embedding(
557662
norm_mode=norm_mode,
558663
device=dev,
559664
)
560-
grid, fmeta = _anysat_patch_features(
561-
model,
562-
s2_input,
563-
patch_size_m=patch_size_m,
564-
)
665+
feature_mode = grid_feature_mode if output.mode == "grid" else "patch"
666+
pooled_vec = None
667+
if output.mode == "pooled" and pooled_source == "tile":
668+
pooled_vec, fmeta = _anysat_tile_features(
669+
model,
670+
s2_input,
671+
patch_size_m=patch_size_m,
672+
)
673+
grid = None
674+
else:
675+
grid, fmeta = _anysat_spatial_features(
676+
model,
677+
s2_input,
678+
patch_size_m=patch_size_m,
679+
feature_mode=feature_mode,
680+
)
565681

566682
meta = build_meta(
567683
model=self.model_name,
@@ -582,6 +698,8 @@ def get_embedding(
582698
"model_size": model_size,
583699
"flash_attn": bool(flash_attn),
584700
"normalization": norm_mode,
701+
"grid_feature_mode": grid_feature_mode,
702+
"pooled_source": pooled_source,
585703
"start": t.start,
586704
"end": t.end,
587705
"n_frames": int(raw_tchw.shape[0]),
@@ -595,21 +713,29 @@ def get_embedding(
595713
)
596714

597715
if output.mode == "pooled":
716+
if pooled_vec is not None:
717+
ometa = {
718+
**meta,
719+
"pooling": "tile_native",
720+
"requested_pooling": output.pooling,
721+
"pooled_shape": tuple(pooled_vec.shape),
722+
}
723+
return Embedding(data=pooled_vec, meta=ometa)
598724
if output.pooling == "max":
599725
vec = np.max(grid, axis=(1, 2)).astype(np.float32)
600726
else:
601727
vec = np.mean(grid, axis=(1, 2)).astype(np.float32)
602728
ometa = {
603729
**meta,
604-
"pooling": f"patch_{output.pooling}",
730+
"pooling": f"{feature_mode}_{output.pooling}",
605731
"pooled_shape": tuple(vec.shape),
606732
}
607733
return Embedding(data=vec, meta=ometa)
608734

609735
if output.mode == "grid":
610736
gmeta = {
611737
**meta,
612-
"grid_kind": "patch_tokens",
738+
"grid_kind": "dense_subpatch" if feature_mode == "dense" else "patch_tokens",
613739
"grid_hw": (int(grid.shape[1]), int(grid.shape[2])),
614740
"grid_shape": tuple(grid.shape),
615741
}

0 commit comments

Comments
 (0)