Skip to content

Releases: nhsengland/NHSSynth

Stable VAE GAN comparison

25 Feb 18:51

Choose a tag to compare

Implementation Summary: Synthetic Data Fidelity Improvements

Date: 2026-01-15 (Updated: 2026-01-16, 2026-01-19)
Dataset: support.csv
Initial Issues: 50k+ constraint violations, 89% clipping, artificial peaks in distributions

Note: Critical bugs and posterior collapse issues were debugged, diagnosed, and documented using Claude Code (claude-sonnet-4-5).


Critical Update (2026-01-19): Posterior Collapse Fix

MAJOR FIX: Resolved complete posterior collapse in VAE where KLD collapsed to 0, preventing the encoder from learning meaningful latent representations.

Problem Identified

After initial fixes, VAE training showed catastrophic posterior collapse:

  • KLD collapsed from 348 → 0 in first 10 epochs
  • Encoder learned to output z ≈ N(0,1) (posterior = prior, no information)
  • Decoder ignored latent codes - reconstructed from random noise
  • Result: x11 had extreme variance collapse, x14 lost multimodal structure

Root Cause

Even with KL annealing (beta: 0.0 → 1.0), the encoder could achieve low loss by:

  1. Outputting uninformative z ≈ N(0,1) to minimize KLD
  2. Decoder learning to reconstruct from random z ~ N(0,1)
  3. No incentive to encode meaningful information

Solution: Free Bits

Implemented free bits constraint that forces encoder to use latent capacity:

  • Only penalize KLD above a threshold per latent dimension
  • Default: 2.0 nats per dimension (e.g., 256 total for 128D latent)
  • Prevents encoder from collapsing to uninformative prior

Implementation Details (vae.py)

Lines 391-405: Added free bits to loss function

# Clamp logsigma_z for numerical stability
logsigma_z = torch.clamp(logsigma_z, min=-10, max=2)

# Compute per-dimension KLD for free bits
kld_per_dim = torch.distributions.kl_divergence(q, p)

# Apply free bits: only penalize KLD above threshold
free_bits = getattr(self, '_free_bits', 0.0)
if free_bits > 0:
    kld_per_dim = torch.maximum(kld_per_dim, torch.tensor(free_bits))

Lines 466-510: Updated train() signature with free_bits parameter

def train(
    self,
    num_epochs: int = 100,
    patience: int = 5,
    displayed_metrics: list[str] = ["ELBO"],
    notebook_run: bool = False,
    beta_start: float = 0.0,
    beta_end: float = 1.0,
    beta_anneal_epochs: int = None,  # Default: 50% of num_epochs
    free_bits: float = 2.0,  # NEW: Force latent capacity usage
) -> tuple[int, dict[str, list[float]]]:

CRITICAL FINDING: Annealing schedule affects component selection quality:

  • 50% annealing (100 epochs): beta=1.0 at epoch 100, 100 epochs with full KL penalty → GOOD
  • 90% annealing (180 epochs): beta=1.0 at epoch 180, only 20 epochs with full KL penalty → POOR
  • Decoder needs sufficient time with full KL penalty to learn correct component distributions

Results After Free Bits

Training metrics:

  • KLD stabilized at ~500 (healthy level, was 0.00)
  • No posterior collapse throughout 200 epochs
  • Beta annealing worked correctly
  • Weighted KLD tracked properly

Distribution quality:

  • x13: Excellent match (right-skewed preserved)
  • x14: Significantly improved (trimodal structure captured)
  • x7, x8: Excellent smooth distributions
  • x6: Perfect (categorical)
  • ⚠️ x12: Improved but still has component selection bias (see Known Issues)

Critical Update (2026-01-16)

IMPORTANT: After initial implementation, we discovered several critical bugs that caused systematic mean shifts in generated data. These have been fixed as detailed below.

Issues Found & Fixed

  1. Kurtosis Detection Bug (vae.py): Kurtosis attribute lookup was checking wrong object hierarchy
  2. Duplicate Component Temperature (continuous.py): Component temperature applied twice (8x total instead of 2x)
  3. Clipping Bounds (continuous.py): Percentile-based bounds were too aggressive
  4. Constraint Validation: Incorrect constraints can cause massive distribution corruption

All fixes are documented in the updated file modifications below.


Files Modified

1. src/nhssynth/modules/dataloader/transformers/continuous.py

Line 38: Increased n_components from 5 to 10

n_components: int = 10  # Was 5

Line 53: Changed weight_concentration_prior from 1.0 to 1e-3

weight_concentration_prior=1e-3,  # Was 1.0

Line 59: Reduced std_multiplier from 3 to 1

self._std_multiplier = 1  # Was 4 → 3 → 1

Lines 149-164: Added kurtosis detection and component diagnostics

# Diagnostic: Show effective number of components and their characteristics
weights = self._transformer.weights_
effective_components = (weights > 0.01).sum()
tqdm.write(f"[{col_name}] BGM fitted {effective_components}/{len(weights)} components")
tqdm.write(f"[{col_name}] Component means: {means.round(2)}")
tqdm.write(f"[{col_name}] Component stds: {stds.round(2)}")
tqdm.write(f"[{col_name}] GMM expected mean: {expected_mean:.2f}, actual: {actual_mean:.2f}")

# Calculate kurtosis to detect heavily-peaked distributions
from scipy import stats
if data.size > 10:
    self._kurtosis = float(stats.kurtosis(data.flatten(), fisher=True))
    if self._kurtosis > 5:
        tqdm.write(f"[{col_name}] High kurtosis detected: {self._kurtosis:.2f}")
else:
    self._kurtosis = 0.0

Lines 159-169: Fixed clipping bounds (CRITICAL BUG FIX)

