Skip to content

Latest commit

 

History

History
1121 lines (973 loc) · 46.8 KB

File metadata and controls

1121 lines (973 loc) · 46.8 KB

Training Pipeline Domain Model

The Training & ML Pipeline is the subsystem of WiFi-DensePose that turns raw public CSI datasets into a trained pose estimation model and its downstream derivatives: contrastive embeddings, domain-generalized weights, and deterministic proof bundles. It is the bridge between research data and deployable inference.

This document defines the system using Domain-Driven Design (DDD): bounded contexts that own their data and rules, aggregate roots that enforce invariants, value objects that carry meaning, and domain events that connect everything. The goal is to make the pipeline's structure match the physics and mathematics it implements -- so that anyone reading the code (or an AI agent modifying it) understands why each piece exists, not just what it does.

Bounded Contexts:

# Context Responsibility Key ADRs Code
1 Dataset Management Load, validate, normalize, and preprocess training data from MM-Fi and Wi-Pose ADR-015 train/src/dataset.rs, train/src/subcarrier.rs
2 Model Architecture Define the neural network, forward pass, attention mechanisms, and spatial decoding ADR-016, ADR-020 train/src/model.rs, train/src/graph_transformer.rs
3 Training Orchestration Run the training loop, compute composite loss, checkpoint, and verify deterministic proofs ADR-015, ADR-016 train/src/trainer.rs, train/src/losses.rs, train/src/metrics.rs, train/src/proof.rs
4 Embedding & Transfer Produce AETHER contrastive embeddings, MERIDIAN domain-generalized features, and LoRA adapters ADR-024, ADR-027 train/src/embedding.rs, train/src/domain.rs, train/src/sona.rs

All code paths shown are relative to rust-port/wifi-densepose-rs/crates/wifi-densepose- unless otherwise noted.


Domain-Driven Design Specification

Ubiquitous Language

Term Definition
Training Run A complete training session: configuration, epoch loop, checkpoint history, and final model weights
Epoch One full pass through the training dataset; produces train loss and validation metrics
Checkpoint A snapshot of model weights at a given epoch, identified by SHA-256 hash and validation PCK
CSI Sample A single observation: amplitude + phase tensors, ground-truth keypoints, and visibility flags
Subcarrier Interpolation Resampling CSI from source subcarrier count to the canonical 56 (114->56 for MM-Fi, 30->56 for Wi-Pose)
Teacher-Student Training regime where a camera-based RGB model generates pseudo-labels; at inference the camera is removed
Pseudo-Label DensePose UV surface coordinates generated by Detectron2 from paired RGB frames
PCK@0.2 Percentage of Correct Keypoints within 20% of torso diameter; primary accuracy metric
OKS Object Keypoint Similarity; per-keypoint Gaussian-weighted distance used in COCO evaluation
MPJPE Mean Per Joint Position Error in millimeters; 3D accuracy metric
Hungarian Assignment Bipartite matching of predicted persons to ground-truth using min-cost assignment
Dynamic Min-Cut Subpolynomial O(n^1.5 log n) person-to-GT assignment maintained across frames
Compressed CSI Buffer Tiered-quantization temporal window: hot frames at 8-bit, warm at 5/7-bit, cold at 3-bit
Proof Verification Deterministic check: fixed seed -> N training steps -> loss decreases AND SHA-256 hash matches
AETHER Embedding 128-dim L2-normalized contrastive vector from the CsiToPoseTransformer backbone
InfoNCE Loss Contrastive loss that pushes same-identity embeddings together and different-identity apart
HNSW Index Hierarchical Navigable Small World graph for approximate nearest-neighbor embedding search
Domain Factorizer Splits latent features into pose-invariant (h_pose) and environment-specific (h_env) components
Gradient Reversal Layer Identity in forward pass; multiplies gradient by -lambda in backward pass to force domain invariance
GRL Lambda Adversarial weight annealed from 0.0 to 1.0 over the first 20 epochs
FiLM Conditioning Feature-wise Linear Modulation: gamma * features + beta, conditioned on geometry encoding
Hardware Normalizer Resamples CSI from any chipset to canonical 56 subcarriers with z-score amplitude normalization
LoRA Adapter Low-Rank Adaptation weights (rank r, alpha) for few-shot environment-specific fine-tuning
Rapid Adaptation 10-second unlabeled calibration producing a per-room LoRA adapter via contrastive test-time training

Bounded Contexts

1. Dataset Management Context

Responsibility: Load raw CSI data from public datasets (MM-Fi, Wi-Pose), validate structural invariants, resample subcarriers to the canonical 56, apply phase sanitization, and present typed samples to the training loop. Memory efficiency via tiered temporal compression.

+----------------------------------------------------------+
|              Dataset Management Context                    |
+----------------------------------------------------------+
|                                                            |
|  +---------------+    +---------------+                    |
|  |  MM-Fi Loader |    |  Wi-Pose      |                    |
|  |  (.npy files, |    |  Loader       |                    |
|  |   114 sub,    |    |  (.mat files, |                    |
|  |   40 subjects)|    |   30 sub,     |                    |
|  +-------+-------+    |   12 subjects)|                    |
|          |            +-------+-------+                    |
|          |                    |                             |
|          +--------+-----------+                            |
|                   v                                        |
|          +----------------+                                |
|          | Subcarrier     |                                |
|          | Interpolator   |                                |
|          | (114->56 or    |                                |
|          |  30->56)       |                                |
|          +--------+-------+                                |
|                   v                                        |
|          +----------------+                                |
|          | Phase          |                                |
|          | Sanitizer      |                                |
|          | (SOTA algs     |                                |
|          |  from signal)  |                                |
|          +--------+-------+                                |
|                   v                                        |
|          +----------------+                                |
|          | Compressed CSI |--> CsiSample                   |
|          | Buffer         |                                |
|          | (tiered quant) |                                |
|          +----------------+                                |
|                                                            |
+----------------------------------------------------------+

Aggregates:

  • MmFiDataset (Aggregate Root) -- Manages the MM-Fi data lifecycle
  • WiPoseDataset (Aggregate Root) -- Manages the Wi-Pose data lifecycle

Value Objects:

  • CsiSample -- Single observation with amplitude, phase, keypoints, visibility
  • SubcarrierConfig -- Source count, target count, interpolation method
  • DatasetSplit -- Train / Validation / Test subject partitioning
  • CompressedCsiBuffer -- Tiered temporal window backed by TemporalTensorCompressor

Domain Services:

  • SubcarrierInterpolationService -- Resamples subcarriers via sparse least-squares or linear fallback
  • PhaseSanitizationService -- Applies SpotFi / MUSIC phase correction from wifi-densepose-signal
  • TeacherLabelService -- Runs Detectron2 on paired RGB frames to produce DensePose UV pseudo-labels
  • HardwareNormalizerService -- Z-score normalization + chipset-invariant phase sanitization

RuVector Integration:

  • ruvector-solver -> NeumannSolver for sparse O(sqrt(n)) subcarrier interpolation (114->56)
  • ruvector-temporal-tensor -> TemporalTensorCompressor for 50-75% memory reduction in CSI windows

2. Model Architecture Context

Responsibility: Define the WiFiDensePoseModel: CSI embedding, cross-attention between keypoint queries and CSI features, GNN message passing, attention-gated modality fusion, and spatial decoding heads for keypoints and DensePose UV.

+----------------------------------------------------------+
|              Model Architecture Context                    |
+----------------------------------------------------------+
|                                                            |
|  +---------------+    +---------------+                    |
|  | CSI Embed     |    | Keypoint      |                    |
|  | (Linear       |    | Queries       |                    |
|  |  56 -> d)     |    | (17 learned   |                    |
|  +-------+-------+    |  embeddings)  |                    |
|          |            +-------+-------+                    |
|          |                    |                             |
|          +--------+-----------+                            |
|                   v                                        |
|          +----------------+                                |
|          | Cross-Attention|                                |
|          | (Q=queries,    |                                |
|          |  K,V=csi)      |                                |
|          +--------+-------+                                |
|                   v                                        |
|          +----------------+                                |
|          | GNN Stack      |                                |
|          | (2-layer GCN   |                                |
|          |  skeleton      |                                |
|          |  adjacency)    |                                |
|          +--------+-------+                                |
|                   v                                        |
|     body_part_features [17 x d_model]                      |
|          |                                                 |
|          +-------+--------+--------+                       |
|          v       v        v        v                       |
|   +----------+ +------+ +-----+ +-------+                 |
|   | Modality | | xyz  | | UV  | |Spatial|                  |
|   | Transl.  | | Head | | Head| |Attn   |                  |
|   | (attn    | |      | |     | |Decoder|                  |
|   |  mincut) | |      | |     | |       |                  |
|   +----------+ +------+ +-----+ +-------+                 |
|                                                            |
+----------------------------------------------------------+

Aggregates:

  • WiFiDensePoseModel (Aggregate Root) -- The complete model graph

Entities:

  • ModalityTranslator -- Attention-gated CSI fusion using min-cut
  • CsiToPoseTransformer -- Cross-attention + GNN backbone
  • KeypointHead -- Regresses 17 x (x, y, z, confidence) from body_part_features
  • DensePoseHead -- Predicts body part labels and UV surface coordinates

Value Objects:

  • ModelConfig -- Architecture hyperparameters (d_model, n_heads, n_gnn_layers)
  • AttentionOutput -- Attended values + gating result from min-cut attention
  • BodyPartFeatures -- [17 x d_model] intermediate representation

Domain Services:

  • AttentionGatingService -- Applies attn_mincut to prune irrelevant antenna paths
  • SpatialDecodingService -- Graph-based spatial attention among feature map locations

RuVector Integration:

  • ruvector-attn-mincut -> attn_mincut for antenna-path gating in ModalityTranslator
  • ruvector-attention -> ScaledDotProductAttention for spatial decoder long-range dependencies

3. Training Orchestration Context

Responsibility: Run the training loop across epochs, compute the composite loss (keypoint MSE + DensePose part CE + UV Smooth L1 + transfer MSE), evaluate validation metrics (PCK@0.2, OKS, MPJPE), manage checkpoints, and verify deterministic proof correctness.

+----------------------------------------------------------+
|           Training Orchestration Context                   |
+----------------------------------------------------------+
|                                                            |
|  +---------------+    +---------------+                    |
|  | Training Loop |    | Loss Computer |                    |
|  | (epoch iter,  |    | (composite:   |                    |
|  |  batch fwd/   |    |  kp_mse +     |                    |
|  |  bwd, optim)  |    |  part_ce +    |                    |
|  +-------+-------+    |  uv_l1 +     |                    |
|          |            |  transfer)    |                    |
|          |            +-------+-------+                    |
|          +--------+-----------+                            |
|                   v                                        |
|          +----------------+                                |
|          | Metric         |                                |
|          | Evaluator      |                                |
|          | (PCK, OKS,     |                                |
|          |  MPJPE,        |                                |
|          |  Hungarian)    |                                |
|          +--------+-------+                                |
|                   v                                        |
|     +-------------+-------------+                          |
|     v                           v                          |
|  +----------------+    +----------------+                  |
|  | Checkpoint     |    | Proof Verifier |                  |
|  | Manager        |    | (fixed seed,   |                  |
|  | (best-by-PCK,  |    |  50 steps,     |                  |
|  |  SHA-256 hash) |    |  loss + hash)  |                  |
|  +----------------+    +----------------+                  |
|                                                            |
+----------------------------------------------------------+

