Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ lightning_logs
wandb
**/test_data/**/**/*.tif
**/project_data
**/.venv
**/.venv*
**/wandb
37 changes: 37 additions & 0 deletions data/yemen_crop/exp_encoder_patch_2_is16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Encoder patch_size=2, input_size=16 (so 8x8 tokens)
# Need 2x upsample in decoder to match 16x16 targets
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_ps2_is16

model:
init_args:
model:
init_args:
encoder:
- class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth
init_args:
model_id: OLMOEARTH_V1_BASE
patch_size: 2
decoders:
segment:
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 768
out_channels: 512
kernel_size: 1
activation:
class_path: torch.nn.GELU
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 512
out_channels: 9
kernel_size: 1
activation:
class_path: torch.nn.Identity
- class_path: rslearn.models.upsample.Upsample
init_args:
scale_factor: 2
- class_path: rslearn.models.pick_features.PickFeatures
init_args:
indexes: [0]
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
63 changes: 63 additions & 0 deletions data/yemen_crop/exp_encoder_patch_2_is32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Encoder patch_size=2, input_size=32 (so 16x16 tokens)
# Need 2x upsample in decoder to match 32x32 targets
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_ps2_is32

model:
init_args:
model:
init_args:
encoder:
- class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth
init_args:
model_id: OLMOEARTH_V1_BASE
patch_size: 2
decoders:
segment:
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 768
out_channels: 512
kernel_size: 1
activation:
class_path: torch.nn.GELU
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 512
out_channels: 9
kernel_size: 1
activation:
class_path: torch.nn.Identity
- class_path: rslearn.models.upsample.Upsample
init_args:
scale_factor: 2
- class_path: rslearn.models.pick_features.PickFeatures
init_args:
indexes: [0]
- class_path: rslearn.train.tasks.segmentation.SegmentationHead

data:
init_args:
default_config:
patch_size: 32
transforms:
- class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize
init_args:
band_names:
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
train_config:
transforms:
- class_path: rslearn.train.transforms.flip.Flip
init_args:
image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"]
- class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize
init_args:
band_names:
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
groups: ["spatial_split"]
tags:
split: "train"
predict_config:
load_all_patches: true
patch_size: 32
skip_targets: true
32 changes: 32 additions & 0 deletions data/yemen_crop/exp_head_attn_pool.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Attention pooling head
# Requires token_pooling=false to get TokenFeatureMaps from encoder
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_attn_pool_head

model:
init_args:
model:
init_args:
encoder:
- class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth
init_args:
model_id: OLMOEARTH_V1_BASE
patch_size: 1
token_pooling: false
decoders:
segment:
- class_path: rslearn.models.attention_pooling.AttentionPool
init_args:
in_dim: 768
num_heads: 8
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 768
out_channels: 9
kernel_size: 1
activation:
class_path: torch.nn.Identity
- class_path: rslearn.models.pick_features.PickFeatures
init_args:
indexes: [0]
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
35 changes: 35 additions & 0 deletions data/yemen_crop/exp_head_deep_mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Deep MLP head (768 -> 512 -> 256 -> 9)
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_deep_mlp_head

model:
init_args:
model:
init_args:
decoders:
segment:
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 768
out_channels: 512
kernel_size: 1
activation:
class_path: torch.nn.GELU
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 512
out_channels: 256
kernel_size: 1
activation:
class_path: torch.nn.GELU
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 256
out_channels: 9
kernel_size: 1
activation:
class_path: torch.nn.Identity
- class_path: rslearn.models.pick_features.PickFeatures
init_args:
indexes: [0]
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
21 changes: 21 additions & 0 deletions data/yemen_crop/exp_head_linear.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Linear head only (single 1x1 conv: 768 -> 9)
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_linear_head

model:
init_args:
model:
init_args:
decoders:
segment:
- class_path: rslearn.models.conv.Conv
init_args:
in_channels: 768
out_channels: 9
kernel_size: 1
activation:
class_path: torch.nn.Identity
- class_path: rslearn.models.pick_features.PickFeatures
init_args:
indexes: [0]
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
25 changes: 25 additions & 0 deletions data/yemen_crop/exp_head_unet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Full UNet decoder head with encoder patch_size=4
# UNet upsamples from 1/4 resolution back to full resolution
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_unet_head

model:
init_args:
model:
init_args:
encoder:
- class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth
init_args:
model_id: OLMOEARTH_V1_BASE
patch_size: 4
decoders:
segment:
- class_path: rslearn.models.unet.UNetDecoder
init_args:
in_channels: [[4, 768]]
out_channels: 9
conv_layers_per_resolution: 2
kernel_size: 3
num_channels: {4: 512, 2: 256, 1: 128}
target_resolution_factor: 1
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
29 changes: 29 additions & 0 deletions data/yemen_crop/exp_input_size_32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Input size 32x32 pixels
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_is32

data:
init_args:
default_config:
patch_size: 32
transforms:
- class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize
init_args:
band_names:
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
train_config:
transforms:
- class_path: rslearn.train.transforms.flip.Flip
init_args:
image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"]
- class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize
init_args:
band_names:
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
groups: ["spatial_split"]
tags:
split: "train"
predict_config:
load_all_patches: true
patch_size: 32
skip_targets: true
29 changes: 29 additions & 0 deletions data/yemen_crop/exp_input_size_8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Input size 8x8 pixels
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_is8

data:
init_args:
default_config:
patch_size: 8
transforms:
- class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize
init_args:
band_names:
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
train_config:
transforms:
- class_path: rslearn.train.transforms.flip.Flip
init_args:
image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"]
- class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize
init_args:
band_names:
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
groups: ["spatial_split"]
tags:
split: "train"
predict_config:
load_all_patches: true
patch_size: 8
skip_targets: true
25 changes: 25 additions & 0 deletions data/yemen_crop/exp_no_patience_long_freeze.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# No LR patience, longer freeze period before unfreezing encoder
rslp_project: 01_06_yemen_crop
rslp_experiment: yemen_crop_no_patience_long_freeze

model:
init_args:
plateau: false

trainer:
max_epochs: 1000
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_segment/macro_f1
mode: max
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
init_args:
module_selector: ["model", "encoder", 0]
unfreeze_at_epoch: 300
unfreeze_lr_factor: 10
Loading
Loading