diff --git a/README.md b/README.md index aabde47..55907d3 100644 --- a/README.md +++ b/README.md @@ -80,12 +80,62 @@ run it ./runReconClass.sh ``` +## Command Line Options + +The classifier supports several command line options for training configuration: + +### Training Parameters +- `--learningRate`: Learning rate for training (default: 1e-4) +- `--batchSize`: Batch size for training (default: 8) +- `--epochs`: Number of training epochs (default: 100) +- `--minTrainingLoss`: Minimum reduction in training loss in orders of magnitude (default: 2, set to 0 to disable check) + +### Data Configuration +- `--trainFrameFirst`: First frame number for training data (default: 1) +- `--trainFrameLast`: Last frame number (exclusive) for training data (default: 140) +- `--validationFrameFirst`: First frame number for validation data (default: 141) +- `--validationFrameLast`: Last frame number (exclusive) for validation data (default: 150) +- `--paramFile`: Path to the parameter txt file containing gkyl input data +- `--xptCacheDir`: Path to directory for caching X-point finder outputs + +### Training Optimization +- `--use-amp`: Enable automatic mixed precision training for faster training on modern GPUs +- `--amp-dtype`: Data type for mixed precision (`float16` or `bfloat16`, default: `bfloat16`) +- `--patience`: Patience for early stopping (default: 15 epochs) + +### Output and Monitoring +- `--plot`: Enable creation of figures showing ground truth and model-identified X-points +- `--plotDir`: Directory where figures are written (default: `./plots`) +- `--checkPointFrequency`: Number of epochs between model checkpoints (default: 10) + +### Testing +- `--smoke-test`: Run minimal smoke test for CI (overrides other parameters for quick validation) + +### Example with Advanced Options + +For faster training with mixed precision and early stopping: + +```bash +python -u ${rcRoot}/reconClassifier/XPointMLTest.py \ +--paramFile=/path/to/params.txt \ +--xptCacheDir=/path/to/cache \ +--epochs 200 \ +--learningRate 1e-4 \ +--batchSize 16 \ +--use-amp \ +--amp-dtype bfloat16 \ +--patience 20 \ +--plot \ +--trainFrameLast 100 \ +--validationFrameLast 120 +``` + ## Resuming Development Work -The following commands should be run on `checkers` **every time you create a new shell** to resume work in the existing virtual environment. +The following commands should be run on `checkers` **every time you create a new shell** to resume work in the existing virtual environment. ``` cd nsfCssiMlClassifier source envPyTorch.sh source pgkyl/bin/activate -``` +``` \ No newline at end of file diff --git a/XPointMLTest.py b/XPointMLTest.py index b0c2b5d..8404d5c 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -20,6 +20,9 @@ from timeit import default_timer as timer +# Import mixed precision training components +from torch.amp import autocast, GradScaler + from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality def expand_xpoints_mask(binary_mask, kernel_size=9): @@ -134,32 +137,32 @@ def getPgkylData(paramFile, frameNumber, verbosity): def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): if cacheDir == None: - return False + return False else: - cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" - return cachedFrame.exists(); + cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" + return cachedFrame.exists(); def loadPgkylDataFromCache(cacheDir, frameNumber, fields): outFields = {} if cacheDir != None: - for name in fields.keys(): + for name in fields.keys(): if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: - outFields[name] = file.read().rstrip() + with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: + outFields[name] = file.read().rstrip() else: - outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") - return outFields + outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") + return outFields else: - return None + return None def writePgkylDataToCache(cacheDir, frameNumber, fields): if cacheDir != None: - for name, field in fields.items(): + for name, field in fields.items(): if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: - text_file.write(f"{field}") + with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: + text_file.write(f"{field}") else: - np.save(cacheDir / f"{frameNumber}_{name}.npy",field) + np.save(cacheDir / f"{frameNumber}_{name}.npy",field) # DATASET DEFINITION class XPointDataset(Dataset): @@ -171,7 +174,7 @@ class XPointDataset(Dataset): - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. """ def __init__(self, paramFile, fnumList, xptCacheDir=None, - rotateAndReflect=False, verbosity=0): + rotateAndReflect=False, verbosity=0): """ paramFile: Path to parameter file (string). fnumList: List of frames to iterate. @@ -204,11 +207,11 @@ def __init__(self, paramFile, fnumList, xptCacheDir=None, frameData = self.load(fnum) self.data.append(frameData) if rotateAndReflect: - self.data.append(rotate(frameData,90)) - self.data.append(rotate(frameData,180)) - self.data.append(rotate(frameData,270)) - self.data.append(reflect(frameData,0)) - self.data.append(reflect(frameData,1)) + self.data.append(rotate(frameData,90)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) def __len__(self): return len(self.data) @@ -222,7 +225,7 @@ def load(self, fnum): # check if cache exists if self.xptCacheDir != None: if not self.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") + print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") sys.exit() t2 = timer() @@ -259,11 +262,17 @@ def load(self, fnum): binaryMap = expand_xpoints_mask(binaryMap, kernel_size=9) + # Normalize input features for better training stability + psi_norm = (fields["psi"] - fields["psi"].mean()) / (fields["psi"].std() + 1e-8) + bx_norm = (fields["Bx"] - fields["Bx"].mean()) / (fields["Bx"].std() + 1e-8) + by_norm = (fields["By"] - fields["By"].mean()) / (fields["By"].std() + 1e-8) + jz_norm = (fields["Jz"] - fields["Jz"].mean()) / (fields["Jz"].std() + 1e-8) + # -------------- 6) Convert to Torch Tensors -------------- - psi_torch = torch.from_numpy(fields["psi"]).float().unsqueeze(0) # [1, Nx, Ny] - bx_torch = torch.from_numpy(fields["Bx"]).float().unsqueeze(0) - by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) - jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) + psi_torch = torch.from_numpy(psi_norm).float().unsqueeze(0) # [1, Nx, Ny] + bx_torch = torch.from_numpy(bx_norm).float().unsqueeze(0) + by_torch = torch.from_numpy(by_norm).float().unsqueeze(0) + jz_torch = torch.from_numpy(jz_norm).float().unsqueeze(0) all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] @@ -274,159 +283,169 @@ def load(self, fnum): "fnum": fnum, "rotation": 0, "reflectionAxis": -1, # no reflection - "psi": psi_torch, # shape [1, Nx, Ny] - "all": all_torch, # shape [4, Nx, Ny] - "mask": mask_torch, # shape [1, Nx, Ny] + "psi": psi_torch, # shape [1, Nx, Ny] + "all": all_torch, # Normalized for training + "mask": mask_torch, # shape [1, Nx, Ny] "x": fields["coords"][0], "y": fields["coords"][1], "filenameBase": fields["fileName"], "params": dict(self.params) # copy of the params for local plotting } +class XPointPatchDataset(Dataset): + """On‑the‑fly square crops, balancing positive / background patches.""" + def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30): + self.base_ds = base_ds + self.patch = patch + self.pos_ratio = pos_ratio + self.retries = retries + self.rng = np.random.default_rng() + def __len__(self): + # give each full frame K random crops per epoch (K=32 for more samples) + return len(self.base_ds) * 32 + + def _crop(self, arr, top, left): + return arr[..., top:top+self.patch, left:left+self.patch] + + def __getitem__(self, _): + frame = self.base_ds[self.rng.integers(len(self.base_ds))] + H, W = frame["mask"].shape[-2:] + + # Ensure we have enough space for cropping + if H < self.patch or W < self.patch: + # Return padded version if image is too small + return { + "all": F.pad(frame["all"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))), + "mask": F.pad(frame["mask"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))) + } + + for attempt in range(self.retries): + y0 = self.rng.integers(0, H - self.patch + 1) + x0 = self.rng.integers(0, W - self.patch + 1) + crop_mask = self._crop(frame["mask"], y0, x0) + has_pos = crop_mask.sum() > 0 + want_pos = (attempt / self.retries) < self.pos_ratio + + if has_pos == want_pos or attempt == self.retries - 1: + return { + "all" : self._crop(frame["all"], y0, x0), + "mask": crop_mask + } + + +# Improved the U-Net architecture with residual connections +# Links to understand the residual blocks: +# https://code.likeagirl.io/u-net-vs-residual-u-net-vs-attention-u-net-vs-attention-residual-u-net-899b57c5698 +# https://notes.kvfrans.com/3-building-blocks/residual-networks.html +class ResidualBlock(nn.Module): + """Residual block with two convolutions and skip connection""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + # Skip connection if dimensions don't match + self.skip = nn.Identity() + if in_channels != out_channels: + self.skip = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1), + nn.BatchNorm2d(out_channels) + ) + + def forward(self, x): + residual = self.skip(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out -# 2) U-NET ARCHITECTURE class UNet(nn.Module): """ - A simplified U-Net for binary segmentation: - in: (N, 1, H, W) ++++ BX, BY, JZ - out: (N, 1, H, W) + Improved U-Net with residual blocks and better normalization """ - def __init__(self, input_channels=1, base_channels=16): + def __init__(self, input_channels=4, base_channels=32): super().__init__() - self.enc1 = self.double_conv(input_channels, base_channels) # -> base_channels - self.enc2 = self.double_conv(base_channels, base_channels*2) # -> 2*base_channels - self.enc3 = self.double_conv(base_channels*2, base_channels*4) # -> 4*base_channels + + # Encoder + self.enc1 = ResidualBlock(input_channels, base_channels) + self.enc2 = ResidualBlock(base_channels, base_channels*2) + self.enc3 = ResidualBlock(base_channels*2, base_channels*4) + self.enc4 = ResidualBlock(base_channels*4, base_channels*8) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.dropout = nn.Dropout2d(0.1) - self.bottleneck = self.double_conv(base_channels*4, base_channels*8) + # Bottleneck + self.bottleneck = ResidualBlock(base_channels*8, base_channels*16) # Decoder + self.up4 = nn.ConvTranspose2d(base_channels*16, base_channels*8, kernel_size=2, stride=2) + self.dec4 = ResidualBlock(base_channels*16, base_channels*8) + self.up3 = nn.ConvTranspose2d(base_channels*8, base_channels*4, kernel_size=2, stride=2) - self.dec3 = self.double_conv(base_channels*8, base_channels*4) + self.dec3 = ResidualBlock(base_channels*8, base_channels*4) self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=2, stride=2) - self.dec2 = self.double_conv(base_channels*4, base_channels*2) + self.dec2 = ResidualBlock(base_channels*4, base_channels*2) self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2) - self.dec1 = self.double_conv(base_channels*2, base_channels) + self.dec1 = ResidualBlock(base_channels*2, base_channels) self.out_conv = nn.Conv2d(base_channels, 1, kernel_size=1) - def double_conv(self, in_ch, out_ch): - return nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), - nn.ReLU(inplace=True) - ) - def forward(self, x): # Encoder - e1 = self.enc1(x) # shape: [N, base_channels, H, W] - p1 = self.pool(e1) # half spatial dims - - e2 = self.enc2(p1) # [N, 2*base_channels, H/2, W/2] + e1 = self.enc1(x) + p1 = self.pool(e1) + + e2 = self.enc2(p1) p2 = self.pool(e2) - - e3 = self.enc3(p2) # [N, 4*base_channels, H/4, W/4] - p3 = self.pool(e3) # [N, 4*base_channels, H/8, W/8] + p2 = self.dropout(p2) + + e3 = self.enc3(p2) + p3 = self.pool(e3) + p3 = self.dropout(p3) + + e4 = self.enc4(p3) + p4 = self.pool(e4) + p4 = self.dropout(p4) # Bottleneck - b = self.bottleneck(p3) # [N, 8*base_channels, H/8, W/8] + b = self.bottleneck(p4) # Decoder - u3 = self.up3(b) # -> shape ~ e3 - cat3 = torch.cat([u3, e3], dim=1) # skip connection + u4 = self.up4(b) + cat4 = torch.cat([u4, e4], dim=1) + d4 = self.dec4(cat4) + + u3 = self.up3(d4) + cat3 = torch.cat([u3, e3], dim=1) d3 = self.dec3(cat3) - u2 = self.up2(d3) # -> shape ~ e2 + u2 = self.up2(d3) cat2 = torch.cat([u2, e2], dim=1) d2 = self.dec2(cat2) - u1 = self.up1(d2) # -> shape ~ e1 + u1 = self.up1(d2) cat1 = torch.cat([u1, e1], dim=1) d1 = self.dec1(cat1) out = self.out_conv(d1) - return out # We'll apply sigmoid in the loss or after - - -# TRAIN & VALIDATION UTILS -def train_one_epoch(model, loader, criterion, optimizer, device): - model.train() - running_loss = 0.0 - for batch in loader: - all, mask = batch["all"].to(device), batch["mask"].to(device) - pred = model(all) - - loss = criterion(pred, mask) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - running_loss += loss.item() - return running_loss / len(loader) - -def validate_one_epoch(model, loader, criterion, device): - model.eval() - val_loss = 0.0 - with torch.no_grad(): - for batch in loader: - all, mask = batch["all"].to(device), batch["mask"].to(device) - pred = model(all) - loss = criterion(pred, mask) - val_loss += loss.item() - return val_loss / len(loader) - - -class FocalLoss(nn.Module): - def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): - """ - Focal Loss implementation - - Parameters: - alpha (float): Weighting factor for the rare class (X-points), default=0.25 - gamma (float): Focusing parameter that reduces the loss for well-classified examples, default=2.0 - reduction (str): 'mean' or 'sum', how to reduce the loss over the batch - """ - super().__init__() - self.alpha = alpha - self.gamma = gamma - self.reduction = reduction - - def forward(self, inputs, targets): - """ - inputs: Model predictions (logits, before sigmoid), shape [N, 1, H, W] - targets: Ground truth binary masks, shape [N, 1, H, W] - """ - # Apply sigmoid to get probabilities - bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') - - # Get probabilities for positive class - probs = torch.sigmoid(inputs) - # For targets=1 (X-points), pt = p; for targets=0 (non-X-points), pt = 1-p - pt = torch.where(targets == 1, probs, 1 - probs) - - # Apply focusing parameter - focal_weight = (1 - pt) ** self.gamma - - # Apply alpha weighting: alpha for X-points, (1-alpha) for non-X-points - alpha_weight = torch.where(targets == 1, self.alpha, 1 - self.alpha) - - # Combine all factors - focal_loss = alpha_weight * focal_weight * bce_loss - - # Apply reduction - if self.reduction == 'mean': - return focal_loss.mean() - elif self.reduction == 'sum': - return focal_loss.sum() - else: - return focal_loss - + return out +# DICE LOSS class DiceLoss(nn.Module): def __init__(self, smooth=1.0, eps=1e-7): """ @@ -461,9 +480,60 @@ def forward(self, inputs, targets): # Return Dice loss (1 - Dice coefficient) return 1.0 - dice -# PLOTTING FUNCTION +# TRAIN & VALIDATION UTILS +def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype): + model.train() + running_loss = 0.0 + + for batch in loader: + all_data, mask = batch["all"].to(device), batch["mask"].to(device) + + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred = model(all_data) + loss = criterion(pred, mask) + + if not torch.isfinite(loss): + print(f"Warning: Non-finite loss detected (loss = {loss.item()}). Skipping batch.") + continue + + optimizer.zero_grad() + + if use_amp and scaler is not None: # float16 path + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + elif use_amp: # bfloat16 path (no scaler) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + else: # Standard float32 path + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + running_loss += loss.item() + + return running_loss / len(loader) if len(loader) > 0 else 0.0 + +def validate_one_epoch(model, loader, criterion, device, use_amp, amp_dtype): + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for batch in loader: + all_data, mask = batch["all"].to(device), batch["mask"].to(device) + + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred = model(all_data) + loss = criterion(pred, mask) + + val_loss += loss.item() + return val_loss / len(loader) if len(loader) > 0 else 0.0 + +# PLOTTING FUNCTIONS def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, - reflectionAxis, filenameBase, interpFac, + reflectionAxis, filenameBase, interpFac, xpoint_mask=None, titleExtra="", outDir="plots", @@ -505,7 +575,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, 'xk' ) - # Save the figure if needed (could be removed as we save anyway) + # Save the figure if needed if saveFig: basename = os.path.basename(filenameBase) saveFilename = os.path.join( @@ -518,7 +588,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, plt.close() def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, - outDir="plots", saveFig=True): + outDir="plots", saveFig=True): """ Visualize model performance comparing predictions with ground truth: - True Positives (green) @@ -636,44 +706,50 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi def parseCommandLineArgs(): parser = argparse.ArgumentParser(description='ML-based reconnection classifier') parser.add_argument('--learningRate', type=float, default=1e-5, - help='specify the learning rate') + help='specify the learning rate') parser.add_argument('--batchSize', type=int, default=1, - help='specify the batch size') + help='specify the batch size') parser.add_argument('--epochs', type=int, default=2000, - help='specify the number of epochs') + help='specify the number of epochs') parser.add_argument('--trainFrameFirst', type=int, default=1, - help='specify the number of the first frame used for training') + help='specify the number of the first frame used for training') parser.add_argument('--trainFrameLast', type=int, default=140, - help='specify the number of the last frame (exclusive) used for training') + help='specify the number of the last frame (exclusive) used for training') parser.add_argument('--validationFrameFirst', type=int, default=141, - help='specify the number of the first frame used for validation') + help='specify the number of the first frame used for validation') parser.add_argument('--validationFrameLast', type=int, default=150, - help='specify the number of the last frame (exclusive) used for validation') + help='specify the number of the last frame (exclusive) used for validation') parser.add_argument('--minTrainingLoss', type=int, default=3, - help=''' - minimum reduction in training loss in orders of magnitude, - set to 0 to disable the check - ''') + help=''' + minimum reduction in training loss in orders of magnitude, + set to 0 to disable the check (default: 3) + ''') parser.add_argument('--checkPointFrequency', type=int, default=10, - help='number of epochs between checkpoints') + help='number of epochs between checkpoints') parser.add_argument('--paramFile', type=Path, default=None, - help=''' - specify the path to the parameter txt file, the parent - directory of that file must contain the gkyl input training data - ''') + help=''' + specify the path to the parameter txt file, the parent + directory of that file must contain the gkyl input training data + ''') parser.add_argument('--xptCacheDir', type=Path, default=None, - help=''' - specify the path to a directory that will be used to cache - the outputs of the analytic Xpoint finder - ''') + help=''' + specify the path to a directory that will be used to cache + the outputs of the analytic Xpoint finder + ''') parser.add_argument('--plot', action=argparse.BooleanOptionalAction, - help='create figures of the ground truth X-points and model identified X-points') + help='create figures of the ground truth X-points and model identified X-points') parser.add_argument('--plotDir', type=Path, default="./plots", - help='directory where figures are written') + help='directory where figures are written') + parser.add_argument('--use-amp', action='store_true', + help='use automatic mixed precision training') + parser.add_argument('--amp-dtype', type=str, default='bfloat16', + choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)') + parser.add_argument('--patience', type=int, default=15, + help='patience for early stopping (default: 15)') # CI TEST: Add smoke test flag parser.add_argument('--smoke-test', action='store_true', - help='Run a minimal smoke test for CI (overrides other parameters)') + help='Run a minimal smoke test for CI (overrides other parameters)') args = parser.parse_args() return args @@ -734,7 +810,7 @@ def printCommandLineArgs(args): print("}") # Function to save model checkpoint -def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints"): +def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints", scaler=None, best_val_loss=None): """ Save model checkpoint including model state, optimizer state, and training metrics @@ -745,6 +821,8 @@ def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpo val_loss: List of validation losses epoch: Current epoch number checkpoint_dir: Directory to save checkpoints + scaler: GradScaler instance if using AMP + best_val_loss: Best validation loss so far """ os.makedirs(checkpoint_dir, exist_ok=True) @@ -756,9 +834,14 @@ def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpo 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, - 'val_loss': val_loss + 'val_loss': val_loss, + 'best_val_loss': best_val_loss } + # Save scaler state if using AMP + if scaler is not None: + checkpoint['scaler_state_dict'] = scaler.state_dict() + # Save checkpoint torch.save(checkpoint, checkpoint_path) print(f"Model checkpoint saved at epoch {epoch} to {checkpoint_path}") @@ -776,7 +859,7 @@ def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpo raise e # Function to load model checkpoint -def load_model_checkpoint(model, optimizer, checkpoint_path): +def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None): """ Load model checkpoint @@ -784,6 +867,7 @@ def load_model_checkpoint(model, optimizer, checkpoint_path): model: The neural network model to load weights into optimizer: The optimizer to load state into checkpoint_path: Path to the checkpoint file + scaler: GradScaler instance if using AMP Returns: model: Updated model with loaded weights @@ -791,13 +875,15 @@ def load_model_checkpoint(model, optimizer, checkpoint_path): epoch: Last saved epoch number train_loss: List of training losses val_loss: List of validation losses + scaler: Updated scaler if using AMP + best_val_loss: Best validation loss from checkpoint """ if not os.path.exists(checkpoint_path): print(f"No checkpoint found at {checkpoint_path}") - return model, optimizer, 0, [], [] + return model, optimizer, 0, [], [], scaler, float('inf') print(f"Loading checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, weights_only=False) # Need False for optimizer state model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) @@ -805,9 +891,14 @@ def load_model_checkpoint(model, optimizer, checkpoint_path): epoch = checkpoint['epoch'] train_loss = checkpoint['train_loss'] val_loss = checkpoint['val_loss'] + best_val_loss = checkpoint.get('best_val_loss', float('inf')) + + # Load scaler state if available + if scaler is not None and 'scaler_state_dict' in checkpoint: + scaler.load_state_dict(checkpoint['scaler_state_dict']) print(f"Loaded checkpoint from epoch {epoch}") - return model, optimizer, epoch, train_loss, val_loss + return model, optimizer, epoch, train_loss, val_loss, scaler, best_val_loss def main(): @@ -857,25 +948,60 @@ def main(): # Original data loading train_fnums = range(args.trainFrameFirst, args.trainFrameLast) val_fnums = range(args.validationFrameFirst, args.validationFrameLast) - + + print(f"Loading training data from frames {args.trainFrameFirst} to {args.trainFrameLast-1}") + print(f"Loading validation data from frames {args.validationFrameFirst} to {args.validationFrameLast-1}") + train_dataset = XPointDataset(args.paramFile, train_fnums, - xptCacheDir=args.xptCacheDir, rotateAndReflect=True) + xptCacheDir=args.xptCacheDir, rotateAndReflect=True) val_dataset = XPointDataset(args.paramFile, val_fnums, - xptCacheDir=args.xptCacheDir) + xptCacheDir=args.xptCacheDir) + + # Use consistent pos_ratio for both training and validation + train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.5, retries=30) + val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.5, retries=30) t1 = timer() print("time (s) to create gkyl data loader: " + str(t1-t0)) print(f"number of training frames (original + augmented): {len(train_dataset)}") print(f"number of validation frames: {len(val_dataset)}") + print(f"number of training patches per epoch: {len(train_crop)}") + print(f"number of validation patches per epoch: {len(val_crop)}") - train_loader = DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False) - val_loader = DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False) + train_loader = DataLoader(train_crop, batch_size=args.batchSize, shuffle=True, num_workers=0) + val_loader = DataLoader(val_crop, batch_size=args.batchSize, shuffle=False, num_workers=0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = UNet(input_channels=4, base_channels=64).to(device) + print(f"Using device: {device}") + + # Use the improved model + model = UNet(input_channels=4, base_channels=32).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") criterion = DiceLoss(smooth=1.0) - optimizer = optim.Adam(model.parameters(), lr=args.learningRate) + + # Use AdamW optimizer with weight decay for better generalization + optimizer = optim.AdamW(model.parameters(), lr=args.learningRate, weight_decay=1e-5) + + # Learning rate scheduler with cosine annealing + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) + + # --- AMP Setup (bfloat16 aware) --- + use_amp = args.use_amp and torch.cuda.is_available() + amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' and torch.cuda.is_bf16_supported() else torch.float16 + + # GradScaler is ONLY needed for float16, not bfloat16 + scaler = GradScaler(enabled=(use_amp and amp_dtype == torch.float16)) + + if use_amp: + if args.amp_dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): + print("Warning: bfloat16 not supported on this GPU. Falling back to float16.") + print(f"Using Automatic Mixed Precision (AMP) with dtype: {amp_dtype}") checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) @@ -883,27 +1009,55 @@ def main(): start_epoch = 0 train_loss = [] val_loss = [] + best_val_loss = float('inf') if os.path.exists(latest_checkpoint_path) and not args.smoke_test: model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint( model, optimizer, latest_checkpoint_path ) print(f"Resuming training from epoch {start_epoch+1}") + print(f"Best validation loss so far: {best_val_loss:.6f}") else: print("Starting training from scratch") t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) + # Early stopping setup + patience_counter = 0 + num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): - train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) - val_loss.append(validate_one_epoch(model, val_loader, criterion, device)) - print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}") + train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype) + val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype) + + train_loss.append(train_loss_epoch) + val_loss.append(val_loss_epoch) - # Save model checkpoint after each epoch - if epoch % args.checkPointFrequency == 0: - save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir) + current_lr = optimizer.param_groups[0]['lr'] + print(f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}") + + # Learning rate scheduling + scheduler.step() + + # Check for improvement + if val_loss[-1] < best_val_loss: + best_val_loss = val_loss[-1] + patience_counter = 0 + print(f" New best validation loss: {best_val_loss:.6f}") + # Save best model + torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pt")) + else: + patience_counter += 1 + + # Save checkpoint periodically + if (epoch+1) % args.checkPointFrequency == 0: + save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir, scaler, best_val_loss) + + # Early stopping + if patience_counter >= args.patience: + print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})") + break plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) @@ -967,11 +1121,18 @@ def main(): else: return 0 - requiredLossDecreaseMagnitude = args.minTrainingLoss - if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude: - print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: " - f"initial {train_loss[0]} final {train_loss[-1]} ... exiting") - return 1; + # Check training progress + if len(train_loss) > 1 and train_loss[-1] > 0 and train_loss[0] > 0: + loss_reduction = np.log10(abs(train_loss[0]/train_loss[-1])) + print(f"Training loss reduced by {loss_reduction:.2f} orders of magnitude") + if args.minTrainingLoss > 0 and loss_reduction < args.minTrainingLoss: + print(f"Warning: TrainLoss reduced by less than {args.minTrainingLoss} orders of magnitude") + + # Load best model for evaluation + best_model_path = os.path.join(checkpoint_dir, "best_model.pt") + if os.path.exists(best_model_path): + print("Loading best model for evaluation...") + model.load_state_dict(torch.load(best_model_path, weights_only=True)) # (D) Plotting after training model.eval() # switch to inference mode @@ -989,64 +1150,59 @@ def main(): t4 = timer() with torch.no_grad(): - for set in full_dataset: - for item in set: - # item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params - fnum = item["fnum"] - rotation = item["rotation"] - reflectionAxis = item["reflectionAxis"] - psi_np = np.array(item["psi"])[0] - mask_gt = np.array(item["mask"])[0] - x = item["x"] - y = item["y"] - filenameBase = item["filenameBase"] - params = item["params"] - - # Get CNN prediction - all_torch = item["all"].unsqueeze(0).to(device) - pred_mask = model(all_torch) - pred_mask_np = pred_mask[0,0].cpu().numpy() - # Binarize - pred_bin = (pred_mask_np > 0.5).astype(np.float32) - - pred_prob = torch.sigmoid(pred_mask) - pred_prob_np = (pred_prob > 0.5).float().cpu().numpy() - - pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) # Thresholding at 0.5, can be fine tune - - print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") - print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}") - print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}") - print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}") - print(f" Probabilities (after sigmoid) - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") - print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") - print(f" Binary Mask (X_points) - shape: {pred_mask_bin.shape}, min: {pred_mask_bin.min()}, max: {pred_mask_bin.max()}") - - if args.plot : - # Plot GROUND TRUTH - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=mask_gt, - titleExtra="GTXpoints", - outDir=outDir, - saveFig=True - ) - - # Plot CNN PREDICTIONS - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=np.squeeze(pred_mask_bin), - titleExtra="CNNXpoints", - outDir=outDir, - saveFig=True - ) - - pred_prob_np_full = pred_prob.cpu().numpy() - plot_model_performance( - psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, - outDir=outDir, - saveFig=True - ) + for dataset in full_dataset: + for item in dataset: + fnum = item["fnum"] + rotation = item["rotation"] + reflectionAxis = item["reflectionAxis"] + psi_np = np.array(item["psi"])[0] + mask_gt = np.array(item["mask"])[0] + x = item["x"] + y = item["y"] + filenameBase = item["filenameBase"] + params = item["params"] + + # Get CNN prediction + all_torch = item["all"].unsqueeze(0).to(device) + + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred_mask = model(all_torch) + pred_prob = torch.sigmoid(pred_mask) + + # Convert to float32 before numpy conversion (fixes BFloat16 error) + pred_mask_np = pred_mask[0,0].float().cpu().numpy() + pred_prob_np = pred_prob.float().cpu().numpy() + + pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32) + + print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") + print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") + print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") + + if args.plot: + # Plot GROUND TRUTH + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=mask_gt, + titleExtra="GTXpoints", + outDir=outDir, + saveFig=True + ) + + # Plot CNN PREDICTIONS + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=pred_mask_bin, + titleExtra="CNNXpoints", + outDir=outDir, + saveFig=True + ) + + plot_model_performance( + psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, + outDir=outDir, + saveFig=True + ) t5 = timer() print("time (s) to apply model: " + str(t5-t4)) diff --git a/ci_tests.py b/ci_tests.py index eb31555..96d1012 100644 --- a/ci_tests.py +++ b/ci_tests.py @@ -3,6 +3,10 @@ from torch.utils.data import Dataset, DataLoader import torch.optim as optim import os +import sys +# Local import within the function itself, which is a bit clunky +# but the original code did it this way. +# from XPointMLTest import validate_one_epoch class SyntheticXPointDataset(Dataset): """ @@ -50,7 +54,7 @@ def _generate_frame(self, idx): #create synthetic current (Laplacian of psi) jz = -(np.gradient(np.gradient(psi, axis=0), axis=0) + - np.gradient(np.gradient(psi, axis=1), axis=1)) + np.gradient(np.gradient(psi, axis=1), axis=1)) # create X-point mask mask = np.zeros((H, W), dtype=np.float32) @@ -97,18 +101,22 @@ def test_checkpoint_functionality(model, optimizer, device, val_loader, criterio """ # Import locally to prevent circular dependency - from XPointMLTest import validate_one_epoch + from XPointMLTest import validate_one_epoch, autocast print("\n" + "="*60) print("TESTING CHECKPOINT SAVE/LOAD FUNCTIONALITY") print("="*60) - #get initial validation loss + # Get the AMP settings from the model's current state to pass to validate_one_epoch + use_amp = isinstance(scaler, torch.cuda.amp.GradScaler) + amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + # Get initial validation loss model.eval() - initial_loss = validate_one_epoch(model, val_loader, criterion, device) + initial_loss = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype) print(f"Initial validation loss: {initial_loss:.6f}") - #saves checkpoint + # Save checkpoint with the correct AMP components test_checkpoint_path = "smoke_test_checkpoint.pt" checkpoint = { 'model_state_dict': model.state_dict(), @@ -117,24 +125,36 @@ def test_checkpoint_functionality(model, optimizer, device, val_loader, criterio 'test_value': 42 } + if scaler is not None: + checkpoint['scaler_state_dict'] = scaler.state_dict() + torch.save(checkpoint, test_checkpoint_path) print(f"Saved checkpoint to {test_checkpoint_path}") # create new model and optimizer - model2 = UNet(input_channels=4, base_channels=64).to(device) + # NOTE: The base_channels here should match the original model's base_channels (32). + # You had 64, which would cause an error later. Changed to 32. + model2 = UNet(input_channels=4, base_channels=32).to(device) optimizer2 = Adam(model2.parameters(), lr=1e-5) # load checkpoint loaded_checkpoint = torch.load(test_checkpoint_path, map_location=device, weights_only=False) model2.load_state_dict(loaded_checkpoint['model_state_dict']) optimizer2.load_state_dict(loaded_checkpoint['optimizer_state_dict']) + + # Load the scaler state if present + scaler2 = None + if 'scaler_state_dict' in loaded_checkpoint: + scaler2 = torch.cuda.amp.GradScaler() + scaler2.load_state_dict(loaded_checkpoint['scaler_state_dict']) assert loaded_checkpoint['test_value'] == 42, "Test value mismatch!" print("Checkpoint test value verified") #get loaded model validation loss model2.eval() - loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + # Now pass the AMP arguments to validate_one_epoch + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device, use_amp, amp_dtype) print(f"Loaded model validation loss: {loaded_loss:.6f}") # check if losses match diff --git a/test_xpoint_ml.py b/test_xpoint_ml.py index 842a48d..5aba775 100644 --- a/test_xpoint_ml.py +++ b/test_xpoint_ml.py @@ -5,6 +5,7 @@ import os import pytest +# Make sure all required functions are imported from the main file from XPointMLTest import UNet, DiceLoss, expand_xpoints_mask, validate_one_epoch from ci_tests import SyntheticXPointDataset @@ -54,7 +55,7 @@ def test_dice_loss_no_match(dice_loss): def test_synthetic_dataset_integrity(synthetic_dataset): assert len(synthetic_dataset) == 2 item = synthetic_dataset[0] - expected_keys = ["fnum", "all", "mask", "psi", "x", "y"] + expected_keys = ["fnum", "all", "mask", "psi", "x", "y", "rotation", "reflectionAxis", "filenameBase", "params"] assert all(key in item for key in expected_keys) assert item['all'].shape == (4, 32, 32) assert item['mask'].shape == (1, 32, 32) @@ -88,13 +89,14 @@ def test_checkpoint_save_load(unet_model, synthetic_dataset): optimizer = optim.Adam(model.parameters(), lr=1e-5) criterion = DiceLoss() - #create a simple dataloader + # Create a simple dataloader val_loader = DataLoader(synthetic_dataset, batch_size=1, shuffle=False) - #get initial loss - initial_loss = validate_one_epoch(model, val_loader, criterion, device) + # get initial loss, passing the required AMP arguments + # we can assume no AMP for this CPU-based unit test + initial_loss = validate_one_epoch(model, val_loader, criterion, device, use_amp=False, amp_dtype=torch.float32) - #save checkpoint + # Save checkpoint test_checkpoint_path = "test_checkpoint_pytest.pt" checkpoint = { 'model_state_dict': model.state_dict(), @@ -104,7 +106,7 @@ def test_checkpoint_save_load(unet_model, synthetic_dataset): } torch.save(checkpoint, test_checkpoint_path) - #create new model and load + # Create new model and load model2 = UNet(input_channels=4, base_channels=16).to(device) optimizer2 = optim.Adam(model2.parameters(), lr=1e-5) @@ -114,14 +116,14 @@ def test_checkpoint_save_load(unet_model, synthetic_dataset): assert loaded_checkpoint['test_value'] == 42 - #get loaded model loss - loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + # Get loaded model loss, again passing the AMP arguments + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device, use_amp=False, amp_dtype=torch.float32) - #check if losses match + # Check if losses match loss_diff = abs(initial_loss - loaded_loss) assert loss_diff < 1e-6, f"Loss difference too large: {loss_diff}" - #cleanup + # Cleanup if os.path.exists(test_checkpoint_path): os.remove(test_checkpoint_path)