Aggregates:

  • TrainingRun (Aggregate Root) -- The complete training session

Entities:

  • CheckpointManager -- Persists and selects model snapshots
  • ProofVerifier -- Deterministic verification against stored hashes

Value Objects:

  • TrainingConfig -- Epochs, batch_size, learning_rate, loss_weights, optimizer params
  • Checkpoint -- Epoch number, model weights SHA-256, validation PCK at that epoch
  • LossWeights -- Relative weights for each loss component
  • CompositeTrainingLoss -- Combined scalar loss with per-component breakdown
  • OksScore -- Per-keypoint Object Keypoint Similarity with sigma values
  • PckScore -- Percentage of Correct Keypoints at threshold 0.2
  • MpjpeScore -- Mean Per Joint Position Error in millimeters
  • ProofResult -- Seed, steps, loss_decreased flag, hash_matches flag

Domain Services:

  • LossComputationService -- Computes composite loss from model outputs and ground truth
  • MetricEvaluationService -- Computes PCK, OKS, MPJPE over validation set
  • HungarianAssignmentService -- Bipartite matching for multi-person evaluation
  • DynamicPersonMatcherService -- Frame-persistent assignment via ruvector-mincut
  • ProofVerificationService -- Fixed-seed training + SHA-256 verification

RuVector Integration:

  • ruvector-mincut -> DynamicMinCut for O(n^1.5 log n) multi-person assignment in metrics
  • Original hungarian_assignment kept for single-frame static matching in proof verification

4. Embedding & Transfer Context

Responsibility: Produce AETHER contrastive embeddings from the model backbone, train domain-adversarial features via MERIDIAN, manage the HNSW embedding index for re-ID and fingerprinting, and generate LoRA adapters for few-shot environment adaptation.

+----------------------------------------------------------+
|           Embedding & Transfer Context                     |
+----------------------------------------------------------+
|                                                            |
|  body_part_features [17 x d_model]                         |
|          |                                                 |
|          +--------+-----------+                            |
|          v                    v                            |
|  +---------------+    +---------------+                    |
|  | AETHER        |    | MERIDIAN      |                    |
|  | Projection    |    | Domain        |                    |
|  | Head          |    | Factorizer    |                    |
|  | (MeanPool ->  |    | (PoseEncoder  |                    |
|  |  fc -> 128d)  |    |  + EnvEncoder)|                    |
|  +-------+-------+    +-------+-------+                    |
|          |                    |                             |
|          v                    v                             |
|  +---------------+    +---------------+                    |
|  | InfoNCE Loss  |    | Gradient      |                    |
|  | + Hard Neg    |    | Reversal      |                    |
|  | Mining (HNSW) |    | Layer (GRL)   |                    |
|  +-------+-------+    +-------+-------+                    |
|          |                    |                             |
|          v                    v                             |
|  +---------------+    +---------------+                    |
|  | Embedding     |    | Geometry      |                    |
|  | Index (HNSW)  |    | Encoder +     |                    |
|  | (fingerprint  |    | FiLM Cond.    |                    |
|  |  store)       |    | (zero-shot)   |                    |
|  +---------------+    +-------+-------+                    |
|                               |                            |
|                               v                            |
|                       +---------------+                    |
|                       | Rapid Adapt.  |                    |
|                       | (LoRA + TTT,  |                    |
|                       |  10-sec cal.) |                    |
|                       +---------------+                    |
|                                                            |
+----------------------------------------------------------+

Aggregates:

  • EmbeddingIndex (Aggregate Root) -- HNSW-indexed store of AETHER fingerprints
  • DomainAdaptationState (Aggregate Root) -- Tracks GRL lambda, domain classifier accuracy, factorization quality

Entities:

  • ProjectionHead -- MLP mapping body_part_features to 128-dim embedding space
  • DomainFactorizer -- Splits features into h_pose and h_env
  • DomainClassifier -- Classifies domain from h_pose (trained adversarially via GRL)
  • GeometryEncoder -- Fourier positional encoding + DeepSets for AP positions
  • LoraAdapter -- Low-rank adaptation weights for environment-specific fine-tuning

Value Objects:

  • AetherEmbedding -- 128-dim L2-normalized contrastive vector
  • FingerprintType -- ReIdentification / RoomFingerprint / PersonFingerprint
  • DomainLabel -- Environment identifier for adversarial training
  • GrlSchedule -- Lambda annealing parameters (max_lambda, warmup_epochs)
  • GeometryInput -- AP positions in meters relative to room origin
  • FilmParameters -- Gamma (scale) and beta (shift) vectors from geometry conditioning
  • LoraConfig -- Rank, alpha, target layers
  • AdaptationLoss -- ContrastiveTTT / EntropyMin / Combined

Domain Services:

  • ContrastiveLossService -- Computes InfoNCE loss with temperature scaling
  • HardNegativeMiningService -- HNSW k-NN search for difficult negative pairs
  • DomainAdversarialService -- Manages GRL annealing and domain classification
  • GeometryConditioningService -- Encodes AP layout and produces FiLM parameters
  • VirtualDomainAugmentationService -- Generates synthetic environment shifts for training diversity
  • RapidAdaptationService -- Produces LoRA adapter from 10-second unlabeled calibration

