-
Notifications
You must be signed in to change notification settings - Fork 0
99 lines (86 loc) · 2.96 KB
/
checkTrainingParams.yml
File metadata and controls
99 lines (86 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
name: Check Training Parameters
on:
pull_request:
branches: [ main ]
paths:
- 'XPointMLTest.py' # Only trigger when this specific file changes
jobs:
check-training-params:
runs-on: ubuntu-latest
steps:
- name: Checkout PR code
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
path: pr-code
fetch-depth: 1
- name: Checkout main branch
uses: actions/checkout@v3
with:
ref: main
path: main-code
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
- name: Run parameter check
run: |
python - <<EOF
import re
import sys
# Training parameter patterns to check
params_to_check = {
'epochs': r'--epochs\', type=int, default=(\d+),',
'trainFrameFirst': r'--trainFrameFirst\', type=int, default=(\d+),',
'trainFrameLast': r'--trainFrameLast\', type=int, default=(\d+),',
'validationFrameFirst': r'--validationFrameFirst\', type=int, default=(\d+),',
'validationFrameLast': r'--validationFrameLast\', type=int, default=(\d+),',
}
# Files to check
files_to_check = ['main-code/XPointMLTest.py', 'pr-code/XPointMLTest.py']
main_params = {}
pr_params = {}
# Extract parameters from main branch
with open('main-code/XPointMLTest.py', 'r') as f:
content = f.read()
for param, pattern in params_to_check.items():
match = re.search(pattern, content)
if match:
main_params[param] = int(match.group(1))
else:
print(f"Warning: Could not find parameter '{param}' in main branch code")
# Extract parameters from PR
with open('pr-code/XPointMLTest.py', 'r') as f:
content = f.read()
for param, pattern in params_to_check.items():
match = re.search(pattern, content)
if match:
pr_params[param] = int(match.group(1))
else:
print(f"Warning: Could not find parameter '{param}' in PR code")
# Compare parameters
mismatch = False
for param in params_to_check.keys():
if param in main_params and param in pr_params:
if main_params[param] != pr_params[param]:
print(f"❌ Parameter '{param}' has changed: {main_params[param]} -> {pr_params[param]}")
mismatch = True
else:
print(f"✅ Parameter '{param}' unchanged: {main_params[param]}")
else:
print(f"⚠️ Could not compare '{param}' - missing from one or both branches")
# Summary
print("\n=== Parameter Check Summary ===")
if mismatch:
print("❌ Training parameters have been modified!")
print("Detected changes to training configuration parameters.")
print("Please verify these changes are intentional and approved.")
sys.exit(1)
else:
print("✅ All training parameters match the main branch!")
sys.exit(0)
EOF