Skip to content

Pytorch version with some 'quality' updates #87

@rolyantrauts

Description

@rolyantrauts

Colab

# @title Complete DTLN QAT Training Pipeline (L4 Optimized & Resilient)

# --- 0. Auto-Install Dependencies ---
import subprocess
import sys

def install_packages():
    packages = ["onnx"]
    for package in packages:
        try:
            __import__(package.replace("-", "_"))
        except ImportError:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install_packages()

# --- 1. Imports & Configuration ---
import torch
import torch.nn as nn
import torch.quantization
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import glob
import tarfile
import copy
from tqdm import tqdm
from google.colab import drive

# Mount Drive
if not os.path.exists('/content/drive'):
    print("Mounting Google Drive...")
    drive.mount('/content/drive')

# Paths & Config
DRIVE_ROOT = '/content/drive/MyDrive/Start_Smart_ASR/dataset'
LOCAL_ROOT = '/content/dataset' # Fast local storage
CHECKPOINT_DIR = '/content/drive/MyDrive/Start_Smart_ASR/Checkpoints'

# Hyperparameters
BATCH_SIZE = 128        # Optimized for L4
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- 2. Robust Data Extraction ---
def setup_data_split(split_name):
    drive_path = os.path.join(DRIVE_ROOT, split_name)
    local_path = os.path.join(LOCAL_ROOT, split_name)
    flag_file = os.path.join(local_path, '.extraction_complete')

    if not os.path.exists(drive_path):
        print(f"Warning: {drive_path} does not exist. Skipping.")
        return

    os.makedirs(local_path, exist_ok=True)

    # Check for completion flag to skip re-extraction on restarts
    if os.path.exists(flag_file):
        print(f"[{split_name}] Found completion flag. Skipping extraction.")
        return
    
    print(f"[{split_name}] Copying and Extracting shards from Drive...")
    shards = sorted(glob.glob(os.path.join(drive_path, '*.tar')))
    
    try:
        for shard in tqdm(shards, desc=f"Extracting {split_name}"):
            with tarfile.open(shard, 'r') as tar:
                tar.extractall(local_path)
        
        # Write flag upon success
        with open(flag_file, 'w') as f:
            f.write("done")
            
    except Exception as e:
        print(f"Error extracting {split_name}: {e}")

setup_data_split('train')
setup_data_split('val')

# --- 3. Optimized Dataset (No Resampler) ---
class DTLNDataset(Dataset):
    def __init__(self, data_path, target_len_samples=16000*10):
        self.data_path = data_path
        self.noisy_files = sorted(glob.glob(os.path.join(data_path, '*.noisy.flac')))
        self.target_len = target_len_samples
        
        if len(self.noisy_files) == 0:
            print(f"WARNING: No data found in {data_path}")

    def __len__(self):
        return len(self.noisy_files)

    def __getitem__(self, idx):
        noisy_path = self.noisy_files[idx]
        # Efficient string replacement
        clean_path = noisy_path[:-11] + '.clean.flac' 

        # Load (Assuming 16kHz files to save CPU)
        noisy_sig, _ = torchaudio.load(noisy_path)
        clean_sig, _ = torchaudio.load(clean_path)

        # Random Crop / Pad
        sig_len = noisy_sig.shape[1]
        if sig_len > self.target_len:
            start = torch.randint(0, sig_len - self.target_len, (1,)).item()
            noisy_sig = noisy_sig[:, start:start+self.target_len]
            clean_sig = clean_sig[:, start:start+self.target_len]
        elif sig_len < self.target_len:
            padding = self.target_len - sig_len
            noisy_sig = torch.nn.functional.pad(noisy_sig, (0, padding))
            clean_sig = torch.nn.functional.pad(clean_sig, (0, padding))

        return noisy_sig.squeeze(), clean_sig.squeeze()

# Loaders
num_workers = 4 # Matches typical Colab vCPU count
print(f"Using {num_workers} workers | Batch Size {BATCH_SIZE}")

train_loader = DataLoader(
    DTLNDataset(os.path.join(LOCAL_ROOT, 'train')),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers,
    pin_memory=True, persistent_workers=True, prefetch_factor=4
)

val_loader = DataLoader(
    DTLNDataset(os.path.join(LOCAL_ROOT, 'val')),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers,
    pin_memory=True, persistent_workers=True
)

# --- 4. QAT-Safe Model Architecture ---

class InstantLayerNormalization(nn.Module):
    def __init__(self, num_features, epsilon=1e-7):
        super(InstantLayerNormalization, self).__init__()
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        # FP32 Normalization logic
        mean = torch.mean(x, dim=-1, keepdim=True)
        variance = torch.var(x, dim=-1, unbiased=False, keepdim=True)
        std = torch.sqrt(variance + self.epsilon)
        outputs = (x - mean) / std
        return outputs * self.gamma + self.beta

class DTLN_Part1(nn.Module):
    def __init__(self, input_dim=257, hidden_dim=128):
        super(DTLN_Part1, self).__init__()
        self.norm = InstantLayerNormalization(input_dim)
        self.quant = torch.quantization.QuantStub() # Quantize AFTER Norm (Safe)
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=2, batch_first=True)
        self.dense = nn.Linear(hidden_dim, input_dim)
        self.sigmoid = nn.Sigmoid()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, mag_input):
        x = self.norm(mag_input)
        x = self.quant(x) 
        x, _ = self.lstm(x)
        mask = self.sigmoid(self.dense(x))
        return self.dequant(mask)