Core Domain Entities

TrainingRun (Aggregate Root)

pub struct TrainingRun {
    /// Unique run identifier
    pub id: TrainingRunId,
    /// Full training configuration
    pub config: TrainingConfig,
    /// Datasets loaded for this run
    pub datasets: Vec<DatasetHandle>,
    /// Ordered history of per-epoch metrics
    pub epoch_history: Vec<EpochRecord>,
    /// Best checkpoint by validation PCK
    pub best_checkpoint: Option<Checkpoint>,
    /// Current epoch (0-indexed)
    pub current_epoch: usize,
    /// Run status
    pub status: RunStatus,
    /// Proof verification result (if run)
    pub proof_result: Option<ProofResult>,
}

pub enum RunStatus {
    Initializing,
    Training,
    Completed,
    Failed { reason: String },
    ProofVerified,
}

Invariants:

  • Must have at least 1 dataset loaded before transitioning to Training
  • best_checkpoint is updated only when a new epoch's validation PCK exceeds all prior epochs
  • proof_result can only be set once and is immutable after verification

MmFiDataset (Aggregate Root)

pub struct MmFiDataset {
    /// Root directory containing .npy files
    pub data_root: PathBuf,
    /// Subject IDs in this split
    pub subject_ids: Vec<u32>,
    /// Number of action classes
    pub n_actions: usize,  // 27
    /// Source subcarrier count
    pub source_subcarriers: usize,  // 114
    /// Target subcarrier count after interpolation
    pub target_subcarriers: usize,  // 56
    /// Antenna configuration: 1 TX x 3 RX
    pub antenna_pairs: usize,  // 3
    /// Sampling rate in Hz
    pub sample_rate_hz: f32,  // 100.0
    /// Temporal window size (frames per sample)
    pub window_frames: usize,  // 10
    /// Compressed buffer for memory-efficient storage
    pub buffer: CompressedCsiBuffer,
    /// Total loaded samples
    pub n_samples: usize,
}

WiPoseDataset (Aggregate Root)

pub struct WiPoseDataset {
    /// Root directory containing .mat files
    pub data_root: PathBuf,
    /// Subject IDs in this split
    pub subject_ids: Vec<u32>,
    /// Source subcarrier count
    pub source_subcarriers: usize,  // 30
    /// Target subcarrier count after zero-padding
    pub target_subcarriers: usize,  // 56
    /// Antenna configuration: 3 TX x 3 RX
    pub antenna_pairs: usize,  // 9
    /// Keypoint count (18 AlphaPose, mapped to 17 COCO)
    pub source_keypoints: usize,  // 18
    /// Compressed buffer
    pub buffer: CompressedCsiBuffer,
    /// Total loaded samples
    pub n_samples: usize,
}

WiFiDensePoseModel (Aggregate Root)

pub struct WiFiDensePoseModel {
    /// CSI embedding layer: Linear(56, d_model)
    pub csi_embed: Linear,
    /// Learned keypoint query embeddings [17 x d_model]
    pub keypoint_queries: Tensor,
    /// Cross-attention: Q=queries, K,V=csi_embed
    pub cross_attention: MultiHeadAttention,
    /// GNN message passing on skeleton graph
    pub gnn_stack: GnnStack,
    /// Modality translator with attention-gated fusion
    pub modality_translator: ModalityTranslator,
    /// Keypoint regression head
    pub keypoint_head: KeypointHead,
    /// DensePose UV prediction head
    pub densepose_head: DensePoseHead,
    /// Spatial attention decoder
    pub spatial_decoder: SpatialAttentionDecoder,
    /// Model dimensionality
    pub d_model: usize,  // 64
}

EmbeddingIndex (Aggregate Root)

pub struct EmbeddingIndex {
    /// HNSW graph for approximate nearest-neighbor search
    pub hnsw: HnswIndex,
    /// Stored embeddings with metadata
    pub entries: Vec<EmbeddingEntry>,
    /// Embedding dimensionality
    pub dim: usize,  // 128
    /// Number of indexed embeddings
    pub count: usize,
    /// HNSW construction parameters
    pub ef_construction: usize,  // 200
    pub m_connections: usize,    // 16
}

pub struct EmbeddingEntry {
    pub id: EmbeddingId,
    pub embedding: Vec<f32>,  // [128], L2-normalized
    pub fingerprint_type: FingerprintType,
    pub source_domain: Option<DomainLabel>,
    pub created_at: u64,
}

pub enum FingerprintType {
    ReIdentification,
    RoomFingerprint,
    PersonFingerprint,
}

Value Objects

CsiSample

pub struct CsiSample {
    /// Amplitude tensor [n_antenna_pairs x n_subcarriers x n_time_frames]
    pub amplitude: Vec<f32>,
    /// Phase tensor [n_antenna_pairs x n_subcarriers x n_time_frames]
    pub phase: Vec<f32>,
    /// Ground-truth 3D keypoints [17 x 3] (x, y, z in meters)
    pub keypoints: [[f32; 3]; 17],
    /// Per-keypoint visibility flags
    pub visibility: [f32; 17],
    /// DensePose UV pseudo-labels (optional, from teacher model)
    pub densepose_uv: Option<DensePoseLabels>,
    /// Domain label for adversarial training
    pub domain_label: Option<DomainLabel>,
    /// Hardware source type
    pub hardware_type: HardwareType,
}

TrainingConfig

