Releases: nhsengland/NHSSynth
Stable VAE GAN comparison
Implementation Summary: Synthetic Data Fidelity Improvements
Date: 2026-01-15 (Updated: 2026-01-16, 2026-01-19)
Dataset: support.csv
Initial Issues: 50k+ constraint violations, 89% clipping, artificial peaks in distributions
Note: Critical bugs and posterior collapse issues were debugged, diagnosed, and documented using Claude Code (claude-sonnet-4-5).
Critical Update (2026-01-19): Posterior Collapse Fix
MAJOR FIX: Resolved complete posterior collapse in VAE where KLD collapsed to 0, preventing the encoder from learning meaningful latent representations.
Problem Identified
After initial fixes, VAE training showed catastrophic posterior collapse:
- KLD collapsed from 348 → 0 in first 10 epochs
- Encoder learned to output z ≈ N(0,1) (posterior = prior, no information)
- Decoder ignored latent codes - reconstructed from random noise
- Result: x11 had extreme variance collapse, x14 lost multimodal structure
Root Cause
Even with KL annealing (beta: 0.0 → 1.0), the encoder could achieve low loss by:
- Outputting uninformative z ≈ N(0,1) to minimize KLD
- Decoder learning to reconstruct from random z ~ N(0,1)
- No incentive to encode meaningful information
Solution: Free Bits
Implemented free bits constraint that forces encoder to use latent capacity:
- Only penalize KLD above a threshold per latent dimension
- Default: 2.0 nats per dimension (e.g., 256 total for 128D latent)
- Prevents encoder from collapsing to uninformative prior
Implementation Details (vae.py)
Lines 391-405: Added free bits to loss function
# Clamp logsigma_z for numerical stability
logsigma_z = torch.clamp(logsigma_z, min=-10, max=2)
# Compute per-dimension KLD for free bits
kld_per_dim = torch.distributions.kl_divergence(q, p)
# Apply free bits: only penalize KLD above threshold
free_bits = getattr(self, '_free_bits', 0.0)
if free_bits > 0:
kld_per_dim = torch.maximum(kld_per_dim, torch.tensor(free_bits))Lines 466-510: Updated train() signature with free_bits parameter
def train(
self,
num_epochs: int = 100,
patience: int = 5,
displayed_metrics: list[str] = ["ELBO"],
notebook_run: bool = False,
beta_start: float = 0.0,
beta_end: float = 1.0,
beta_anneal_epochs: int = None, # Default: 50% of num_epochs
free_bits: float = 2.0, # NEW: Force latent capacity usage
) -> tuple[int, dict[str, list[float]]]:CRITICAL FINDING: Annealing schedule affects component selection quality:
- 50% annealing (100 epochs): beta=1.0 at epoch 100, 100 epochs with full KL penalty → GOOD
- 90% annealing (180 epochs): beta=1.0 at epoch 180, only 20 epochs with full KL penalty → POOR
- Decoder needs sufficient time with full KL penalty to learn correct component distributions
Results After Free Bits
Training metrics:
- KLD stabilized at ~500 (healthy level, was 0.00)
- No posterior collapse throughout 200 epochs
- Beta annealing worked correctly
- Weighted KLD tracked properly
Distribution quality:
- ✓ x13: Excellent match (right-skewed preserved)
- ✓ x14: Significantly improved (trimodal structure captured)
- ✓ x7, x8: Excellent smooth distributions
- ✓ x6: Perfect (categorical)
⚠️ x12: Improved but still has component selection bias (see Known Issues)
Critical Update (2026-01-16)
IMPORTANT: After initial implementation, we discovered several critical bugs that caused systematic mean shifts in generated data. These have been fixed as detailed below.
Issues Found & Fixed
- Kurtosis Detection Bug (vae.py): Kurtosis attribute lookup was checking wrong object hierarchy
- Duplicate Component Temperature (continuous.py): Component temperature applied twice (8x total instead of 2x)
- Clipping Bounds (continuous.py): Percentile-based bounds were too aggressive
- Constraint Validation: Incorrect constraints can cause massive distribution corruption
All fixes are documented in the updated file modifications below.
Files Modified
1. src/nhssynth/modules/dataloader/transformers/continuous.py
Line 38: Increased n_components from 5 to 10
n_components: int = 10 # Was 5Line 53: Changed weight_concentration_prior from 1.0 to 1e-3
weight_concentration_prior=1e-3, # Was 1.0Line 59: Reduced std_multiplier from 3 to 1
self._std_multiplier = 1 # Was 4 → 3 → 1Lines 149-164: Added kurtosis detection and component diagnostics
# Diagnostic: Show effective number of components and their characteristics
weights = self._transformer.weights_
effective_components = (weights > 0.01).sum()
tqdm.write(f"[{col_name}] BGM fitted {effective_components}/{len(weights)} components")
tqdm.write(f"[{col_name}] Component means: {means.round(2)}")
tqdm.write(f"[{col_name}] Component stds: {stds.round(2)}")
tqdm.write(f"[{col_name}] GMM expected mean: {expected_mean:.2f}, actual: {actual_mean:.2f}")
# Calculate kurtosis to detect heavily-peaked distributions
from scipy import stats
if data.size > 10:
self._kurtosis = float(stats.kurtosis(data.flatten(), fisher=True))
if self._kurtosis > 5:
tqdm.write(f"[{col_name}] High kurtosis detected: {self._kurtosis:.2f}")
else:
self._kurtosis = 0.0Lines 159-169: Fixed clipping bounds (CRITICAL BUG FIX)
# FIXED: Use actual min instead of 0.5th percentile to preserve lower bounds
# FIXED: Use 99.9th percentile instead of 99.5th for upper bound
# FIXED: 0% margin on min to avoid negative values when natural bound is 0
# FIXED: 15% margin on max instead of 5% to allow proper tail extrapolation
self._data_min = float(np.min(data)) # Was: np.percentile(data, 0.5)
self._data_max = float(np.percentile(data, 99.9)) # Was: np.percentile(data, 99.5)
data_range = self._data_max - self._data_min
self._safe_min = self._data_min # Was: self._data_min - 0.05 * data_range
self._safe_max = self._data_max + 0.15 * data_range # Was: + 0.05 * data_rangeLines 361-364: Removed duplicate component temperature (CRITICAL BUG FIX)
# REMOVED: component_temperature = 4.0
# REMOVED: logits = logits / component_temperature
# This was causing 8x total temperature (2.0 from VAE × 4.0 here)
# which biased component selection toward higher-mean componentsLines 388-396: Added component selection diagnostics
# Debug: show component selection frequencies vs fitted weights
unique_k, counts_k = np.unique(k_idx, return_counts=True)
selection_freq = np.zeros(len(means))
selection_freq[unique_k] = counts_k / len(k_idx)
tqdm.write(f"[revert:{base}] Component selection frequencies: {selection_freq.round(3)}")
selected_mean = np.mean(means[k_idx])
tqdm.write(f"[revert:{base}] Mean from selected components: {selected_mean:.2f}")2. src/nhssynth/modules/dataloader/metadata.py
Lines 154-162: Added datetime special handling
# Datetime columns should use single Gaussian (typically smooth temporal distributions)
if self.dtype.kind == "M":
datetime_config = self.transformer_config.copy()
# Force exactly 1 component for smooth temporal distributions
datetime_config['n_components'] = 1
transformer = ClusterContinuousTransformer(**datetime_config)
transformer = DatetimeTransformer(transformer)3. src/nhssynth/modules/model/models/vae.py
Lines 229-290: Implemented adaptive temperature system
# Adaptive temperature scaling based on variable characteristics
base_temperature = 3.0 # Default for smooth distributions
peaked_temperature = 1.5 # Lower for peaked distributions (high kurtosis)
datetime_boost = 5.0 # Additional boost for datetime
# Categorize continuous columns by characteristics
datetime_indices = [] # Get 15.0x (3.0 × 5.0)
peaked_indices = [] # Get 1.5x
normal_indices = [] # Get 3.0x
# Apply adaptive temperature based on kurtosis and dtypeLines 253-260: Fixed kurtosis detection (CRITICAL BUG FIX)
# FIXED: Check outer transformer for _kurtosis directly
# Previously checked inner _transformer (sklearn BGM) which doesn't have _kurtosis
elif hasattr(col_meta.transformer, '_kurtosis') and col_meta.transformer._kurtosis > 5:
peaked_indices.append(idx)
tqdm.write(f" [{base_name}] idx={idx} → peaked (kurtosis={col_meta.transformer._kurtosis:.2f})")
else:
normal_indices.append(idx)
tqdm.write(f" [{base_name}] idx={idx} → normal")Lines 313-318: Added GMM component temperature
# Apply temperature to GMM component logits to encourage mixing
gmm_temperature = 2.0 # Moderate temperature to preserve multimodal structure
for gmm_idxs in gmm_component_groups:
x_gen_[:, gmm_idxs] = x_gen[:, gmm_idxs] / gmm_temperature
# NOTE: This is the ONLY component temperature - duplicate in continuous.py was removed4. src/nhssynth/modules/dataloader/metatransformer.py
Lines 624-647: Added post-generation Gaussian smoothing
# Add Gaussian smoothing to continuous variables to blur GMM component peaks
smoothing_std = 0.03 # 3% of column std as smoothing noise
continuous_cols = []
for col_meta in self._metadata:
# Skip datetime columns, categorical columns, and missingness indicators
if (hasattr(col_meta, 'transformer') and
col_meta.name in dataset.columns and
not col_meta.name.endswith('_missing') and
dataset[col_meta.name].dtype in ['float64', 'float32', 'int64', 'int32']):
continuous_cols.append(col_meta.name)
if continuous_cols:
for col in continuous_cols:
col_std = dataset[col].std()
if col_std > 0:
...v0.10.2
Full Changelog: v0.10.1...v0.10.2
v0.10.1
Full Changelog: v0.10.0...v0.10.1
v0.10.0
Full Changelog: v0.9.1...v0.10.0
v0.9.1
v0.9.0
Full Changelog: v0.8.0...v0.9.0
v0.8.0
Full Changelog: v0.7.3...v0.8.0
v0.7.3
What's Changed
- Fix the bugs with the continuous transformer and suppress some unnecessary warnings by @HarrisonWilde in #109
Full Changelog: v0.7.2...v0.7.3
v0.7.2
What's Changed
- Improving handover of evaluations and experiments by @HarrisonWilde in #104
- Working plots for the dashboard by @HarrisonWilde in #106
Full Changelog: v0.7.1...v0.7.2
v0.7.1
What's Changed
- Small fixes and cleaning up by @HarrisonWilde in #103
Full Changelog: v0.7.0...v0.7.1