Skip to content

Latest commit

 

History

History
265 lines (206 loc) · 7.99 KB

File metadata and controls

265 lines (206 loc) · 7.99 KB

Next Steps: Hyperparameter Tuning & Multi-Project Expansion

Current Baseline (CGWAVES, 12 variables, lite model)

Metric Value
Best val loss 3.22
Train loss 2.48
Variables 12
Model params ~2-3M
Epoch time (T4) ~1-2 min
Probing accuracy 90-100%
Flight artifact leakage None (silhouette -0.21)

Phase 1: Hyperparameter Optimization

1.1 Model Architecture Experiments

Test these configurations sequentially, keeping the best performers:

Experiment d_model Enc layers Dec layers Heads Expected Impact
Baseline (current) 128 4 2 4 -
Wider 192 4 2 6 Better representation capacity
Deeper encoder 128 6 2 4 More complex patterns
Balanced 192 4 3 6 Better reconstruction
Full (original) 256 6 4 8 Maximum capacity

Success criteria: Val loss < 3.0 without increasing flight artifact leakage.

1.2 Learning Rate Schedule

# Experiment A: Warmup + cosine decay
warmup_epochs = 5
def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs
    return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (N_EPOCHS - warmup_epochs)))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Experiment B: Reduce on plateau
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)
Experiment Base LR Schedule Epochs
Current 3e-5 Cosine 30
Higher + warmup 1e-4 Warmup(5) + Cosine 50
Lower + plateau 1e-5 ReduceOnPlateau 50
Cyclic 1e-5 to 1e-4 CyclicLR 40

1.3 Masking Strategy

Experiment Group prob Var prob Strategy
Current 0.4 0.2 Group + variable
Light 0.2 0.15 Easier task
Heavy 0.6 0.3 Harder task
Temporal only 0.0 0.0 Mask 20% of timesteps
Variable only 0.0 0.3 Random variable mask

1.4 Variable Count Scaling

Experiment Variables Description
Current 12 Core atmospheric only
Medium 20-25 Add wind, chemistry
Large 40-50 Add valid derived variables
Full valid All valid Every variable with valid norm stats

For each, monitor:

  • Val loss convergence
  • Epoch time
  • Probing accuracy
  • Flight artifact silhouette score

1.5 Recommended Testing Order

  1. Learning rate experiments (fastest to test)
  2. Masking strategy (same model, different task)
  3. Model size scaling (once LR is optimized)
  4. Variable count (once model size is chosen)

Phase 2: Multi-Project Expansion

2.1 Project Selection

Query available projects:

SELECT p.project_name, p.aircraft, p.year, COUNT(f.id) as n_flights
FROM projects p
JOIN flights f ON p.id = f.id
GROUP BY p.id
ORDER BY n_flights DESC;

Priority projects:

Project Aircraft Why
CGWAVES GV (N677F) Current baseline
GOTHIC/GOTHAAM GV Similar aircraft, different science
CAESAR C-130 Cross-platform transfer test
FRAPPE C-130 Boundary layer focus

2.2 Cross-Project Variable Alignment

  1. Find common variables:
# For each project, get variable sets
project_vars = {}
for project_id in [12, 15, 20]:  # Example IDs
    vars = db_conn.query("""
        SELECT vm.variable_name FROM variable_metadata vm
        JOIN variable_projects vp ON vm.id = vp.variable_id
        WHERE vp.project_id = %s
    """, (project_id,))
    project_vars[project_id] = set(v['variable_name'] for v in vars)

# Find intersection
common_vars = set.intersection(*project_vars.values())
print(f"Common variables: {len(common_vars)}")
  1. Categorize variables:
    • Universal: Present in all projects (ATX, THETA, PSXC, MR)
    • Platform-specific: Differ by aircraft (TASX ranges)
    • Campaign-specific: Only in certain projects (specialized instruments)

2.3 Incremental Training Strategy

Option A: Sequential fine-tuning

1. Train on CGWAVES → baseline model
2. Fine-tune on GOTHIC (lower LR) → expanded model
3. Fine-tune on CAESAR → cross-platform model

Option B: Joint training from scratch

1. Combine all project data into single dataset
2. Add project embedding token (like BERT's [CLS])
3. Train jointly with project-balanced sampling

Option C: Domain adversarial training

1. Add discriminator head predicting project/platform
2. Use gradient reversal to learn project-invariant features
3. Goal: Embeddings useful across projects, not memorizing project-specific patterns

2.4 Data Pipeline Updates

  1. Update local preprocessing:
# In normalization stats cell, loop over projects:
for project_id in PROJECT_IDS:
    stats = compute_normalization_stats(db_conn, project_id, common_vars)
    # Save per-project stats
  1. Update cache export:
# Export windows for each project
for project_id in PROJECT_IDS:
    windows_path = f'windows_{project_id}.npz'
    # ... export logic
  1. Update Colab bootstrap:
# Load multiple projects
all_windows = []
all_flight_ids = []
all_project_ids = []

for project_id in PROJECT_IDS:
    cached = np.load(f'windows_{project_id}.npz')
    all_windows.append(cached['data'][:, :, var_indices])
    all_flight_ids.append(cached['flight_ids'])
    all_project_ids.append(np.full(len(cached['flight_ids']), project_id))

combined_data = np.concatenate(all_windows)

2.5 Evaluation for Multi-Project

Test What it measures
Per-project val loss Does model work for each project?
Cross-project probing Do embeddings transfer?
Platform silhouette Are embeddings platform-invariant?
Project classification Can model distinguish projects? (want: low accuracy)

Phase 3: Validation & Science Readiness

3.1 Held-out Project Test

  1. Train on projects A, B, C
  2. Evaluate on project D (never seen)
  3. Measure: probing accuracy, cluster quality, reconstruction loss

3.2 Known Phenomena Tests

Verify embeddings capture known atmospheric science:

Test Method Expected
Diurnal cycle Cluster by time-of-day Morning/afternoon separation
Altitude structure PCA on embeddings PC1 correlates with PSXC
Cloud detection Probe for PLWCC > threshold >90% accuracy
Turbulence High reconstruction error Correlates with vertical wind variance

3.3 Downstream Task Prototypes

  1. Flight phase segmentation: Cluster embeddings within a flight, label phases
  2. Anomaly detection: Flag windows with high reconstruction error
  3. Cross-project retrieval: Find similar atmospheric states across campaigns
  4. Instrument QC: Detect when instruments disagree with expected patterns

Recommended Execution Timeline

Week Focus Deliverable
1 Hyperparameter tuning (LR, masking) Optimized training config
2 Model scaling experiments Best model architecture
3 Add second project (GOTHIC) Two-project model
4 Cross-platform test (C-130) Platform transfer results
5 Validation tests Science readiness report
6 Downstream task prototype Demo application

Files to Update

File Changes needed
Convergence_MAE.ipynb Multi-project support in data loading
cache/metadata.json Include project list, common variables
cache/norm_stats_*.json Per-project normalization
cache/windows_*.npz Per-project cached windows

Success Metrics

Metric Current Target
Val loss 3.22 < 2.5
Probing accuracy 90-100% Maintain
Platform leakage -0.21 < 0 (negative)
Projects supported 1 3+
Cross-project transfer N/A >80% probing on held-out