pub struct TrainingConfig {
    /// Number of training epochs
    pub epochs: usize,
    /// Mini-batch size
    pub batch_size: usize,
    /// Initial learning rate
    pub learning_rate: f64,  // 1e-3
    /// Learning rate schedule: step decay at these epochs
    pub lr_decay_epochs: Vec<usize>,  // [40, 80]
    /// Learning rate decay factor
    pub lr_decay_factor: f64,  // 0.1
    /// Loss component weights
    pub loss_weights: LossWeights,
    /// Optimizer (Adam)
    pub optimizer: OptimizerConfig,
    /// Validation subject IDs (MM-Fi: 33-40)
    pub val_subjects: Vec<u32>,
    /// Random seed for reproducibility
    pub seed: u64,
    /// Enable MERIDIAN domain-adversarial training
    pub meridian_enabled: bool,
    /// Enable AETHER contrastive learning
    pub aether_enabled: bool,
}

pub struct LossWeights {
    /// Keypoint heatmap MSE weight
    pub keypoint_mse: f32,      // 1.0
    /// DensePose body part cross-entropy weight
    pub densepose_part_ce: f32, // 0.5
    /// DensePose UV Smooth L1 weight
    pub uv_smooth_l1: f32,     // 0.5
    /// Teacher-student transfer MSE weight
    pub transfer_mse: f32,     // 0.2
    /// AETHER contrastive loss weight (ADR-024)
    pub contrastive: f32,      // 0.1
    /// MERIDIAN domain adversarial weight (ADR-027)
    pub domain_adversarial: f32, // annealed 0.0 -> 1.0
}

Checkpoint

pub struct Checkpoint {
    /// Epoch at which this checkpoint was saved
    pub epoch: usize,
    /// SHA-256 hash of serialized model weights
    pub weights_hash: String,
    /// Validation PCK@0.2 at this epoch
    pub validation_pck: f64,
    /// Validation OKS at this epoch
    pub validation_oks: f64,
    /// File path to saved weights
    pub path: PathBuf,
    /// Timestamp
    pub created_at: u64,
}

ProofResult

pub struct ProofResult {
    /// Seed used for model initialization
    pub model_seed: u64,  // MODEL_SEED = 0
    /// Seed used for proof data generation
    pub proof_seed: u64,  // PROOF_SEED = 42
    /// Number of training steps in proof
    pub steps: usize,     // 50
    /// Whether loss decreased monotonically
    pub loss_decreased: bool,
    /// Whether final weights hash matches stored expected hash
    pub hash_matches: bool,
    /// The computed SHA-256 hash
    pub computed_hash: String,
    /// The expected SHA-256 hash (from file)
    pub expected_hash: String,
}

LoraAdapter

pub struct LoraAdapter {
    /// Low-rank decomposition rank
    pub rank: usize,  // 4
    /// LoRA alpha scaling factor
    pub alpha: f32,   // 1.0
    /// Per-layer weight matrices (A and B for each adapted layer)
    pub weights: Vec<LoraLayerWeights>,
    /// Source domain this adapter was calibrated for
    pub source_domain: DomainLabel,
    /// Calibration duration in seconds
    pub calibration_duration_secs: f32,
    /// Number of calibration frames used
    pub calibration_frames: usize,
}

pub struct LoraLayerWeights {
    /// Layer name in the model
    pub layer_name: String,
    /// Down-projection: [d_model x rank]
    pub a: Vec<f32>,
    /// Up-projection: [rank x d_model]
    pub b: Vec<f32>,
}

Domain Events

Dataset Events

pub enum DatasetEvent {
    /// Dataset loaded and validated
    DatasetLoaded {
        dataset_type: DatasetType,
        n_samples: usize,
        n_subjects: u32,
        source_subcarriers: usize,
        timestamp: u64,
    },

    /// Subcarrier interpolation completed for a dataset
    SubcarrierInterpolationComplete {
        dataset_type: DatasetType,
        source_count: usize,
        target_count: usize,
        method: InterpolationMethod,
        timestamp: u64,
    },

    /// Teacher pseudo-labels generated for a batch
    PseudoLabelsGenerated {
        n_samples: usize,
        n_with_uv: usize,
        timestamp: u64,
    },
}

pub enum DatasetType {
    MmFi,
    WiPose,
    Synthetic,
}

pub enum InterpolationMethod {
    /// ruvector-solver NeumannSolver sparse least-squares
    SparseNeumannSolver,
    /// Fallback linear interpolation
    LinearInterpolation,
    /// Wi-Pose zero-padding
    ZeroPad,
}

Training Events

pub enum TrainingEvent {
    /// One epoch of training completed
    EpochCompleted {
        epoch: usize,
        train_loss: f64,
        val_pck: f64,
        val_oks: f64,
        val_mpjpe_mm: f64,
        learning_rate: f64,
        grl_lambda: f32,
        timestamp: u64,
    },

    /// New best checkpoint saved
    CheckpointSaved {
        epoch: usize,
        weights_hash: String,
        validation_pck: f64,
        path: String,
        timestamp: u64,
    },

    /// Deterministic proof verification completed
    ProofVerified {
        model_seed: u64,
        proof_seed: u64,
        steps: usize,
        loss_decreased: bool,
        hash_matches: bool,
        timestamp: u64,
    },

    /// Training run completed or failed
    TrainingRunFinished {
        run_id: String,
        status: RunStatus,
        total_epochs: usize,
        best_pck: f64,
        best_oks: f64,
        timestamp: u64,
    },
}

Embedding Events

pub enum EmbeddingEvent {
    /// New AETHER embedding indexed
    EmbeddingIndexed {
        embedding_id: String,
        fingerprint_type: FingerprintType,
        nearest_neighbor_distance: f32,
        index_size: usize,
        timestamp: u64,
    },

    /// Hard negative pair discovered during mining
    HardNegativeFound {
        anchor_id: String,
        negative_id: String,
        similarity: f32,
        timestamp: u64,
    },

    /// Domain adaptation completed for a target environment
    DomainAdaptationComplete {
        source_domain: String,
        target_domain: String,
        pck_before: f64,
        pck_after: f64,
        adaptation_method: String,
        timestamp: u64,
    },

    /// LoRA adapter generated via rapid calibration
    LoraAdapterGenerated {
        domain: String,
        rank: usize,
        calibration_frames: usize,
        calibration_seconds: f32,
        timestamp: u64,
    },
}

Invariants

Dataset Management

  • MM-Fi samples must be interpolated from 114 to 56 subcarriers before use in training
  • Wi-Pose samples must be zero-padded from 30 to 56 subcarriers before use in training
  • Wi-Pose keypoints must be mapped from 18 (AlphaPose) to 17 (COCO) by dropping neck index 1
  • All CSI amplitudes must be finite and non-negative after loading
  • Phase values must be in [-pi, pi] after sanitization
  • Validation subjects (MM-Fi: 33-40) must never appear in the training split
  • CompressedCsiBuffer must preserve signal fidelity within quantization error bounds (hot: <1% error)

Model Architecture

  • csi_embed input dimension must equal the canonical 56 subcarriers
  • keypoint_queries must have exactly 17 entries (one per COCO keypoint)
  • attn_mincut seq_len must equal n_antenna_pairs * n_time_frames
  • GNN adjacency matrix must encode the human skeleton topology (17 nodes, 16 edges)
  • Spatial attention decoder must preserve spatial resolution (no information loss in reshape)

Training Orchestration

  • TrainingRun must have at least 1 dataset loaded before start() is called
  • Proof verification requires fixed seeds: MODEL_SEED=0, PROOF_SEED=42
  • Proof verification uses exactly 50 training steps on deterministic SyntheticDataset
  • Loss must decrease over proof steps (otherwise proof fails)
  • SHA-256 hash of final weights must match stored expected hash (otherwise proof fails)
  • best_checkpoint is updated if and only if current val_pck > all previous val_pck values
  • Learning rate decays by factor 0.1 at epochs 40 and 80 (step schedule)
  • Hungarian assignment for static single-frame matching must use the deterministic implementation (not DynamicMinCut) during proof verification

Embedding & Transfer

  • AETHER embeddings must be L2-normalized (unit norm) before indexing in HNSW
  • InfoNCE temperature must be > 0 (typically 0.07)
  • HNSW index ef_search must be >= k for k-NN queries
  • MERIDIAN GRL lambda must anneal from 0.0 to 1.0 over the first 20 epochs using the schedule: lambda(p) = 2 / (1 + exp(-10 * p)) - 1, where p = epoch / 20
  • GRL lambda must not exceed 1.0 at any epoch
  • DomainFactorizer output dimensions: h_pose = [17 x 64], h_env = [32]
  • GeometryEncoder must be permutation-invariant with respect to AP ordering (DeepSets guarantee)
  • LoRA adapter rank must be <= d_model / 4 (default rank=4 for d_model=64)
  • Rapid adaptation requires at least 200 CSI frames (10 seconds at 20 Hz)

Domain Services

SubcarrierInterpolationService

Resamples CSI subcarriers from source to target count using physically-motivated sparse interpolation.

pub trait SubcarrierInterpolationService {
    /// Sparse interpolation via NeumannSolver (O(sqrt(n)), preferred)
    fn interpolate_sparse(
        &self,
        source: &[f32],
        source_count: usize,
        target_count: usize,
        tolerance: f64,
    ) -> Result<Vec<f32>, InterpolationError>;

    /// Linear interpolation fallback (O(n))
    fn interpolate_linear(
        &self,
        source: &[f32],
        source_count: usize,
        target_count: usize,
    ) -> Vec<f32>;

    /// Zero-pad for Wi-Pose (30 -> 56)
    fn zero_pad(
        &self,
        source: &[f32],
        target_count: usize,
    ) -> Vec<f32>;
}

LossComputationService

Computes the composite training loss from model outputs and ground truth.

pub trait LossComputationService {
    /// Compute composite loss with per-component breakdown
    fn compute(
        &self,
        predictions: &ModelOutput,
        targets: &GroundTruth,
        weights: &LossWeights,
    ) -> CompositeTrainingLoss;
}

pub struct CompositeTrainingLoss {
    /// Total weighted loss (scalar for backprop)
    pub total: f64,
    /// Keypoint heatmap MSE component
    pub keypoint_mse: f64,
    /// DensePose body part cross-entropy component
    pub densepose_part_ce: f64,
    /// DensePose UV Smooth L1 component
    pub uv_smooth_l1: f64,
    /// Teacher-student transfer MSE component
    pub transfer_mse: f64,
    /// AETHER contrastive loss (if enabled)
    pub contrastive: Option<f64>,
    /// MERIDIAN domain adversarial loss (if enabled)
    pub domain_adversarial: Option<f64>,
}

MetricEvaluationService

Evaluates model accuracy on the validation set using standard pose estimation metrics.

pub trait MetricEvaluationService {
    /// PCK@0.2: fraction of keypoints within 20% of torso diameter
    fn compute_pck(&self, predictions: &[PosePrediction], targets: &[PoseTarget], threshold: f64) -> PckScore;

