The Training & ML Pipeline is the subsystem of WiFi-DensePose that turns raw public CSI datasets into a trained pose estimation model and its downstream derivatives: contrastive embeddings, domain-generalized weights, and deterministic proof bundles. It is the bridge between research data and deployable inference.
This document defines the system using Domain-Driven Design (DDD): bounded contexts that own their data and rules, aggregate roots that enforce invariants, value objects that carry meaning, and domain events that connect everything. The goal is to make the pipeline's structure match the physics and mathematics it implements -- so that anyone reading the code (or an AI agent modifying it) understands why each piece exists, not just what it does.
Bounded Contexts:
| # | Context | Responsibility | Key ADRs | Code |
|---|---|---|---|---|
| 1 | Dataset Management | Load, validate, normalize, and preprocess training data from MM-Fi and Wi-Pose | ADR-015 | train/src/dataset.rs, train/src/subcarrier.rs |
| 2 | Model Architecture | Define the neural network, forward pass, attention mechanisms, and spatial decoding | ADR-016, ADR-020 | train/src/model.rs, train/src/graph_transformer.rs |
| 3 | Training Orchestration | Run the training loop, compute composite loss, checkpoint, and verify deterministic proofs | ADR-015, ADR-016 | train/src/trainer.rs, train/src/losses.rs, train/src/metrics.rs, train/src/proof.rs |
| 4 | Embedding & Transfer | Produce AETHER contrastive embeddings, MERIDIAN domain-generalized features, and LoRA adapters | ADR-024, ADR-027 | train/src/embedding.rs, train/src/domain.rs, train/src/sona.rs |
All code paths shown are relative to rust-port/wifi-densepose-rs/crates/wifi-densepose- unless otherwise noted.
| Term | Definition |
|---|---|
| Training Run | A complete training session: configuration, epoch loop, checkpoint history, and final model weights |
| Epoch | One full pass through the training dataset; produces train loss and validation metrics |
| Checkpoint | A snapshot of model weights at a given epoch, identified by SHA-256 hash and validation PCK |
| CSI Sample | A single observation: amplitude + phase tensors, ground-truth keypoints, and visibility flags |
| Subcarrier Interpolation | Resampling CSI from source subcarrier count to the canonical 56 (114->56 for MM-Fi, 30->56 for Wi-Pose) |
| Teacher-Student | Training regime where a camera-based RGB model generates pseudo-labels; at inference the camera is removed |
| Pseudo-Label | DensePose UV surface coordinates generated by Detectron2 from paired RGB frames |
| PCK@0.2 | Percentage of Correct Keypoints within 20% of torso diameter; primary accuracy metric |
| OKS | Object Keypoint Similarity; per-keypoint Gaussian-weighted distance used in COCO evaluation |
| MPJPE | Mean Per Joint Position Error in millimeters; 3D accuracy metric |
| Hungarian Assignment | Bipartite matching of predicted persons to ground-truth using min-cost assignment |
| Dynamic Min-Cut | Subpolynomial O(n^1.5 log n) person-to-GT assignment maintained across frames |
| Compressed CSI Buffer | Tiered-quantization temporal window: hot frames at 8-bit, warm at 5/7-bit, cold at 3-bit |
| Proof Verification | Deterministic check: fixed seed -> N training steps -> loss decreases AND SHA-256 hash matches |
| AETHER Embedding | 128-dim L2-normalized contrastive vector from the CsiToPoseTransformer backbone |
| InfoNCE Loss | Contrastive loss that pushes same-identity embeddings together and different-identity apart |
| HNSW Index | Hierarchical Navigable Small World graph for approximate nearest-neighbor embedding search |
| Domain Factorizer | Splits latent features into pose-invariant (h_pose) and environment-specific (h_env) components |
| Gradient Reversal Layer | Identity in forward pass; multiplies gradient by -lambda in backward pass to force domain invariance |
| GRL Lambda | Adversarial weight annealed from 0.0 to 1.0 over the first 20 epochs |
| FiLM Conditioning | Feature-wise Linear Modulation: gamma * features + beta, conditioned on geometry encoding |
| Hardware Normalizer | Resamples CSI from any chipset to canonical 56 subcarriers with z-score amplitude normalization |
| LoRA Adapter | Low-Rank Adaptation weights (rank r, alpha) for few-shot environment-specific fine-tuning |
| Rapid Adaptation | 10-second unlabeled calibration producing a per-room LoRA adapter via contrastive test-time training |
Responsibility: Load raw CSI data from public datasets (MM-Fi, Wi-Pose), validate structural invariants, resample subcarriers to the canonical 56, apply phase sanitization, and present typed samples to the training loop. Memory efficiency via tiered temporal compression.
+----------------------------------------------------------+
| Dataset Management Context |
+----------------------------------------------------------+
| |
| +---------------+ +---------------+ |
| | MM-Fi Loader | | Wi-Pose | |
| | (.npy files, | | Loader | |
| | 114 sub, | | (.mat files, | |
| | 40 subjects)| | 30 sub, | |
| +-------+-------+ | 12 subjects)| |
| | +-------+-------+ |
| | | |
| +--------+-----------+ |
| v |
| +----------------+ |
| | Subcarrier | |
| | Interpolator | |
| | (114->56 or | |
| | 30->56) | |
| +--------+-------+ |
| v |
| +----------------+ |
| | Phase | |
| | Sanitizer | |
| | (SOTA algs | |
| | from signal) | |
| +--------+-------+ |
| v |
| +----------------+ |
| | Compressed CSI |--> CsiSample |
| | Buffer | |
| | (tiered quant) | |
| +----------------+ |
| |
+----------------------------------------------------------+
Aggregates:
MmFiDataset(Aggregate Root) -- Manages the MM-Fi data lifecycleWiPoseDataset(Aggregate Root) -- Manages the Wi-Pose data lifecycle
Value Objects:
CsiSample-- Single observation with amplitude, phase, keypoints, visibilitySubcarrierConfig-- Source count, target count, interpolation methodDatasetSplit-- Train / Validation / Test subject partitioningCompressedCsiBuffer-- Tiered temporal window backed byTemporalTensorCompressor
Domain Services:
SubcarrierInterpolationService-- Resamples subcarriers via sparse least-squares or linear fallbackPhaseSanitizationService-- Applies SpotFi / MUSIC phase correction fromwifi-densepose-signalTeacherLabelService-- Runs Detectron2 on paired RGB frames to produce DensePose UV pseudo-labelsHardwareNormalizerService-- Z-score normalization + chipset-invariant phase sanitization
RuVector Integration:
ruvector-solver->NeumannSolverfor sparse O(sqrt(n)) subcarrier interpolation (114->56)ruvector-temporal-tensor->TemporalTensorCompressorfor 50-75% memory reduction in CSI windows
Responsibility: Define the WiFiDensePoseModel: CSI embedding, cross-attention between keypoint queries and CSI features, GNN message passing, attention-gated modality fusion, and spatial decoding heads for keypoints and DensePose UV.
+----------------------------------------------------------+
| Model Architecture Context |
+----------------------------------------------------------+
| |
| +---------------+ +---------------+ |
| | CSI Embed | | Keypoint | |
| | (Linear | | Queries | |
| | 56 -> d) | | (17 learned | |
| +-------+-------+ | embeddings) | |
| | +-------+-------+ |
| | | |
| +--------+-----------+ |
| v |
| +----------------+ |
| | Cross-Attention| |
| | (Q=queries, | |
| | K,V=csi) | |
| +--------+-------+ |
| v |
| +----------------+ |
| | GNN Stack | |
| | (2-layer GCN | |
| | skeleton | |
| | adjacency) | |
| +--------+-------+ |
| v |
| body_part_features [17 x d_model] |
| | |
| +-------+--------+--------+ |
| v v v v |
| +----------+ +------+ +-----+ +-------+ |
| | Modality | | xyz | | UV | |Spatial| |
| | Transl. | | Head | | Head| |Attn | |
| | (attn | | | | | |Decoder| |
| | mincut) | | | | | | | |
| +----------+ +------+ +-----+ +-------+ |
| |
+----------------------------------------------------------+
Aggregates:
WiFiDensePoseModel(Aggregate Root) -- The complete model graph
Entities:
ModalityTranslator-- Attention-gated CSI fusion using min-cutCsiToPoseTransformer-- Cross-attention + GNN backboneKeypointHead-- Regresses 17 x (x, y, z, confidence) from body_part_featuresDensePoseHead-- Predicts body part labels and UV surface coordinates
Value Objects:
ModelConfig-- Architecture hyperparameters (d_model, n_heads, n_gnn_layers)AttentionOutput-- Attended values + gating result from min-cut attentionBodyPartFeatures-- [17 x d_model] intermediate representation
Domain Services:
AttentionGatingService-- Appliesattn_mincutto prune irrelevant antenna pathsSpatialDecodingService-- Graph-based spatial attention among feature map locations
RuVector Integration:
ruvector-attn-mincut->attn_mincutfor antenna-path gating in ModalityTranslatorruvector-attention->ScaledDotProductAttentionfor spatial decoder long-range dependencies
Responsibility: Run the training loop across epochs, compute the composite loss (keypoint MSE + DensePose part CE + UV Smooth L1 + transfer MSE), evaluate validation metrics (PCK@0.2, OKS, MPJPE), manage checkpoints, and verify deterministic proof correctness.
+----------------------------------------------------------+
| Training Orchestration Context |
+----------------------------------------------------------+
| |
| +---------------+ +---------------+ |
| | Training Loop | | Loss Computer | |
| | (epoch iter, | | (composite: | |
| | batch fwd/ | | kp_mse + | |
| | bwd, optim) | | part_ce + | |
| +-------+-------+ | uv_l1 + | |
| | | transfer) | |
| | +-------+-------+ |
| +--------+-----------+ |
| v |
| +----------------+ |
| | Metric | |
| | Evaluator | |
| | (PCK, OKS, | |
| | MPJPE, | |
| | Hungarian) | |
| +--------+-------+ |
| v |
| +-------------+-------------+ |
| v v |
| +----------------+ +----------------+ |
| | Checkpoint | | Proof Verifier | |
| | Manager | | (fixed seed, | |
| | (best-by-PCK, | | 50 steps, | |
| | SHA-256 hash) | | loss + hash) | |
| +----------------+ +----------------+ |
| |
+----------------------------------------------------------+
Aggregates:
TrainingRun(Aggregate Root) -- The complete training session
Entities:
CheckpointManager-- Persists and selects model snapshotsProofVerifier-- Deterministic verification against stored hashes
Value Objects:
TrainingConfig-- Epochs, batch_size, learning_rate, loss_weights, optimizer paramsCheckpoint-- Epoch number, model weights SHA-256, validation PCK at that epochLossWeights-- Relative weights for each loss componentCompositeTrainingLoss-- Combined scalar loss with per-component breakdownOksScore-- Per-keypoint Object Keypoint Similarity with sigma valuesPckScore-- Percentage of Correct Keypoints at threshold 0.2MpjpeScore-- Mean Per Joint Position Error in millimetersProofResult-- Seed, steps, loss_decreased flag, hash_matches flag
Domain Services:
LossComputationService-- Computes composite loss from model outputs and ground truthMetricEvaluationService-- Computes PCK, OKS, MPJPE over validation setHungarianAssignmentService-- Bipartite matching for multi-person evaluationDynamicPersonMatcherService-- Frame-persistent assignment viaruvector-mincutProofVerificationService-- Fixed-seed training + SHA-256 verification
RuVector Integration:
ruvector-mincut->DynamicMinCutfor O(n^1.5 log n) multi-person assignment in metrics- Original
hungarian_assignmentkept for single-frame static matching in proof verification
Responsibility: Produce AETHER contrastive embeddings from the model backbone, train domain-adversarial features via MERIDIAN, manage the HNSW embedding index for re-ID and fingerprinting, and generate LoRA adapters for few-shot environment adaptation.
+----------------------------------------------------------+
| Embedding & Transfer Context |
+----------------------------------------------------------+
| |
| body_part_features [17 x d_model] |
| | |
| +--------+-----------+ |
| v v |
| +---------------+ +---------------+ |
| | AETHER | | MERIDIAN | |
| | Projection | | Domain | |
| | Head | | Factorizer | |
| | (MeanPool -> | | (PoseEncoder | |
| | fc -> 128d) | | + EnvEncoder)| |
| +-------+-------+ +-------+-------+ |
| | | |
| v v |
| +---------------+ +---------------+ |
| | InfoNCE Loss | | Gradient | |
| | + Hard Neg | | Reversal | |
| | Mining (HNSW) | | Layer (GRL) | |
| +-------+-------+ +-------+-------+ |
| | | |
| v v |
| +---------------+ +---------------+ |
| | Embedding | | Geometry | |
| | Index (HNSW) | | Encoder + | |
| | (fingerprint | | FiLM Cond. | |
| | store) | | (zero-shot) | |
| +---------------+ +-------+-------+ |
| | |
| v |
| +---------------+ |
| | Rapid Adapt. | |
| | (LoRA + TTT, | |
| | 10-sec cal.) | |
| +---------------+ |
| |
+----------------------------------------------------------+
Aggregates:
EmbeddingIndex(Aggregate Root) -- HNSW-indexed store of AETHER fingerprintsDomainAdaptationState(Aggregate Root) -- Tracks GRL lambda, domain classifier accuracy, factorization quality
Entities:
ProjectionHead-- MLP mapping body_part_features to 128-dim embedding spaceDomainFactorizer-- Splits features into h_pose and h_envDomainClassifier-- Classifies domain from h_pose (trained adversarially via GRL)GeometryEncoder-- Fourier positional encoding + DeepSets for AP positionsLoraAdapter-- Low-rank adaptation weights for environment-specific fine-tuning
Value Objects:
AetherEmbedding-- 128-dim L2-normalized contrastive vectorFingerprintType-- ReIdentification / RoomFingerprint / PersonFingerprintDomainLabel-- Environment identifier for adversarial trainingGrlSchedule-- Lambda annealing parameters (max_lambda, warmup_epochs)GeometryInput-- AP positions in meters relative to room originFilmParameters-- Gamma (scale) and beta (shift) vectors from geometry conditioningLoraConfig-- Rank, alpha, target layersAdaptationLoss-- ContrastiveTTT / EntropyMin / Combined
Domain Services:
ContrastiveLossService-- Computes InfoNCE loss with temperature scalingHardNegativeMiningService-- HNSW k-NN search for difficult negative pairsDomainAdversarialService-- Manages GRL annealing and domain classificationGeometryConditioningService-- Encodes AP layout and produces FiLM parametersVirtualDomainAugmentationService-- Generates synthetic environment shifts for training diversityRapidAdaptationService-- Produces LoRA adapter from 10-second unlabeled calibration
pub struct TrainingRun {
/// Unique run identifier
pub id: TrainingRunId,
/// Full training configuration
pub config: TrainingConfig,
/// Datasets loaded for this run
pub datasets: Vec<DatasetHandle>,
/// Ordered history of per-epoch metrics
pub epoch_history: Vec<EpochRecord>,
/// Best checkpoint by validation PCK
pub best_checkpoint: Option<Checkpoint>,
/// Current epoch (0-indexed)
pub current_epoch: usize,
/// Run status
pub status: RunStatus,
/// Proof verification result (if run)
pub proof_result: Option<ProofResult>,
}
pub enum RunStatus {
Initializing,
Training,
Completed,
Failed { reason: String },
ProofVerified,
}Invariants:
- Must have at least 1 dataset loaded before transitioning to
Training best_checkpointis updated only when a new epoch's validation PCK exceeds all prior epochsproof_resultcan only be set once and is immutable after verification
pub struct MmFiDataset {
/// Root directory containing .npy files
pub data_root: PathBuf,
/// Subject IDs in this split
pub subject_ids: Vec<u32>,
/// Number of action classes
pub n_actions: usize, // 27
/// Source subcarrier count
pub source_subcarriers: usize, // 114
/// Target subcarrier count after interpolation
pub target_subcarriers: usize, // 56
/// Antenna configuration: 1 TX x 3 RX
pub antenna_pairs: usize, // 3
/// Sampling rate in Hz
pub sample_rate_hz: f32, // 100.0
/// Temporal window size (frames per sample)
pub window_frames: usize, // 10
/// Compressed buffer for memory-efficient storage
pub buffer: CompressedCsiBuffer,
/// Total loaded samples
pub n_samples: usize,
}pub struct WiPoseDataset {
/// Root directory containing .mat files
pub data_root: PathBuf,
/// Subject IDs in this split
pub subject_ids: Vec<u32>,
/// Source subcarrier count
pub source_subcarriers: usize, // 30
/// Target subcarrier count after zero-padding
pub target_subcarriers: usize, // 56
/// Antenna configuration: 3 TX x 3 RX
pub antenna_pairs: usize, // 9
/// Keypoint count (18 AlphaPose, mapped to 17 COCO)
pub source_keypoints: usize, // 18
/// Compressed buffer
pub buffer: CompressedCsiBuffer,
/// Total loaded samples
pub n_samples: usize,
}pub struct WiFiDensePoseModel {
/// CSI embedding layer: Linear(56, d_model)
pub csi_embed: Linear,
/// Learned keypoint query embeddings [17 x d_model]
pub keypoint_queries: Tensor,
/// Cross-attention: Q=queries, K,V=csi_embed
pub cross_attention: MultiHeadAttention,
/// GNN message passing on skeleton graph
pub gnn_stack: GnnStack,
/// Modality translator with attention-gated fusion
pub modality_translator: ModalityTranslator,
/// Keypoint regression head
pub keypoint_head: KeypointHead,
/// DensePose UV prediction head
pub densepose_head: DensePoseHead,
/// Spatial attention decoder
pub spatial_decoder: SpatialAttentionDecoder,
/// Model dimensionality
pub d_model: usize, // 64
}pub struct EmbeddingIndex {
/// HNSW graph for approximate nearest-neighbor search
pub hnsw: HnswIndex,
/// Stored embeddings with metadata
pub entries: Vec<EmbeddingEntry>,
/// Embedding dimensionality
pub dim: usize, // 128
/// Number of indexed embeddings
pub count: usize,
/// HNSW construction parameters
pub ef_construction: usize, // 200
pub m_connections: usize, // 16
}
pub struct EmbeddingEntry {
pub id: EmbeddingId,
pub embedding: Vec<f32>, // [128], L2-normalized
pub fingerprint_type: FingerprintType,
pub source_domain: Option<DomainLabel>,
pub created_at: u64,
}
pub enum FingerprintType {
ReIdentification,
RoomFingerprint,
PersonFingerprint,
}pub struct CsiSample {
/// Amplitude tensor [n_antenna_pairs x n_subcarriers x n_time_frames]
pub amplitude: Vec<f32>,
/// Phase tensor [n_antenna_pairs x n_subcarriers x n_time_frames]
pub phase: Vec<f32>,
/// Ground-truth 3D keypoints [17 x 3] (x, y, z in meters)
pub keypoints: [[f32; 3]; 17],
/// Per-keypoint visibility flags
pub visibility: [f32; 17],
/// DensePose UV pseudo-labels (optional, from teacher model)
pub densepose_uv: Option<DensePoseLabels>,
/// Domain label for adversarial training
pub domain_label: Option<DomainLabel>,
/// Hardware source type
pub hardware_type: HardwareType,
}pub struct TrainingConfig {
/// Number of training epochs
pub epochs: usize,
/// Mini-batch size
pub batch_size: usize,
/// Initial learning rate
pub learning_rate: f64, // 1e-3
/// Learning rate schedule: step decay at these epochs
pub lr_decay_epochs: Vec<usize>, // [40, 80]
/// Learning rate decay factor
pub lr_decay_factor: f64, // 0.1
/// Loss component weights
pub loss_weights: LossWeights,
/// Optimizer (Adam)
pub optimizer: OptimizerConfig,
/// Validation subject IDs (MM-Fi: 33-40)
pub val_subjects: Vec<u32>,
/// Random seed for reproducibility
pub seed: u64,
/// Enable MERIDIAN domain-adversarial training
pub meridian_enabled: bool,
/// Enable AETHER contrastive learning
pub aether_enabled: bool,
}
pub struct LossWeights {
/// Keypoint heatmap MSE weight
pub keypoint_mse: f32, // 1.0
/// DensePose body part cross-entropy weight
pub densepose_part_ce: f32, // 0.5
/// DensePose UV Smooth L1 weight
pub uv_smooth_l1: f32, // 0.5
/// Teacher-student transfer MSE weight
pub transfer_mse: f32, // 0.2
/// AETHER contrastive loss weight (ADR-024)
pub contrastive: f32, // 0.1
/// MERIDIAN domain adversarial weight (ADR-027)
pub domain_adversarial: f32, // annealed 0.0 -> 1.0
}pub struct Checkpoint {
/// Epoch at which this checkpoint was saved
pub epoch: usize,
/// SHA-256 hash of serialized model weights
pub weights_hash: String,
/// Validation PCK@0.2 at this epoch
pub validation_pck: f64,
/// Validation OKS at this epoch
pub validation_oks: f64,
/// File path to saved weights
pub path: PathBuf,
/// Timestamp
pub created_at: u64,
}pub struct ProofResult {
/// Seed used for model initialization
pub model_seed: u64, // MODEL_SEED = 0
/// Seed used for proof data generation
pub proof_seed: u64, // PROOF_SEED = 42
/// Number of training steps in proof
pub steps: usize, // 50
/// Whether loss decreased monotonically
pub loss_decreased: bool,
/// Whether final weights hash matches stored expected hash
pub hash_matches: bool,
/// The computed SHA-256 hash
pub computed_hash: String,
/// The expected SHA-256 hash (from file)
pub expected_hash: String,
}pub struct LoraAdapter {
/// Low-rank decomposition rank
pub rank: usize, // 4
/// LoRA alpha scaling factor
pub alpha: f32, // 1.0
/// Per-layer weight matrices (A and B for each adapted layer)
pub weights: Vec<LoraLayerWeights>,
/// Source domain this adapter was calibrated for
pub source_domain: DomainLabel,
/// Calibration duration in seconds
pub calibration_duration_secs: f32,
/// Number of calibration frames used
pub calibration_frames: usize,
}
pub struct LoraLayerWeights {
/// Layer name in the model
pub layer_name: String,
/// Down-projection: [d_model x rank]
pub a: Vec<f32>,
/// Up-projection: [rank x d_model]
pub b: Vec<f32>,
}pub enum DatasetEvent {
/// Dataset loaded and validated
DatasetLoaded {
dataset_type: DatasetType,
n_samples: usize,
n_subjects: u32,
source_subcarriers: usize,
timestamp: u64,
},
/// Subcarrier interpolation completed for a dataset
SubcarrierInterpolationComplete {
dataset_type: DatasetType,
source_count: usize,
target_count: usize,
method: InterpolationMethod,
timestamp: u64,
},
/// Teacher pseudo-labels generated for a batch
PseudoLabelsGenerated {
n_samples: usize,
n_with_uv: usize,
timestamp: u64,
},
}
pub enum DatasetType {
MmFi,
WiPose,
Synthetic,
}
pub enum InterpolationMethod {
/// ruvector-solver NeumannSolver sparse least-squares
SparseNeumannSolver,
/// Fallback linear interpolation
LinearInterpolation,
/// Wi-Pose zero-padding
ZeroPad,
}pub enum TrainingEvent {
/// One epoch of training completed
EpochCompleted {
epoch: usize,
train_loss: f64,
val_pck: f64,
val_oks: f64,
val_mpjpe_mm: f64,
learning_rate: f64,
grl_lambda: f32,
timestamp: u64,
},
/// New best checkpoint saved
CheckpointSaved {
epoch: usize,
weights_hash: String,
validation_pck: f64,
path: String,
timestamp: u64,
},
/// Deterministic proof verification completed
ProofVerified {
model_seed: u64,
proof_seed: u64,
steps: usize,
loss_decreased: bool,
hash_matches: bool,
timestamp: u64,
},
/// Training run completed or failed
TrainingRunFinished {
run_id: String,
status: RunStatus,
total_epochs: usize,
best_pck: f64,
best_oks: f64,
timestamp: u64,
},
}pub enum EmbeddingEvent {
/// New AETHER embedding indexed
EmbeddingIndexed {
embedding_id: String,
fingerprint_type: FingerprintType,
nearest_neighbor_distance: f32,
index_size: usize,
timestamp: u64,
},
/// Hard negative pair discovered during mining
HardNegativeFound {
anchor_id: String,
negative_id: String,
similarity: f32,
timestamp: u64,
},
/// Domain adaptation completed for a target environment
DomainAdaptationComplete {
source_domain: String,
target_domain: String,
pck_before: f64,
pck_after: f64,
adaptation_method: String,
timestamp: u64,
},
/// LoRA adapter generated via rapid calibration
LoraAdapterGenerated {
domain: String,
rank: usize,
calibration_frames: usize,
calibration_seconds: f32,
timestamp: u64,
},
}- MM-Fi samples must be interpolated from 114 to 56 subcarriers before use in training
- Wi-Pose samples must be zero-padded from 30 to 56 subcarriers before use in training
- Wi-Pose keypoints must be mapped from 18 (AlphaPose) to 17 (COCO) by dropping neck index 1
- All CSI amplitudes must be finite and non-negative after loading
- Phase values must be in [-pi, pi] after sanitization
- Validation subjects (MM-Fi: 33-40) must never appear in the training split
CompressedCsiBuffermust preserve signal fidelity within quantization error bounds (hot: <1% error)
csi_embedinput dimension must equal the canonical 56 subcarrierskeypoint_queriesmust have exactly 17 entries (one per COCO keypoint)attn_mincutseq_len must equal n_antenna_pairs * n_time_frames- GNN adjacency matrix must encode the human skeleton topology (17 nodes, 16 edges)
- Spatial attention decoder must preserve spatial resolution (no information loss in reshape)
- TrainingRun must have at least 1 dataset loaded before
start()is called - Proof verification requires fixed seeds: MODEL_SEED=0, PROOF_SEED=42
- Proof verification uses exactly 50 training steps on deterministic SyntheticDataset
- Loss must decrease over proof steps (otherwise proof fails)
- SHA-256 hash of final weights must match stored expected hash (otherwise proof fails)
best_checkpointis updated if and only if current val_pck > all previous val_pck values- Learning rate decays by factor 0.1 at epochs 40 and 80 (step schedule)
- Hungarian assignment for static single-frame matching must use the deterministic implementation (not DynamicMinCut) during proof verification
- AETHER embeddings must be L2-normalized (unit norm) before indexing in HNSW
- InfoNCE temperature must be > 0 (typically 0.07)
- HNSW index ef_search must be >= k for k-NN queries
- MERIDIAN GRL lambda must anneal from 0.0 to 1.0 over the first 20 epochs using the schedule: lambda(p) = 2 / (1 + exp(-10 * p)) - 1, where p = epoch / 20
- GRL lambda must not exceed 1.0 at any epoch
DomainFactorizeroutput dimensions: h_pose = [17 x 64], h_env = [32]GeometryEncodermust be permutation-invariant with respect to AP ordering (DeepSets guarantee)- LoRA adapter rank must be <= d_model / 4 (default rank=4 for d_model=64)
- Rapid adaptation requires at least 200 CSI frames (10 seconds at 20 Hz)
Resamples CSI subcarriers from source to target count using physically-motivated sparse interpolation.
pub trait SubcarrierInterpolationService {
/// Sparse interpolation via NeumannSolver (O(sqrt(n)), preferred)
fn interpolate_sparse(
&self,
source: &[f32],
source_count: usize,
target_count: usize,
tolerance: f64,
) -> Result<Vec<f32>, InterpolationError>;
/// Linear interpolation fallback (O(n))
fn interpolate_linear(
&self,
source: &[f32],
source_count: usize,
target_count: usize,
) -> Vec<f32>;
/// Zero-pad for Wi-Pose (30 -> 56)
fn zero_pad(
&self,
source: &[f32],
target_count: usize,
) -> Vec<f32>;
}Computes the composite training loss from model outputs and ground truth.
pub trait LossComputationService {
/// Compute composite loss with per-component breakdown
fn compute(
&self,
predictions: &ModelOutput,
targets: &GroundTruth,
weights: &LossWeights,
) -> CompositeTrainingLoss;
}
pub struct CompositeTrainingLoss {
/// Total weighted loss (scalar for backprop)
pub total: f64,
/// Keypoint heatmap MSE component
pub keypoint_mse: f64,
/// DensePose body part cross-entropy component
pub densepose_part_ce: f64,
/// DensePose UV Smooth L1 component
pub uv_smooth_l1: f64,
/// Teacher-student transfer MSE component
pub transfer_mse: f64,
/// AETHER contrastive loss (if enabled)
pub contrastive: Option<f64>,
/// MERIDIAN domain adversarial loss (if enabled)
pub domain_adversarial: Option<f64>,
}Evaluates model accuracy on the validation set using standard pose estimation metrics.
pub trait MetricEvaluationService {
/// PCK@0.2: fraction of keypoints within 20% of torso diameter
fn compute_pck(&self, predictions: &[PosePrediction], targets: &[PoseTarget], threshold: f64) -> PckScore;
/// OKS: Object Keypoint Similarity with per-keypoint sigmas
fn compute_oks(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> OksScore;
/// MPJPE: Mean Per Joint Position Error in millimeters
fn compute_mpjpe(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> MpjpeScore;
/// Multi-person assignment via Hungarian (static, deterministic)
fn assign_hungarian(&self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>;
/// Multi-person assignment via DynamicMinCut (persistent, O(n^1.5 log n))
fn assign_dynamic(&mut self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>;
}Manages the MERIDIAN gradient reversal training regime.
pub trait DomainAdversarialService {
/// Compute GRL lambda for the current epoch
fn grl_lambda(&self, epoch: usize, max_warmup_epochs: usize) -> f32;
/// Forward pass through domain classifier with gradient reversal
fn classify_domain(
&self,
h_pose: &Tensor,
lambda: f32,
) -> Tensor;
/// Compute domain adversarial loss (cross-entropy on domain logits)
fn domain_loss(
&self,
domain_logits: &Tensor,
domain_labels: &Tensor,
) -> f64;
}+------------------------------------------------------------------+
| Training Pipeline System |
+------------------------------------------------------------------+
| |
| +------------------+ CsiSample +------------------+ |
| | Dataset |-------------->| Training | |
| | Management | | Orchestration | |
| | Context | | Context | |
| +--------+---------+ +--------+-----------+ |
| | | |
| | Publishes | Publishes |
| | DatasetEvent | TrainingEvent |
| v v |
| +------------------------------------------------------+ |
| | Event Bus (Domain Events) | |
| +------------------------------------------------------+ |
| | | |
| v v |
| +------------------+ +------------------+ |
| | Model |<-------------| Embedding & | |
| | Architecture | body_part_ | Transfer | |
| | Context | features | Context | |
| +------------------+ +------------------+ |
| |
+------------------------------------------------------------------+
| UPSTREAM (Conformist) |
| +--------------+ +--------------+ +--------------+ |
| |wifi-densepose| |wifi-densepose| |wifi-densepose| |
| | -signal | | -nn | | -core | |
| | (phase algs,| | (ONNX, | | (CsiFrame, | |
| | SpotFi) | | Candle) | | error) | |
| +--------------+ +--------------+ +--------------+ |
| |
+------------------------------------------------------------------+
| SIBLING (Partnership) |
| +--------------+ +--------------+ +--------------+ |
| | RuvSense | | MAT | | Sensing | |
| | (pose | | (triage, | | Server | |
| | tracker, | | survivor) | | (inference | |
| | field | | | | deployment) | |
| | model) | | | | | |
| +--------------+ +--------------+ +--------------+ |
| |
+------------------------------------------------------------------+
| EXTERNAL (Published Language) |
| +--------------+ +--------------+ +--------------+ |
| | MM-Fi | | Wi-Pose | | Detectron2 | |
| | (NeurIPS | | (NjtechCV | | (teacher | |
| | dataset) | | dataset) | | labels) | |
| +--------------+ +--------------+ +--------------+ |
+------------------------------------------------------------------+
Relationship Types:
- Dataset Management -> Training Orchestration: Customer/Supplier (Dataset produces CsiSamples; Orchestration consumes)
- Model Architecture -> Training Orchestration: Partnership (tight bidirectional coupling: Orchestration drives forward/backward; Architecture defines the computation graph)
- Model Architecture -> Embedding & Transfer: Customer/Supplier (Architecture produces body_part_features; Embedding consumes for contrastive/adversarial heads)
- Embedding & Transfer -> Training Orchestration: Partnership (contrastive and adversarial losses feed into composite loss)
- Training Pipeline -> Upstream crates: Conformist (adapts to wifi-densepose-signal, -nn, -core types)
- Training Pipeline -> RuvSense/MAT/Server: Partnership (trained model weights flow downstream)
- Training Pipeline -> External datasets: Anti-Corruption Layer (dataset loaders translate external formats to domain types)
/// Translates raw MM-Fi numpy files into domain CsiSample values.
/// Handles the 114->56 subcarrier interpolation and 1TX/3RX antenna layout.
pub struct MmFiAdapter {
/// Subcarrier interpolation service
interpolator: Box<dyn SubcarrierInterpolationService>,
/// Phase sanitizer from wifi-densepose-signal
phase_sanitizer: PhaseSanitizer,
/// Hardware normalizer for z-score normalization
normalizer: HardwareNormalizer,
}
impl MmFiAdapter {
/// Load a single MM-Fi sample from .npy tensors and produce a CsiSample.
/// Steps:
/// 1. Read amplitude [3, 114, 10] and phase [3, 114, 10]
/// 2. Interpolate 114 -> 56 subcarriers per antenna pair
/// 3. Sanitize phase (remove linear offset, unwrap)
/// 4. Z-score normalize amplitude per frame
/// 5. Read 17-keypoint COCO annotations
pub fn adapt(&self, raw: &MmFiRawSample) -> Result<CsiSample, AdapterError>;
}/// Translates Wi-Pose .mat files into domain CsiSample values.
/// Handles 30->56 zero-padding and 18->17 keypoint mapping.
pub struct WiPoseAdapter {
/// Zero-padding service
interpolator: Box<dyn SubcarrierInterpolationService>,
/// Phase sanitizer
phase_sanitizer: PhaseSanitizer,
}
impl WiPoseAdapter {
/// Load a Wi-Pose sample from .mat format and produce a CsiSample.
/// Steps:
/// 1. Read CSI [9, 30] (3x3 antenna pairs, 30 subcarriers)
/// 2. Zero-pad 30 -> 56 subcarriers (high-frequency padding)
/// 3. Sanitize phase
/// 4. Map 18 AlphaPose keypoints -> 17 COCO (drop neck, index 1)
pub fn adapt(&self, raw: &WiPoseRawSample) -> Result<CsiSample, AdapterError>;
}/// Adapts Detectron2 DensePose outputs into domain DensePoseLabels.
/// Used during teacher-student pseudo-label generation.
pub struct TeacherModelAdapter;
impl TeacherModelAdapter {
/// Run Detectron2 DensePose on an RGB frame and produce pseudo-labels.
/// Output: (part_labels [H x W], u_coords [H x W], v_coords [H x W])
pub fn generate_pseudo_labels(
&self,
rgb_frame: &RgbFrame,
) -> Result<DensePoseLabels, AdapterError>;
}/// Adapts ruvector-attn-mincut API to the model's tensor format.
/// Handles the Tensor <-> Vec<f32> conversion overhead per batch element.
pub struct AttnMinCutAdapter;
impl AttnMinCutAdapter {
/// Apply min-cut gated attention to antenna-path features.
/// Converts [B, n_ant, n_sc] tensor to flat Vec<f32> per batch element,
/// calls attn_mincut, and reshapes output back to tensor.
pub fn apply(
&self,
features: &Tensor,
n_antenna_paths: usize,
n_subcarriers: usize,
lambda: f32,
) -> Result<Tensor, AdapterError>;
}/// Persists and retrieves training run state
pub trait TrainingRunRepository {
fn save(&self, run: &TrainingRun) -> Result<(), RepositoryError>;
fn find_by_id(&self, id: &TrainingRunId) -> Result<Option<TrainingRun>, RepositoryError>;
fn find_latest(&self) -> Result<Option<TrainingRun>, RepositoryError>;
fn list_completed(&self) -> Result<Vec<TrainingRun>, RepositoryError>;
}
/// Persists model checkpoints
pub trait CheckpointRepository {
fn save(&self, checkpoint: &Checkpoint) -> Result<(), RepositoryError>;
fn find_best(&self, run_id: &TrainingRunId) -> Result<Option<Checkpoint>, RepositoryError>;
fn find_by_epoch(&self, run_id: &TrainingRunId, epoch: usize) -> Result<Option<Checkpoint>, RepositoryError>;
fn list_all(&self, run_id: &TrainingRunId) -> Result<Vec<Checkpoint>, RepositoryError>;
}
/// Persists AETHER embedding index
pub trait EmbeddingRepository {
fn save_index(&self, index: &EmbeddingIndex) -> Result<(), RepositoryError>;
fn load_index(&self) -> Result<Option<EmbeddingIndex>, RepositoryError>;
fn add_entry(&self, entry: &EmbeddingEntry) -> Result<(), RepositoryError>;
fn search_knn(&self, query: &[f32], k: usize) -> Result<Vec<(EmbeddingEntry, f32)>, RepositoryError>;
}
/// Persists LoRA adapters for environment-specific fine-tuning
pub trait LoraRepository {
fn save(&self, adapter: &LoraAdapter) -> Result<(), RepositoryError>;
fn find_by_domain(&self, domain: &DomainLabel) -> Result<Option<LoraAdapter>, RepositoryError>;
fn list_all(&self) -> Result<Vec<LoraAdapter>, RepositoryError>;
}- ADR-015: Public Dataset Strategy (MM-Fi, Wi-Pose, teacher-student training)
- ADR-016: RuVector Integration (5 crate integration points in training pipeline)
- ADR-020: Rust Migration (training pipeline in wifi-densepose-train crate)
- ADR-024: AETHER Contrastive CSI Embeddings (128-dim fingerprints, InfoNCE, HNSW)
- ADR-027: MERIDIAN Cross-Environment Domain Generalization (GRL, FiLM, LoRA)
- Yang et al., "MM-Fi: Multi-Modal Non-Intrusive 4D Human Dataset" (NeurIPS 2023)
- NjtechCVLab, "Wi-Pose Dataset" (CSI-Former, MDPI Entropy 2023)
- Geng et al., "DensePose From WiFi" (CMU, arXiv:2301.00250, 2023)
- Ganin et al., "Domain-Adversarial Training of Neural Networks" (JMLR 2016)
- Perez et al., "FiLM: Visual Reasoning with a General Conditioning Layer" (AAAI 2018)