# FIXED: Use actual min instead of 0.5th percentile to preserve lower bounds
# FIXED: Use 99.9th percentile instead of 99.5th for upper bound
# FIXED: 0% margin on min to avoid negative values when natural bound is 0
# FIXED: 15% margin on max instead of 5% to allow proper tail extrapolation
self._data_min = float(np.min(data))  # Was: np.percentile(data, 0.5)
self._data_max = float(np.percentile(data, 99.9))  # Was: np.percentile(data, 99.5)
data_range = self._data_max - self._data_min
self._safe_min = self._data_min  # Was: self._data_min - 0.05 * data_range
self._safe_max = self._data_max + 0.15 * data_range  # Was: + 0.05 * data_range

Lines 361-364: Removed duplicate component temperature (CRITICAL BUG FIX)

# REMOVED: component_temperature = 4.0
# REMOVED: logits = logits / component_temperature
# This was causing 8x total temperature (2.0 from VAE × 4.0 here)
# which biased component selection toward higher-mean components

Lines 388-396: Added component selection diagnostics

# Debug: show component selection frequencies vs fitted weights
unique_k, counts_k = np.unique(k_idx, return_counts=True)
selection_freq = np.zeros(len(means))
selection_freq[unique_k] = counts_k / len(k_idx)
tqdm.write(f"[revert:{base}] Component selection frequencies: {selection_freq.round(3)}")
selected_mean = np.mean(means[k_idx])
tqdm.write(f"[revert:{base}] Mean from selected components: {selected_mean:.2f}")

2. src/nhssynth/modules/dataloader/metadata.py

Lines 154-162: Added datetime special handling

# Datetime columns should use single Gaussian (typically smooth temporal distributions)
if self.dtype.kind == "M":
    datetime_config = self.transformer_config.copy()
    # Force exactly 1 component for smooth temporal distributions
    datetime_config['n_components'] = 1
    transformer = ClusterContinuousTransformer(**datetime_config)
    transformer = DatetimeTransformer(transformer)

3. src/nhssynth/modules/model/models/vae.py

Lines 229-290: Implemented adaptive temperature system

# Adaptive temperature scaling based on variable characteristics
base_temperature = 3.0  # Default for smooth distributions
peaked_temperature = 1.5  # Lower for peaked distributions (high kurtosis)
datetime_boost = 5.0  # Additional boost for datetime

# Categorize continuous columns by characteristics
datetime_indices = []  # Get 15.0x (3.0 × 5.0)
peaked_indices = []    # Get 1.5x
normal_indices = []    # Get 3.0x

# Apply adaptive temperature based on kurtosis and dtype

Lines 253-260: Fixed kurtosis detection (CRITICAL BUG FIX)

# FIXED: Check outer transformer for _kurtosis directly
# Previously checked inner _transformer (sklearn BGM) which doesn't have _kurtosis
elif hasattr(col_meta.transformer, '_kurtosis') and col_meta.transformer._kurtosis > 5:
    peaked_indices.append(idx)
    tqdm.write(f"  [{base_name}] idx={idx} → peaked (kurtosis={col_meta.transformer._kurtosis:.2f})")
else:
    normal_indices.append(idx)
    tqdm.write(f"  [{base_name}] idx={idx} → normal")

Lines 313-318: Added GMM component temperature

# Apply temperature to GMM component logits to encourage mixing
gmm_temperature = 2.0  # Moderate temperature to preserve multimodal structure
for gmm_idxs in gmm_component_groups:
    x_gen_[:, gmm_idxs] = x_gen[:, gmm_idxs] / gmm_temperature
# NOTE: This is the ONLY component temperature - duplicate in continuous.py was removed

4. src/nhssynth/modules/dataloader/metatransformer.py

Lines 624-647: Added post-generation Gaussian smoothing

# Add Gaussian smoothing to continuous variables to blur GMM component peaks
smoothing_std = 0.03  # 3% of column std as smoothing noise
continuous_cols = []
for col_meta in self._metadata:
    # Skip datetime columns, categorical columns, and missingness indicators
    if (hasattr(col_meta, 'transformer') and
        col_meta.name in dataset.columns and
        not col_meta.name.endswith('_missing') and
        dataset[col_meta.name].dtype in ['float64', 'float32', 'int64', 'int32']):
        continuous_cols.append(col_meta.name)

if continuous_cols:
    for col in continuous_cols:
        col_std = dataset[col].std()
        if col_std > 0:
     ...
Read more

v0.10.2

04 Oct 11:54

Choose a tag to compare

Full Changelog: v0.10.1...v0.10.2

v0.10.1

28 Feb 15:51

Choose a tag to compare

Full Changelog: v0.10.0...v0.10.1

v0.10.0

03 Nov 15:12

Choose a tag to compare

Full Changelog: v0.9.1...v0.10.0

v0.9.1

19 Oct 08:08

Choose a tag to compare

What's Changed

Full Changelog: v0.9.0...v0.9.1

v0.9.0

18 Oct 10:43

Choose a tag to compare

Full Changelog: v0.8.0...v0.9.0

v0.8.0

16 Oct 09:24

Choose a tag to compare

Full Changelog: v0.7.3...v0.8.0

v0.7.3

04 Sep 16:34

Choose a tag to compare

What's Changed

  • Fix the bugs with the continuous transformer and suppress some unnecessary warnings by @HarrisonWilde in #109

Full Changelog: v0.7.2...v0.7.3

v0.7.2

04 Sep 09:51

Choose a tag to compare

What's Changed

Full Changelog: v0.7.1...v0.7.2

v0.7.1

30 Aug 14:58

Choose a tag to compare

What's Changed

Full Changelog: v0.7.0...v0.7.1