8585 'chart/*' # Anything inside chart directory
8686]
8787
88- def get_changed_files_from_git (root_dir : str ) -> List [ str ] :
88+ def get_default_branch (root_dir : str ) -> str :
8989 """
90- Get list of files that have uncommitted changes in git.
91-
90+ Get the default branch name from git config .
91+
9292 Returns:
93- List of absolute file paths with changes
93+ Default branch name (e.g., 'main', 'master')
9494 """
9595 try :
96- # Get modified and staged files
96+ # Try to get remote HEAD (most reliable method)
9797 result = subprocess .run (
98- ['git' , 'diff' , '--name-only ' , 'HEAD' ],
98+ ['git' , 'symbolic-ref ' , 'refs/remotes/origin/ HEAD' ],
9999 capture_output = True ,
100100 text = True ,
101101 cwd = root_dir ,
102102 check = True
103103 )
104-
105- result2 = subprocess .run (
106- ['git' , 'diff' , '--name-only' , '--cached' ],
107- capture_output = True ,
108- text = True ,
109- cwd = root_dir ,
110- check = True
111- )
112-
104+ # Returns 'refs/remotes/origin/main', extract 'main'
105+ return result .stdout .strip ().split ('/' )[- 1 ]
106+ except subprocess .CalledProcessError :
107+ # Fallback: try init.defaultBranch config
108+ try :
109+ result = subprocess .run (
110+ ['git' , 'config' , '--get' , 'init.defaultBranch' ],
111+ capture_output = True ,
112+ text = True ,
113+ cwd = root_dir ,
114+ check = True
115+ )
116+ return result .stdout .strip ()
117+ except subprocess .CalledProcessError :
118+ # Ultimate fallback
119+ return 'main'
120+
121+ def get_changed_files_from_git (root_dir : str , base_branch : str = None ) -> List [str ]:
122+ """
123+ Get list of files that have uncommitted changes or differ from base branch.
124+
125+ Args:
126+ root_dir: Repository root directory
127+ base_branch: Optional base branch to compare against
128+
129+ Returns:
130+ List of absolute file paths with changes
131+ """
132+ try :
113133 changed_files = set ()
134+
135+ # Always check for uncommitted changes first
136+ result = subprocess .run (
137+ ['git' , 'diff' , '--name-only' , 'HEAD' ],
138+ capture_output = True , text = True , cwd = root_dir , check = True
139+ )
114140 for line in result .stdout .strip ().split ('\n ' ):
115141 if line :
116142 changed_files .add (os .path .abspath (os .path .join (root_dir , line )))
117-
143+
144+ # Check staged files
145+ result2 = subprocess .run (
146+ ['git' , 'diff' , '--name-only' , '--cached' ],
147+ capture_output = True , text = True , cwd = root_dir , check = True
148+ )
118149 for line in result2 .stdout .strip ().split ('\n ' ):
119150 if line :
120151 changed_files .add (os .path .abspath (os .path .join (root_dir , line )))
121-
152+
153+ # If no uncommitted changes and no base branch specified, auto-detect
154+ if not changed_files and base_branch is None :
155+ # Get current branch
156+ current_result = subprocess .run (
157+ ['git' , 'branch' , '--show-current' ],
158+ capture_output = True , text = True , cwd = root_dir , check = True
159+ )
160+ current_branch = current_result .stdout .strip ()
161+ default_branch = get_default_branch (root_dir )
162+
163+ # Only compare to default if we're on a different branch
164+ if current_branch and current_branch != default_branch :
165+ base_branch = default_branch
166+
167+ # If base branch specified or auto-detected, get branch diff
168+ if base_branch :
169+ result3 = subprocess .run (
170+ ['git' , 'diff' , '--name-only' , f'{ base_branch } ...HEAD' ],
171+ capture_output = True , text = True , cwd = root_dir , check = True
172+ )
173+ for line in result3 .stdout .strip ().split ('\n ' ):
174+ if line :
175+ changed_files .add (os .path .abspath (os .path .join (root_dir , line )))
176+
122177 return list (changed_files )
123-
178+
124179 except (subprocess .CalledProcessError , FileNotFoundError ):
125180 return []
126181
@@ -534,6 +589,7 @@ def main():
534589 parser .add_argument ('--license-file' , default = 'LICENSE' , help = 'Path to the license template file' )
535590 parser .add_argument ('--root-dir' , default = '.' , help = 'Root directory to search for files' )
536591 parser .add_argument ('--year' , help = 'Year to use in SPDX copyright header (default: current year)' )
592+ parser .add_argument ('--base-branch' , help = 'Base branch to compare against (auto-detected if not specified)' )
537593 parser .add_argument ('--verbose' , action = 'store_true' , help = 'Show detailed messages, including when licenses are already formatted' )
538594 args = parser .parse_args ()
539595
@@ -542,7 +598,7 @@ def main():
542598 license_text = read_license_template (args .license_file )
543599
544600 # Step 2: Get changed files from git
545- changed_files = get_changed_files_from_git (args .root_dir )
601+ changed_files = get_changed_files_from_git (args .root_dir , args . base_branch )
546602
547603 if not changed_files :
548604 print ("No changed files found." )
0 commit comments