-
Notifications
You must be signed in to change notification settings - Fork 173
Open
Description
Colab
# @title Complete DTLN QAT Training Pipeline (L4 Optimized & Resilient)
# --- 0. Auto-Install Dependencies ---
import subprocess
import sys
def install_packages():
packages = ["onnx"]
for package in packages:
try:
__import__(package.replace("-", "_"))
except ImportError:
print(f"Installing {package}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install_packages()
# --- 1. Imports & Configuration ---
import torch
import torch.nn as nn
import torch.quantization
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import glob
import tarfile
import copy
from tqdm import tqdm
from google.colab import drive
# Mount Drive
if not os.path.exists('/content/drive'):
print("Mounting Google Drive...")
drive.mount('/content/drive')
# Paths & Config
DRIVE_ROOT = '/content/drive/MyDrive/Start_Smart_ASR/dataset'
LOCAL_ROOT = '/content/dataset' # Fast local storage
CHECKPOINT_DIR = '/content/drive/MyDrive/Start_Smart_ASR/Checkpoints'
# Hyperparameters
BATCH_SIZE = 128 # Optimized for L4
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# --- 2. Robust Data Extraction ---
def setup_data_split(split_name):
drive_path = os.path.join(DRIVE_ROOT, split_name)
local_path = os.path.join(LOCAL_ROOT, split_name)
flag_file = os.path.join(local_path, '.extraction_complete')
if not os.path.exists(drive_path):
print(f"Warning: {drive_path} does not exist. Skipping.")
return
os.makedirs(local_path, exist_ok=True)
# Check for completion flag to skip re-extraction on restarts
if os.path.exists(flag_file):
print(f"[{split_name}] Found completion flag. Skipping extraction.")
return
print(f"[{split_name}] Copying and Extracting shards from Drive...")
shards = sorted(glob.glob(os.path.join(drive_path, '*.tar')))
try:
for shard in tqdm(shards, desc=f"Extracting {split_name}"):
with tarfile.open(shard, 'r') as tar:
tar.extractall(local_path)
# Write flag upon success
with open(flag_file, 'w') as f:
f.write("done")
except Exception as e:
print(f"Error extracting {split_name}: {e}")
setup_data_split('train')
setup_data_split('val')
# --- 3. Optimized Dataset (No Resampler) ---
class DTLNDataset(Dataset):
def __init__(self, data_path, target_len_samples=16000*10):
self.data_path = data_path
self.noisy_files = sorted(glob.glob(os.path.join(data_path, '*.noisy.flac')))
self.target_len = target_len_samples
if len(self.noisy_files) == 0:
print(f"WARNING: No data found in {data_path}")
def __len__(self):
return len(self.noisy_files)
def __getitem__(self, idx):
noisy_path = self.noisy_files[idx]
# Efficient string replacement
clean_path = noisy_path[:-11] + '.clean.flac'
# Load (Assuming 16kHz files to save CPU)
noisy_sig, _ = torchaudio.load(noisy_path)
clean_sig, _ = torchaudio.load(clean_path)
# Random Crop / Pad
sig_len = noisy_sig.shape[1]
if sig_len > self.target_len:
start = torch.randint(0, sig_len - self.target_len, (1,)).item()
noisy_sig = noisy_sig[:, start:start+self.target_len]
clean_sig = clean_sig[:, start:start+self.target_len]
elif sig_len < self.target_len:
padding = self.target_len - sig_len
noisy_sig = torch.nn.functional.pad(noisy_sig, (0, padding))
clean_sig = torch.nn.functional.pad(clean_sig, (0, padding))
return noisy_sig.squeeze(), clean_sig.squeeze()
# Loaders
num_workers = 4 # Matches typical Colab vCPU count
print(f"Using {num_workers} workers | Batch Size {BATCH_SIZE}")
train_loader = DataLoader(
DTLNDataset(os.path.join(LOCAL_ROOT, 'train')),
batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers,
pin_memory=True, persistent_workers=True, prefetch_factor=4
)
val_loader = DataLoader(
DTLNDataset(os.path.join(LOCAL_ROOT, 'val')),
batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers,
pin_memory=True, persistent_workers=True
)
# --- 4. QAT-Safe Model Architecture ---
class InstantLayerNormalization(nn.Module):
def __init__(self, num_features, epsilon=1e-7):
super(InstantLayerNormalization, self).__init__()
self.epsilon = epsilon
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
# FP32 Normalization logic
mean = torch.mean(x, dim=-1, keepdim=True)
variance = torch.var(x, dim=-1, unbiased=False, keepdim=True)
std = torch.sqrt(variance + self.epsilon)
outputs = (x - mean) / std
return outputs * self.gamma + self.beta
class DTLN_Part1(nn.Module):
def __init__(self, input_dim=257, hidden_dim=128):
super(DTLN_Part1, self).__init__()
self.norm = InstantLayerNormalization(input_dim)
self.quant = torch.quantization.QuantStub() # Quantize AFTER Norm (Safe)
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=2, batch_first=True)
self.dense = nn.Linear(hidden_dim, input_dim)
self.sigmoid = nn.Sigmoid()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, mag_input):
x = self.norm(mag_input)
x = self.quant(x)
x, _ = self.lstm(x)
mask = self.sigmoid(self.dense(x))
return self.dequant(mask)
class DTLN_Part2(nn.Module):
def __init__(self, block_len=512, hidden_dim=128, encoder_size=256):
super(DTLN_Part2, self).__init__()
self.encoder = nn.Linear(block_len, encoder_size, bias=False)
self.norm = InstantLayerNormalization(encoder_size)
self.quant = torch.quantization.QuantStub() # Quantize AFTER Norm (Safe)
self.lstm = nn.LSTM(input_size=encoder_size, hidden_size=hidden_dim, num_layers=2, batch_first=True)
self.dense = nn.Linear(hidden_dim, encoder_size)
self.sigmoid = nn.Sigmoid()
self.decoder = nn.Linear(encoder_size, block_len, bias=False)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, frame_input):
encoded = self.encoder(frame_input)
encoded_norm = self.norm(encoded)
x = self.quant(encoded_norm)
x_lstm, _ = self.lstm(x)
mask = self.sigmoid(self.dense(x_lstm))
mask_float = self.dequant(mask)
estimated = encoded * mask_float
return self.decoder(estimated)
class DTLN_Full_Train(nn.Module):
def __init__(self, part1, part2, block_len=512, block_shift=128):
super(DTLN_Full_Train, self).__init__()
self.part1 = part1
self.part2 = part2
self.block_len = block_len
self.block_shift = block_shift
self.register_buffer('window', torch.hann_window(block_len))
def stft(self, x):
return torch.stft(x, n_fft=self.block_len, hop_length=self.block_shift,
win_length=self.block_len, window=self.window,
return_complex=True, center=False)
def forward(self, x):
# STFT (FP32)
stft_data = self.stft(x)
mag = torch.abs(stft_data).permute(0, 2, 1)
phase = torch.angle(stft_data)
# Part 1 (INT8 Sim)
mag_norm = torch.log10(mag + 1e-7)
mask1 = self.part1(mag_norm)
# Intermediate (FP32)
estimated_mag = mag * mask1
est_stft = torch.complex(estimated_mag * torch.cos(phase.permute(0, 2, 1)),
estimated_mag * torch.sin(phase.permute(0, 2, 1)))
estimated_frames = torch.fft.irfft(est_stft, n=self.block_len, dim=-1)
# Part 2 (INT8 Sim)
predicted_frames = self.part2(estimated_frames)
# Synthesis (FP32)
predicted_frames_windowed = predicted_frames * self.window
b, t, l = predicted_frames_windowed.shape
output_sig = torch.nn.functional.fold(
predicted_frames_windowed.permute(0, 2, 1).reshape(b, l, -1),
output_size=(1, (t - 1) * self.block_shift + self.block_len),
kernel_size=(1, self.block_len),
stride=(1, self.block_shift)
)
return output_sig.squeeze(1).squeeze(1)
# --- 5. Loss & Setup ---
class SNRLoss(nn.Module):
def __init__(self, block_len=512):
super(SNRLoss, self).__init__()
self.block_len = block_len
def forward(self, s_estimate, s_true):
# Match lengths
min_len = min(s_estimate.shape[-1], s_true.shape[-1])
s_estimate = s_estimate[..., :min_len]
s_true = s_true[..., :min_len]
# Latency Compensation (Skip first fade-in block)
if min_len > self.block_len:
s_estimate = s_estimate[..., self.block_len:]
s_true = s_true[..., self.block_len:]
# SNR Calculation
snr = torch.mean(s_true ** 2, dim=-1, keepdim=True) / \
(torch.mean((s_true - s_estimate) ** 2, dim=-1, keepdim=True) + 1e-7)
loss = -10 * torch.log10(snr + 1e-7)
return torch.mean(loss)
# Init Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DTLN_Full_Train(DTLN_Part1(), DTLN_Part2()).to(device)
# QAT Prep
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
torch.ao.quantization.prepare_qat(model, inplace=True)
criterion = SNRLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# --- 6. Resume Logic (Optional) ---
start_epoch = 0
resume_path = os.path.join(CHECKPOINT_DIR, "dtln_qat_resume.pt")
if os.path.exists(resume_path):
print(f"Resuming from {resume_path}...")
checkpoint = torch.load(resume_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
print(f"Resuming at Epoch {start_epoch+1}")
print(f"Starting QAT Training on {device}...")
# --- 7. Training Loop ---
best_val_loss = float('inf')
for epoch in range(start_epoch, NUM_EPOCHS):
model.train()
train_loss = 0.0
# Ensure QAT observers are active
model.apply(torch.ao.quantization.enable_observer)
model.apply(torch.ao.quantization.enable_fake_quant)
pbar = tqdm(train_loader, desc=f"Ep {epoch+1}/{NUM_EPOCHS}")
for noisy, clean in pbar:
noisy, clean = noisy.to(device), clean.to(device)
optimizer.zero_grad()
estimate = model(noisy)
loss = criterion(estimate, clean)
loss.backward()
optimizer.step()
train_loss += loss.item()
pbar.set_postfix({'loss': f"{loss.item():.2f}dB"})
avg_train_loss = train_loss / len(train_loader)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for noisy, clean in tqdm(val_loader, desc="Val"):
noisy, clean = noisy.to(device), clean.to(device)
estimate = model(noisy)
loss = criterion(estimate, clean)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Summary Ep {epoch+1}: Train {avg_train_loss:.4f} | Val {avg_val_loss:.4f}")
# --- Robust Checkpointing ---
# 1. Resume Checkpoint (Save Everything for crash recovery)
resume_checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_val_loss,
}
torch.save(resume_checkpoint, resume_path)
# 2. Best Checkpoint (Clean CPU weights for export)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
# Save a clean CPU copy of the state dict
best_state = {k: v.cpu() for k, v in model.state_dict().items()}
torch.save(best_state, os.path.join(CHECKPOINT_DIR, "dtln_qat_best.pt"))
print(f"--> Saved Best Model ({best_val_loss:.4f}dB)")
print("Training Complete.")
let me know if you have anything to add?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels