Skip to content

Commit e00d3a2

Browse files
committed
feat(scripts): auto-detect base branch for license formatting
1 parent 0ce324c commit e00d3a2

File tree

1 file changed

+75
-19
lines changed

1 file changed

+75
-19
lines changed

scripts/format_license.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,42 +85,97 @@
8585
'chart/*' # Anything inside chart directory
8686
]
8787

88-
def get_changed_files_from_git(root_dir: str) -> List[str]:
88+
def get_default_branch(root_dir: str) -> str:
8989
"""
90-
Get list of files that have uncommitted changes in git.
91-
90+
Get the default branch name from git config.
91+
9292
Returns:
93-
List of absolute file paths with changes
93+
Default branch name (e.g., 'main', 'master')
9494
"""
9595
try:
96-
# Get modified and staged files
96+
# Try to get remote HEAD (most reliable method)
9797
result = subprocess.run(
98-
['git', 'diff', '--name-only', 'HEAD'],
98+
['git', 'symbolic-ref', 'refs/remotes/origin/HEAD'],
9999
capture_output=True,
100100
text=True,
101101
cwd=root_dir,
102102
check=True
103103
)
104-
105-
result2 = subprocess.run(
106-
['git', 'diff', '--name-only', '--cached'],
107-
capture_output=True,
108-
text=True,
109-
cwd=root_dir,
110-
check=True
111-
)
112-
104+
# Returns 'refs/remotes/origin/main', extract 'main'
105+
return result.stdout.strip().split('/')[-1]
106+
except subprocess.CalledProcessError:
107+
# Fallback: try init.defaultBranch config
108+
try:
109+
result = subprocess.run(
110+
['git', 'config', '--get', 'init.defaultBranch'],
111+
capture_output=True,
112+
text=True,
113+
cwd=root_dir,
114+
check=True
115+
)
116+
return result.stdout.strip()
117+
except subprocess.CalledProcessError:
118+
# Ultimate fallback
119+
return 'main'
120+
121+
def get_changed_files_from_git(root_dir: str, base_branch: str = None) -> List[str]:
122+
"""
123+
Get list of files that have uncommitted changes or differ from base branch.
124+
125+
Args:
126+
root_dir: Repository root directory
127+
base_branch: Optional base branch to compare against
128+
129+
Returns:
130+
List of absolute file paths with changes
131+
"""
132+
try:
113133
changed_files = set()
134+
135+
# Always check for uncommitted changes first
136+
result = subprocess.run(
137+
['git', 'diff', '--name-only', 'HEAD'],
138+
capture_output=True, text=True, cwd=root_dir, check=True
139+
)
114140
for line in result.stdout.strip().split('\n'):
115141
if line:
116142
changed_files.add(os.path.abspath(os.path.join(root_dir, line)))
117-
143+
144+
# Check staged files
145+
result2 = subprocess.run(
146+
['git', 'diff', '--name-only', '--cached'],
147+
capture_output=True, text=True, cwd=root_dir, check=True
148+
)
118149
for line in result2.stdout.strip().split('\n'):
119150
if line:
120151
changed_files.add(os.path.abspath(os.path.join(root_dir, line)))
121-
152+
153+
# If no uncommitted changes and no base branch specified, auto-detect
154+
if not changed_files and base_branch is None:
155+
# Get current branch
156+
current_result = subprocess.run(
157+
['git', 'branch', '--show-current'],
158+
capture_output=True, text=True, cwd=root_dir, check=True
159+
)
160+
current_branch = current_result.stdout.strip()
161+
default_branch = get_default_branch(root_dir)
162+
163+
# Only compare to default if we're on a different branch
164+
if current_branch and current_branch != default_branch:
165+
base_branch = default_branch
166+
167+
# If base branch specified or auto-detected, get branch diff
168+
if base_branch:
169+
result3 = subprocess.run(
170+
['git', 'diff', '--name-only', f'{base_branch}...HEAD'],
171+
capture_output=True, text=True, cwd=root_dir, check=True
172+
)
173+
for line in result3.stdout.strip().split('\n'):
174+
if line:
175+
changed_files.add(os.path.abspath(os.path.join(root_dir, line)))
176+
122177
return list(changed_files)
123-
178+
124179
except (subprocess.CalledProcessError, FileNotFoundError):
125180
return []
126181

@@ -534,6 +589,7 @@ def main():
534589
parser.add_argument('--license-file', default='LICENSE', help='Path to the license template file')
535590
parser.add_argument('--root-dir', default='.', help='Root directory to search for files')
536591
parser.add_argument('--year', help='Year to use in SPDX copyright header (default: current year)')
592+
parser.add_argument('--base-branch', help='Base branch to compare against (auto-detected if not specified)')
537593
parser.add_argument('--verbose', action='store_true', help='Show detailed messages, including when licenses are already formatted')
538594
args = parser.parse_args()
539595

@@ -542,7 +598,7 @@ def main():
542598
license_text = read_license_template(args.license_file)
543599

544600
# Step 2: Get changed files from git
545-
changed_files = get_changed_files_from_git(args.root_dir)
601+
changed_files = get_changed_files_from_git(args.root_dir, args.base_branch)
546602

547603
if not changed_files:
548604
print("No changed files found.")

0 commit comments

Comments
 (0)