class DTLN_Part2(nn.Module):
    def __init__(self, block_len=512, hidden_dim=128, encoder_size=256):
        super(DTLN_Part2, self).__init__()
        self.encoder = nn.Linear(block_len, encoder_size, bias=False)
        self.norm = InstantLayerNormalization(encoder_size)
        self.quant = torch.quantization.QuantStub() # Quantize AFTER Norm (Safe)
        self.lstm = nn.LSTM(input_size=encoder_size, hidden_size=hidden_dim, num_layers=2, batch_first=True)
        self.dense = nn.Linear(hidden_dim, encoder_size)
        self.sigmoid = nn.Sigmoid()
        self.decoder = nn.Linear(encoder_size, block_len, bias=False)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, frame_input):
        encoded = self.encoder(frame_input)
        encoded_norm = self.norm(encoded)
        x = self.quant(encoded_norm)
        x_lstm, _ = self.lstm(x)
        mask = self.sigmoid(self.dense(x_lstm))
        mask_float = self.dequant(mask)
        estimated = encoded * mask_float
        return self.decoder(estimated)

class DTLN_Full_Train(nn.Module):
    def __init__(self, part1, part2, block_len=512, block_shift=128):
        super(DTLN_Full_Train, self).__init__()
        self.part1 = part1
        self.part2 = part2
        self.block_len = block_len
        self.block_shift = block_shift
        self.register_buffer('window', torch.hann_window(block_len))

    def stft(self, x):
        return torch.stft(x, n_fft=self.block_len, hop_length=self.block_shift,
                          win_length=self.block_len, window=self.window,
                          return_complex=True, center=False)

    def forward(self, x):
        # STFT (FP32)
        stft_data = self.stft(x)
        mag = torch.abs(stft_data).permute(0, 2, 1)
        phase = torch.angle(stft_data)
        
        # Part 1 (INT8 Sim)
        mag_norm = torch.log10(mag + 1e-7)
        mask1 = self.part1(mag_norm) 

        # Intermediate (FP32)
        estimated_mag = mag * mask1
        est_stft = torch.complex(estimated_mag * torch.cos(phase.permute(0, 2, 1)),
                                 estimated_mag * torch.sin(phase.permute(0, 2, 1)))
        estimated_frames = torch.fft.irfft(est_stft, n=self.block_len, dim=-1)
        
        # Part 2 (INT8 Sim)
        predicted_frames = self.part2(estimated_frames) 
        
        # Synthesis (FP32)
        predicted_frames_windowed = predicted_frames * self.window
        b, t, l = predicted_frames_windowed.shape
        output_sig = torch.nn.functional.fold(
            predicted_frames_windowed.permute(0, 2, 1).reshape(b, l, -1),
            output_size=(1, (t - 1) * self.block_shift + self.block_len),
            kernel_size=(1, self.block_len),
            stride=(1, self.block_shift)
        )
        return output_sig.squeeze(1).squeeze(1)

# --- 5. Loss & Setup ---

class SNRLoss(nn.Module):
    def __init__(self, block_len=512):
        super(SNRLoss, self).__init__()
        self.block_len = block_len

    def forward(self, s_estimate, s_true):
        # Match lengths
        min_len = min(s_estimate.shape[-1], s_true.shape[-1])
        s_estimate = s_estimate[..., :min_len]
        s_true = s_true[..., :min_len]
        
        # Latency Compensation (Skip first fade-in block)
        if min_len > self.block_len:
            s_estimate = s_estimate[..., self.block_len:]
            s_true = s_true[..., self.block_len:]

        # SNR Calculation
        snr = torch.mean(s_true ** 2, dim=-1, keepdim=True) / \
              (torch.mean((s_true - s_estimate) ** 2, dim=-1, keepdim=True) + 1e-7)
        loss = -10 * torch.log10(snr + 1e-7)
        return torch.mean(loss)

# Init Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DTLN_Full_Train(DTLN_Part1(), DTLN_Part2()).to(device)

# QAT Prep
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
torch.ao.quantization.prepare_qat(model, inplace=True)

criterion = SNRLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- 6. Resume Logic (Optional) ---
start_epoch = 0
resume_path = os.path.join(CHECKPOINT_DIR, "dtln_qat_resume.pt")
if os.path.exists(resume_path):
    print(f"Resuming from {resume_path}...")
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Resuming at Epoch {start_epoch+1}")

print(f"Starting QAT Training on {device}...")

# --- 7. Training Loop ---
best_val_loss = float('inf')

for epoch in range(start_epoch, NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    
    # Ensure QAT observers are active
    model.apply(torch.ao.quantization.enable_observer)
    model.apply(torch.ao.quantization.enable_fake_quant)

    pbar = tqdm(train_loader, desc=f"Ep {epoch+1}/{NUM_EPOCHS}")
    for noisy, clean in pbar:
        noisy, clean = noisy.to(device), clean.to(device)
        
        optimizer.zero_grad()
        estimate = model(noisy)
        loss = criterion(estimate, clean)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.2f}dB"})

    avg_train_loss = train_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for noisy, clean in tqdm(val_loader, desc="Val"):
            noisy, clean = noisy.to(device), clean.to(device)
            estimate = model(noisy)
            loss = criterion(estimate, clean)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Summary Ep {epoch+1}: Train {avg_train_loss:.4f} | Val {avg_val_loss:.4f}")

    # --- Robust Checkpointing ---

    # 1. Resume Checkpoint (Save Everything for crash recovery)
    resume_checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_val_loss,
    }
    torch.save(resume_checkpoint, resume_path)

    # 2. Best Checkpoint (Clean CPU weights for export)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        
        # Save a clean CPU copy of the state dict
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}
        torch.save(best_state, os.path.join(CHECKPOINT_DIR, "dtln_qat_best.pt"))
        print(f"--> Saved Best Model ({best_val_loss:.4f}dB)")

print("Training Complete.")

let me know if you have anything to add?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions