1
+ import argparse
2
+ import json
3
+ import os
4
+ import semver
5
+ import sys
6
+ from github import Github
7
+ from packaging .version import Version
8
+
9
+ class NightlyBuildHelper :
10
+ def __init__ (self ):
11
+ """Initialize with GitHub credentials from environment variables."""
12
+ token = os .environ .get ('GH_TOKEN' )
13
+ repo_name = os .environ .get ('GITHUB_REPOSITORY' )
14
+
15
+ if not token :
16
+ raise ValueError ("GH_TOKEN environment variable is required" )
17
+ if not repo_name :
18
+ raise ValueError ("GITHUB_REPOSITORY environment variable is required" )
19
+
20
+ self .repo = Github (token ).get_repo (repo_name )
21
+ self .schedule_variable , self .current_schedule = self ._load_schedule ()
22
+
23
+ def _load_schedule (self ):
24
+ """Load the schedule from GitHub Actions variable."""
25
+ try :
26
+ # Get the NIGHTLY_BUILD_SCHEDULE variable
27
+ schedule_var = self .repo .get_variable ("NIGHTLY_BUILD_SCHEDULE" )
28
+ schedule = json .loads (schedule_var .value )
29
+
30
+ # Initialize lists if they don't exist
31
+ schedule .setdefault ("active_nightly_builds" , [])
32
+ schedule .setdefault ("patch_base_versions" , [])
33
+ schedule .setdefault ("minor_base_versions" , [])
34
+ return schedule_var , schedule
35
+ except Exception as e :
36
+ print (f"Error loading schedule from GitHub: { e } " )
37
+ sys .exit (1 )
38
+
39
+ def _save_schedule (self ):
40
+ """Save the current schedule to GitHub Actions variable."""
41
+ try :
42
+ schedule_json = json .dumps (self .current_schedule , indent = 4 , sort_keys = True )
43
+ print (f"Updated schedule: { schedule_json } " )
44
+ # Update the existing variable
45
+ self .schedule_variable .edit (schedule_json )
46
+ print ("Successfully updated NIGHTLY_BUILD_SCHEDULE variable: https://github.com/aws/sagemaker-distribution/settings/variables/actions" )
47
+ except Exception as e :
48
+ print (f"Error saving schedule to GitHub: { e } " )
49
+ sys .exit (1 )
50
+
51
+ def _sort_lists (self ):
52
+ """Sort all version lists in the schedule."""
53
+ self .current_schedule ["active_nightly_builds" ].sort (key = Version )
54
+ self .current_schedule ["patch_base_versions" ].sort (key = Version )
55
+ self .current_schedule ["minor_base_versions" ].sort (key = Version )
56
+
57
+ def remove_version (self , version ):
58
+ """Remove a version from active builds and update base versions accordingly."""
59
+ version_obj = semver .VersionInfo .parse (version )
60
+ if str (version_obj ) != version :
61
+ raise ValueError (f"Version must be in x.y.z format, got: { version } " )
62
+
63
+ print (f"Current schedule: { json .dumps (self .current_schedule , indent = 4 , sort_keys = True )} " )
64
+ print (f"Removing version: { version } " )
65
+
66
+ if version not in self .current_schedule ["active_nightly_builds" ]:
67
+ print (f"Version { version } not found in active nightly builds schedule." )
68
+ return
69
+
70
+ # Remove from active builds
71
+ self .current_schedule ["active_nightly_builds" ].remove (version )
72
+
73
+ if version_obj .patch == 0 : # Handling minor version
74
+ # Remove previous minor version from minor_base_versions
75
+ self .current_schedule ["minor_base_versions" ] = [
76
+ v for v in self .current_schedule ["minor_base_versions" ]
77
+ if not v .startswith (f"{ version_obj .major } .{ version_obj .minor - 1 } " )
78
+ ]
79
+ else : # Handling patch version
80
+ prev_version = str (version_obj .replace (patch = version_obj .patch - 1 ))
81
+ if prev_version in self .current_schedule ["patch_base_versions" ]:
82
+ self .current_schedule ["patch_base_versions" ].remove (prev_version )
83
+
84
+ self ._sort_lists ()
85
+ self ._save_schedule ()
86
+
87
+ def add_next_versions (self , version ):
88
+ """Add next version(s) based on the removed version."""
89
+ version_obj = semver .VersionInfo .parse (version )
90
+ if str (version_obj ) != version :
91
+ raise ValueError (f"Version must be in x.y.z format, got: { version } " )
92
+
93
+ print (f"Current schedule: { self .current_schedule } " )
94
+ print (f"Adding next versions for released: { version } " )
95
+
96
+ next_versions = [str (version_obj .bump_patch ())]
97
+ if version_obj .patch == 0 : # Handling minor version
98
+ next_versions .append (str (version_obj .bump_minor ()))
99
+ self .current_schedule ["active_nightly_builds" ].extend (next_versions )
100
+ self .current_schedule ["patch_base_versions" ].append (version )
101
+ self .current_schedule ["minor_base_versions" ].append (version )
102
+ else : # Handling patch version
103
+ self .current_schedule ["active_nightly_builds" ].extend (next_versions )
104
+ self .current_schedule ["patch_base_versions" ].append (version )
105
+ prev_version = str (version_obj .replace (patch = version_obj .patch - 1 ))
106
+ if prev_version in self .current_schedule ["minor_base_versions" ]:
107
+ self .current_schedule ["minor_base_versions" ].remove (prev_version )
108
+ self .current_schedule ["minor_base_versions" ].append (version )
109
+
110
+ self ._sort_lists ()
111
+ self ._save_schedule ()
112
+
113
+ def main ():
114
+ parser = argparse .ArgumentParser (description = 'Nightly build helper tool' )
115
+ subparsers = parser .add_subparsers (dest = 'command' , help = 'Commands' )
116
+
117
+ # Remove version command
118
+ remove_parser = subparsers .add_parser ('remove-version' ,
119
+ help = 'Remove a version from active builds' )
120
+ remove_parser .add_argument ('version' ,
121
+ help = 'Version to remove (e.g., 1.2.3)' )
122
+
123
+ # Add next versions command
124
+ add_parser = subparsers .add_parser ('add-next-versions' ,
125
+ help = 'Add next version(s) based on released version' )
126
+ add_parser .add_argument ('version' ,
127
+ help = 'Version that was released (e.g., 1.2.3)' )
128
+
129
+ args = parser .parse_args ()
130
+
131
+ if not args .command :
132
+ parser .print_help ()
133
+ sys .exit (1 )
134
+
135
+ try :
136
+ helper = NightlyBuildHelper ()
137
+
138
+ if args .command == 'remove-version' :
139
+ helper .remove_version (args .version )
140
+ elif args .command == 'add-next-versions' :
141
+ helper .add_next_versions (args .version )
142
+ else :
143
+ print (f"Unknown command: { args .command } " )
144
+ parser .print_help ()
145
+ sys .exit (1 )
146
+ except Exception as e :
147
+ print (f"Error: { e } " )
148
+ sys .exit (1 )
149
+
150
+ if __name__ == "__main__" :
151
+ main ()
0 commit comments