    /// OKS: Object Keypoint Similarity with per-keypoint sigmas
    fn compute_oks(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> OksScore;

    /// MPJPE: Mean Per Joint Position Error in millimeters
    fn compute_mpjpe(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> MpjpeScore;

    /// Multi-person assignment via Hungarian (static, deterministic)
    fn assign_hungarian(&self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>;

    /// Multi-person assignment via DynamicMinCut (persistent, O(n^1.5 log n))
    fn assign_dynamic(&mut self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>;
}

DomainAdversarialService

Manages the MERIDIAN gradient reversal training regime.

pub trait DomainAdversarialService {
    /// Compute GRL lambda for the current epoch
    fn grl_lambda(&self, epoch: usize, max_warmup_epochs: usize) -> f32;

    /// Forward pass through domain classifier with gradient reversal
    fn classify_domain(
        &self,
        h_pose: &Tensor,
        lambda: f32,
    ) -> Tensor;

    /// Compute domain adversarial loss (cross-entropy on domain logits)
    fn domain_loss(
        &self,
        domain_logits: &Tensor,
        domain_labels: &Tensor,
    ) -> f64;
}

Context Map

+------------------------------------------------------------------+
|                   Training Pipeline System                         |
+------------------------------------------------------------------+
|                                                                    |
|  +------------------+  CsiSample    +------------------+           |
|  |   Dataset        |-------------->|   Training       |           |
|  |   Management     |              |   Orchestration   |           |
|  |   Context        |              |   Context          |           |
|  +--------+---------+              +--------+-----------+           |
|           |                                 |                      |
|           | Publishes                       | Publishes            |
|           | DatasetEvent                    | TrainingEvent        |
|           v                                 v                      |
|  +------------------------------------------------------+         |
|  |              Event Bus (Domain Events)                 |         |
|  +------------------------------------------------------+         |
|           |                                 |                      |
|           v                                 v                      |
|  +------------------+              +------------------+            |
|  |   Model          |<-------------|   Embedding &    |            |
|  |   Architecture   | body_part_   |   Transfer       |            |
|  |   Context        | features     |   Context         |            |
|  +------------------+              +------------------+            |
|                                                                    |
+------------------------------------------------------------------+
|                    UPSTREAM (Conformist)                            |
|  +--------------+  +--------------+  +--------------+              |
|  |wifi-densepose|  |wifi-densepose|  |wifi-densepose|              |
|  |   -signal    |  |     -nn      |  |    -core     |              |
|  |  (phase algs,|  |  (ONNX,      |  |  (CsiFrame,  |              |
|  |   SpotFi)    |  |   Candle)    |  |   error)     |              |
|  +--------------+  +--------------+  +--------------+              |
|                                                                    |
+------------------------------------------------------------------+
|                    SIBLING (Partnership)                            |
|  +--------------+  +--------------+  +--------------+              |
|  |  RuvSense    |  |  MAT         |  | Sensing      |              |
|  |  (pose       |  |  (triage,    |  | Server       |              |
|  |   tracker,   |  |  survivor)   |  | (inference   |              |
|  |   field      |  |              |  |  deployment) |              |
|  |   model)     |  |              |  |              |              |
|  +--------------+  +--------------+  +--------------+              |
|                                                                    |
+------------------------------------------------------------------+
|                    EXTERNAL (Published Language)                    |
|  +--------------+  +--------------+  +--------------+              |
|  |  MM-Fi       |  |  Wi-Pose     |  |  Detectron2  |              |
|  |  (NeurIPS    |  |  (NjtechCV   |  |  (teacher    |              |
|  |   dataset)   |  |   dataset)   |  |   labels)    |              |
|  +--------------+  +--------------+  +--------------+              |
+------------------------------------------------------------------+

Relationship Types:

  • Dataset Management -> Training Orchestration: Customer/Supplier (Dataset produces CsiSamples; Orchestration consumes)
  • Model Architecture -> Training Orchestration: Partnership (tight bidirectional coupling: Orchestration drives forward/backward; Architecture defines the computation graph)
  • Model Architecture -> Embedding & Transfer: Customer/Supplier (Architecture produces body_part_features; Embedding consumes for contrastive/adversarial heads)
  • Embedding & Transfer -> Training Orchestration: Partnership (contrastive and adversarial losses feed into composite loss)
  • Training Pipeline -> Upstream crates: Conformist (adapts to wifi-densepose-signal, -nn, -core types)
  • Training Pipeline -> RuvSense/MAT/Server: Partnership (trained model weights flow downstream)
  • Training Pipeline -> External datasets: Anti-Corruption Layer (dataset loaders translate external formats to domain types)

Anti-Corruption Layer

MM-Fi Adapter (Dataset Management -> External MM-Fi format)

/// Translates raw MM-Fi numpy files into domain CsiSample values.
/// Handles the 114->56 subcarrier interpolation and 1TX/3RX antenna layout.
pub struct MmFiAdapter {
    /// Subcarrier interpolation service
    interpolator: Box<dyn SubcarrierInterpolationService>,
    /// Phase sanitizer from wifi-densepose-signal
    phase_sanitizer: PhaseSanitizer,
    /// Hardware normalizer for z-score normalization
    normalizer: HardwareNormalizer,
}

impl MmFiAdapter {
    /// Load a single MM-Fi sample from .npy tensors and produce a CsiSample.
    /// Steps:
    ///   1. Read amplitude [3, 114, 10] and phase [3, 114, 10]
    ///   2. Interpolate 114 -> 56 subcarriers per antenna pair
    ///   3. Sanitize phase (remove linear offset, unwrap)
    ///   4. Z-score normalize amplitude per frame
    ///   5. Read 17-keypoint COCO annotations
    pub fn adapt(&self, raw: &MmFiRawSample) -> Result<CsiSample, AdapterError>;
}

Wi-Pose Adapter (Dataset Management -> External Wi-Pose format)

/// Translates Wi-Pose .mat files into domain CsiSample values.
/// Handles 30->56 zero-padding and 18->17 keypoint mapping.
pub struct WiPoseAdapter {
    /// Zero-padding service
    interpolator: Box<dyn SubcarrierInterpolationService>,
    /// Phase sanitizer
    phase_sanitizer: PhaseSanitizer,
}

impl WiPoseAdapter {
    /// Load a Wi-Pose sample from .mat format and produce a CsiSample.
    /// Steps:
    ///   1. Read CSI [9, 30] (3x3 antenna pairs, 30 subcarriers)
    ///   2. Zero-pad 30 -> 56 subcarriers (high-frequency padding)
    ///   3. Sanitize phase
    ///   4. Map 18 AlphaPose keypoints -> 17 COCO (drop neck, index 1)
    pub fn adapt(&self, raw: &WiPoseRawSample) -> Result<CsiSample, AdapterError>;
}

Teacher Model Adapter (Dataset Management -> Detectron2)

/// Adapts Detectron2 DensePose outputs into domain DensePoseLabels.
/// Used during teacher-student pseudo-label generation.
pub struct TeacherModelAdapter;

impl TeacherModelAdapter {
    /// Run Detectron2 DensePose on an RGB frame and produce pseudo-labels.
    /// Output: (part_labels [H x W], u_coords [H x W], v_coords [H x W])
    pub fn generate_pseudo_labels(
        &self,
        rgb_frame: &RgbFrame,
    ) -> Result<DensePoseLabels, AdapterError>;
}

RuVector Adapter (Model Architecture -> ruvector crates)

/// Adapts ruvector-attn-mincut API to the model's tensor format.
/// Handles the Tensor <-> Vec<f32> conversion overhead per batch element.
pub struct AttnMinCutAdapter;

impl AttnMinCutAdapter {
    /// Apply min-cut gated attention to antenna-path features.
    /// Converts [B, n_ant, n_sc] tensor to flat Vec<f32> per batch element,
    /// calls attn_mincut, and reshapes output back to tensor.
    pub fn apply(
        &self,
        features: &Tensor,
        n_antenna_paths: usize,
        n_subcarriers: usize,
        lambda: f32,
    ) -> Result<Tensor, AdapterError>;
}

Repository Interfaces

/// Persists and retrieves training run state
pub trait TrainingRunRepository {
    fn save(&self, run: &TrainingRun) -> Result<(), RepositoryError>;
    fn find_by_id(&self, id: &TrainingRunId) -> Result<Option<TrainingRun>, RepositoryError>;
    fn find_latest(&self) -> Result<Option<TrainingRun>, RepositoryError>;
    fn list_completed(&self) -> Result<Vec<TrainingRun>, RepositoryError>;
}

/// Persists model checkpoints
pub trait CheckpointRepository {
    fn save(&self, checkpoint: &Checkpoint) -> Result<(), RepositoryError>;
    fn find_best(&self, run_id: &TrainingRunId) -> Result<Option<Checkpoint>, RepositoryError>;
    fn find_by_epoch(&self, run_id: &TrainingRunId, epoch: usize) -> Result<Option<Checkpoint>, RepositoryError>;
    fn list_all(&self, run_id: &TrainingRunId) -> Result<Vec<Checkpoint>, RepositoryError>;
}

/// Persists AETHER embedding index
pub trait EmbeddingRepository {
    fn save_index(&self, index: &EmbeddingIndex) -> Result<(), RepositoryError>;
    fn load_index(&self) -> Result<Option<EmbeddingIndex>, RepositoryError>;
    fn add_entry(&self, entry: &EmbeddingEntry) -> Result<(), RepositoryError>;
    fn search_knn(&self, query: &[f32], k: usize) -> Result<Vec<(EmbeddingEntry, f32)>, RepositoryError>;
}

/// Persists LoRA adapters for environment-specific fine-tuning
pub trait LoraRepository {
    fn save(&self, adapter: &LoraAdapter) -> Result<(), RepositoryError>;
    fn find_by_domain(&self, domain: &DomainLabel) -> Result<Option<LoraAdapter>, RepositoryError>;
    fn list_all(&self) -> Result<Vec<LoraAdapter>, RepositoryError>;
}

References

  • ADR-015: Public Dataset Strategy (MM-Fi, Wi-Pose, teacher-student training)
  • ADR-016: RuVector Integration (5 crate integration points in training pipeline)
  • ADR-020: Rust Migration (training pipeline in wifi-densepose-train crate)
  • ADR-024: AETHER Contrastive CSI Embeddings (128-dim fingerprints, InfoNCE, HNSW)
  • ADR-027: MERIDIAN Cross-Environment Domain Generalization (GRL, FiLM, LoRA)
  • Yang et al., "MM-Fi: Multi-Modal Non-Intrusive 4D Human Dataset" (NeurIPS 2023)
  • NjtechCVLab, "Wi-Pose Dataset" (CSI-Former, MDPI Entropy 2023)
  • Geng et al., "DensePose From WiFi" (CMU, arXiv:2301.00250, 2023)
  • Ganin et al., "Domain-Adversarial Training of Neural Networks" (JMLR 2016)
  • Perez et al., "FiLM: Visual Reasoning with a General Conditioning Layer" (AAAI 2018)