Skip to content

Commit 88253a5

Browse files
committed
Merge branch 'main' into cws/bxByJz
2 parents 39c0125 + c380ee5 commit 88253a5

File tree

2 files changed

+191
-2
lines changed

2 files changed

+191
-2
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
name: Check Training Parameters
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
6+
paths:
7+
- 'XPointMLTest.py' # Only trigger when this specific file changes
8+
9+
jobs:
10+
check-training-params:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- name: Checkout PR code
14+
uses: actions/checkout@v3
15+
with:
16+
ref: ${{ github.event.pull_request.head.sha }}
17+
path: pr-code
18+
fetch-depth: 1
19+
20+
- name: Checkout main branch
21+
uses: actions/checkout@v3
22+
with:
23+
ref: main
24+
path: main-code
25+
26+
- name: Set up Python
27+
uses: actions/setup-python@v4
28+
with:
29+
python-version: '3.10'
30+
31+
- name: Install dependencies
32+
run: |
33+
python -m pip install --upgrade pip
34+
pip install pytest
35+
36+
- name: Run parameter check
37+
run: |
38+
python - <<EOF
39+
import re
40+
import sys
41+
42+
# Training parameter patterns to check
43+
params_to_check = {
44+
'epochs': r'--epochs\', type=int, default=(\d+),',
45+
'trainFrameFirst': r'--trainFrameFirst\', type=int, default=(\d+),',
46+
'trainFrameLast': r'--trainFrameLast\', type=int, default=(\d+),',
47+
'validationFrameFirst': r'--validationFrameFirst\', type=int, default=(\d+),',
48+
'validationFrameLast': r'--validationFrameLast\', type=int, default=(\d+),',
49+
}
50+
51+
# Files to check
52+
files_to_check = ['main-code/XPointMLTest.py', 'pr-code/XPointMLTest.py']
53+
54+
main_params = {}
55+
pr_params = {}
56+
57+
# Extract parameters from main branch
58+
with open('main-code/XPointMLTest.py', 'r') as f:
59+
content = f.read()
60+
for param, pattern in params_to_check.items():
61+
match = re.search(pattern, content)
62+
if match:
63+
main_params[param] = int(match.group(1))
64+
else:
65+
print(f"Warning: Could not find parameter '{param}' in main branch code")
66+
67+
# Extract parameters from PR
68+
with open('pr-code/XPointMLTest.py', 'r') as f:
69+
content = f.read()
70+
for param, pattern in params_to_check.items():
71+
match = re.search(pattern, content)
72+
if match:
73+
pr_params[param] = int(match.group(1))
74+
else:
75+
print(f"Warning: Could not find parameter '{param}' in PR code")
76+
77+
# Compare parameters
78+
mismatch = False
79+
for param in params_to_check.keys():
80+
if param in main_params and param in pr_params:
81+
if main_params[param] != pr_params[param]:
82+
print(f"❌ Parameter '{param}' has changed: {main_params[param]} -> {pr_params[param]}")
83+
mismatch = True
84+
else:
85+
print(f"✅ Parameter '{param}' unchanged: {main_params[param]}")
86+
else:
87+
print(f"⚠️ Could not compare '{param}' - missing from one or both branches")
88+
89+
# Summary
90+
print("\n=== Parameter Check Summary ===")
91+
if mismatch:
92+
print("❌ Training parameters have been modified!")
93+
print("Detected changes to training configuration parameters.")
94+
print("Please verify these changes are intentional and approved.")
95+
sys.exit(1)
96+
else:
97+
print("✅ All training parameters match the main branch!")
98+
sys.exit(0)
99+
EOF

XPointMLTest.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,78 @@ def printCommandLineArgs(args):
709709
print(f" {arg}: {getattr(args, arg)}")
710710
print("}")
711711

712+
# Function to save model checkpoint
713+
def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints"):
714+
"""
715+
Save model checkpoint including model state, optimizer state, and training metrics
716+
717+
Parameters:
718+
model: The neural network model
719+
optimizer: The optimizer used for training
720+
train_loss: List of training losses
721+
val_loss: List of validation losses
722+
epoch: Current epoch number
723+
checkpoint_dir: Directory to save checkpoints
724+
"""
725+
os.makedirs(checkpoint_dir, exist_ok=True)
726+
727+
checkpoint_path = os.path.join(checkpoint_dir, f"xpoint_model_epoch_{epoch}.pt")
728+
729+
# Create checkpoint dictionary
730+
checkpoint = {
731+
'epoch': epoch,
732+
'model_state_dict': model.state_dict(),
733+
'optimizer_state_dict': optimizer.state_dict(),
734+
'train_loss': train_loss,
735+
'val_loss': val_loss
736+
}
737+
738+
# Save checkpoint
739+
torch.save(checkpoint, checkpoint_path)
740+
print(f"Model checkpoint saved at epoch {epoch} to {checkpoint_path}")
741+
742+
# Save the latest model separately for easy loading
743+
latest_path = os.path.join(checkpoint_dir, "xpoint_model_latest.pt")
744+
torch.save(checkpoint, latest_path)
745+
print(f"Latest model saved to {latest_path}")
746+
747+
748+
749+
# Function to load model checkpoint
750+
def load_model_checkpoint(model, optimizer, checkpoint_path):
751+
"""
752+
Load model checkpoint
753+
754+
Parameters:
755+
model: The neural network model to load weights into
756+
optimizer: The optimizer to load state into
757+
checkpoint_path: Path to the checkpoint file
758+
759+
Returns:
760+
model: Updated model with loaded weights
761+
optimizer: Updated optimizer with loaded state
762+
epoch: Last saved epoch number
763+
train_loss: List of training losses
764+
val_loss: List of validation losses
765+
"""
766+
if not os.path.exists(checkpoint_path):
767+
print(f"No checkpoint found at {checkpoint_path}")
768+
return model, optimizer, 0, [], []
769+
770+
print(f"Loading checkpoint from {checkpoint_path}")
771+
checkpoint = torch.load(checkpoint_path)
772+
773+
model.load_state_dict(checkpoint['model_state_dict'])
774+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
775+
776+
epoch = checkpoint['epoch']
777+
train_loss = checkpoint['train_loss']
778+
val_loss = checkpoint['val_loss']
779+
780+
print(f"Loaded checkpoint from epoch {epoch}")
781+
return model, optimizer, epoch, train_loss, val_loss
782+
783+
712784
def main():
713785
args = parseCommandLineArgs()
714786
checkCommandLineArgs(args)
@@ -741,17 +813,35 @@ def main():
741813
criterion = DiceLoss(smooth=1.0)
742814
optimizer = optim.Adam(model.parameters(), lr=args.learningRate)
743815

816+
checkpoint_dir = "checkpoints"
817+
os.makedirs(checkpoint_dir, exist_ok=True)
818+
latest_checkpoint_path = os.path.join(checkpoint_dir, "xpoint_model_latest.pt")
819+
start_epoch = 0
820+
train_loss = []
821+
val_loss = []
822+
823+
if os.path.exists(latest_checkpoint_path):
824+
model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint(
825+
model, optimizer, latest_checkpoint_path
826+
)
827+
print(f"Resuming training from epoch {start_epoch+1}")
828+
else:
829+
print("Starting training from scratch")
830+
744831
t2 = timer()
745832
print("time (s) to prepare model: " + str(t2-t1))
746833

747834
train_loss = []
748835
val_loss = []
749-
836+
750837
num_epochs = args.epochs
751-
for epoch in range(num_epochs):
838+
for epoch in range(start_epoch, num_epochs):
752839
train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device))
753840
val_loss.append(validate_one_epoch(model, val_loader, criterion, device))
754841
print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}")
842+
843+
# Save model checkpoint after each epoch
844+
save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir)
755845

756846
plot_training_history(train_loss, val_loss)
757847
print("time (s) to train model: " + str(timer()-t2))

0 commit comments

Comments
 (0)