Skip to content

fix: sanitize git clone repo input url #5234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion src/sagemaker/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,78 @@
from __future__ import absolute_import

import os
from pathlib import Path
import re
import subprocess
import tempfile
import warnings
from pathlib import Path
from urllib.parse import urlparse

import six
from six.moves import urllib


def _sanitize_git_url(repo_url):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any sanitization happening here or is this just validation ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't fix anything on the user's behalf so I think validation is a better word

"""Sanitize Git repository URL to prevent URL injection attacks.

Args:
repo_url (str): The Git repository URL to sanitize

Returns:
str: The sanitized URL

Raises:
ValueError: If the URL contains suspicious patterns that could indicate injection
"""
at_count = repo_url.count("@")

if repo_url.startswith("git@"):
# git@ format requires exactly one @
if at_count != 1:
raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol")
elif repo_url.startswith("ssh://"):
# ssh:// format can have 0 or 1 @ symbols
if at_count > 1:
raise ValueError("Invalid SSH URL format: multiple @ symbols detected")
elif repo_url.startswith("https://") or repo_url.startswith("http://"):
# HTTPS format allows 0 or 1 @ symbols
if at_count > 1:
raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected")

# Check for invalid characters in the URL before parsing
# These characters should not appear in legitimate URLs
invalid_chars = ["<", ">", "[", "]", "{", "}", "\\", "^", "`", "|"]
for char in invalid_chars:
if char in repo_url:
raise ValueError("Invalid characters in hostname")

try:
parsed = urlparse(repo_url)

# Check for suspicious characters in hostname that could indicate injection
if parsed.hostname:
# Check for URL-encoded characters that might be used for obfuscation
suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, :
for pattern in suspicious_patterns:
if pattern in parsed.hostname.lower():
raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}")

# Validate that the hostname looks legitimate
if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname):
raise ValueError("Invalid characters in hostname")

except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Failed to parse URL: {str(e)}")

Check warning on line 80 in src/sagemaker/git_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/git_utils.py#L80

Added line #L80 was not covered by tests
else:
raise ValueError(
"Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed"
)

return repo_url


def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
"""Git clone repo containing the training code and serving code.

Expand Down Expand Up @@ -87,6 +151,10 @@
if entry_point is None:
raise ValueError("Please provide an entry point.")
_validate_git_config(git_config)

# SECURITY: Sanitize the repository URL to prevent injection attacks
git_config["repo"] = _sanitize_git_url(git_config["repo"])

dest_dir = tempfile.mkdtemp()
_generate_and_run_clone_command(git_config, dest_dir)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,7 +2794,7 @@ def test_git_support_bad_repo_url_format(sagemaker_session):
)
with pytest.raises(ValueError) as error:
fw.fit()
assert "Invalid Git url provided." in str(error)
assert "Unsupported URL scheme" in str(error)


@patch(
Expand Down
216 changes: 213 additions & 3 deletions tests/unit/test_git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
import os
from pathlib import Path
import subprocess
from mock import patch, ANY
from pathlib import Path

import pytest
from mock import ANY, patch

from sagemaker import git_utils

Expand Down Expand Up @@ -494,3 +495,212 @@ def test_git_clone_repo_codecommit_https_creds_not_stored_locally(tempdir, mkdte
with pytest.raises(subprocess.CalledProcessError) as error:
git_utils.git_clone_repo(git_config, entry_point)
assert "returned non-zero exit status" in str(error.value)


class TestGitUrlSanitization:
"""Test cases for Git URL sanitization to prevent injection attacks."""

def test_sanitize_git_url_valid_https_urls(self):
"""Test that valid HTTPS URLs pass sanitization."""
valid_urls = [
"https://github.com/user/repo.git",
"https://gitlab.com/user/repo.git",
"https://[email protected]/user/repo.git",
"https://user:[email protected]/user/repo.git",
"http://internal-git.company.com/repo.git",
]

for url in valid_urls:
# Should not raise any exception
result = git_utils._sanitize_git_url(url)
assert result == url

def test_sanitize_git_url_valid_ssh_urls(self):
"""Test that valid SSH URLs pass sanitization."""
valid_urls = [
"[email protected]:user/repo.git",
"[email protected]:user/repo.git",
"ssh://[email protected]/user/repo.git",
"ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/", # 0 @ symbols - valid for ssh://
"[email protected]:repo.git",
]

for url in valid_urls:
# Should not raise any exception
result = git_utils._sanitize_git_url(url)
assert result == url

def test_sanitize_git_url_blocks_multiple_at_https(self):
"""Test that HTTPS URLs with multiple @ symbols are blocked."""
malicious_urls = [
"https://[email protected]@github.com/repo.git",
"https://[email protected]@gitlab.com/user/repo.git",
"https://a@b@[email protected]/repo.git",
"https://user@[email protected]/legit/repo.git",
]

for url in malicious_urls:
with pytest.raises(ValueError) as error:
git_utils._sanitize_git_url(url)
assert "multiple @ symbols detected" in str(error.value)

def test_sanitize_git_url_blocks_multiple_at_ssh(self):
"""Test that SSH URLs with multiple @ symbols are blocked."""
malicious_urls = [
"[email protected]@github.com:repo.git",
"git@[email protected]:user/repo.git",
"ssh://git@[email protected]/repo.git",
"git@a@b@c:repo.git",
]

for url in malicious_urls:
with pytest.raises(ValueError) as error:
git_utils._sanitize_git_url(url)
# git@ URLs should give "exactly one @ symbol" error
# ssh:// URLs should give "multiple @ symbols detected" error
assert any(
phrase in str(error.value)
for phrase in ["multiple @ symbols detected", "exactly one @ symbol"]
)

def test_sanitize_git_url_blocks_invalid_schemes_and_git_at_format(self):
"""Test that invalid schemes and git@ format violations are blocked."""
# Test unsupported schemes
unsupported_scheme_urls = [
"git-github.com:user/repo.git", # Doesn't start with git@, ssh://, http://, https://
]

for url in unsupported_scheme_urls:
with pytest.raises(ValueError) as error:
git_utils._sanitize_git_url(url)
assert "Unsupported URL scheme" in str(error.value)

# Test git@ URLs with wrong @ count
invalid_git_at_urls = [
"[email protected]@evil.com:repo.git", # 2 @ symbols
]

for url in invalid_git_at_urls:
with pytest.raises(ValueError) as error:
git_utils._sanitize_git_url(url)
assert "exactly one @ symbol" in str(error.value)

def test_sanitize_git_url_blocks_url_encoding_obfuscation(self):
"""Test that URL-encoded obfuscation attempts are blocked."""
obfuscated_urls = [
"https://github.com%25evil.com/repo.git",
"https://[email protected]%40attacker.com/repo.git",
"https://github.com%2Fevil.com/repo.git",
"https://github.com%3Aevil.com/repo.git",
]

for url in obfuscated_urls:
with pytest.raises(ValueError) as error:
git_utils._sanitize_git_url(url)
# The error could be either suspicious encoding or invalid characters
assert any(
phrase in str(error.value)
for phrase in ["Suspicious URL encoding detected", "Invalid characters in hostname"]
)

def test_sanitize_git_url_blocks_invalid_hostname_chars(self):
"""Test that hostnames with invalid characters are blocked."""
invalid_urls = [
"https://github<script>.com/repo.git",
"https://github>.com/repo.git",
"https://github[].com/repo.git",
"https://github{}.com/repo.git",
]

for url in invalid_urls:
with pytest.raises(ValueError) as error:
git_utils._sanitize_git_url(url)
# The error could be various types due to URL parsing edge cases
assert any(
phrase in str(error.value)
for phrase in [
"Invalid characters in hostname",
"Failed to parse URL",
"does not appear to be an IPv4 or IPv6 address",
]
)

def test_sanitize_git_url_blocks_unsupported_schemes(self):
"""Test that unsupported URL schemes are blocked."""
unsupported_urls = [
"ftp://github.com/repo.git",
"file:///local/repo.git",
"javascript:alert('xss')",
"data:text/html,<script>alert('xss')</script>",
]

for url in unsupported_urls:
with pytest.raises(ValueError) as error:
git_utils._sanitize_git_url(url)
assert "Unsupported URL scheme" in str(error.value)

def test_git_clone_repo_blocks_malicious_https_url(self):
"""Test that git_clone_repo blocks malicious HTTPS URLs."""
malicious_git_config = {
"repo": "https://[email protected]@github.com/legit/repo.git",
"branch": "main",
}
entry_point = "train.py"

with pytest.raises(ValueError) as error:
git_utils.git_clone_repo(malicious_git_config, entry_point)
assert "multiple @ symbols detected" in str(error.value)

def test_git_clone_repo_blocks_malicious_ssh_url(self):
"""Test that git_clone_repo blocks malicious SSH URLs."""
malicious_git_config = {
"repo": "git@[email protected]:sage-maker/temp-sev2.git",
"branch": "main",
}
entry_point = "train.py"

with pytest.raises(ValueError) as error:
git_utils.git_clone_repo(malicious_git_config, entry_point)
assert "exactly one @ symbol" in str(error.value)

def test_git_clone_repo_blocks_url_encoded_attack(self):
"""Test that git_clone_repo blocks URL-encoded attacks."""
malicious_git_config = {
"repo": "https://github.com%40attacker.com/repo.git",
"branch": "main",
}
entry_point = "train.py"

with pytest.raises(ValueError) as error:
git_utils.git_clone_repo(malicious_git_config, entry_point)
assert "Suspicious URL encoding detected" in str(error.value)

def test_sanitize_git_url_comprehensive_attack_scenarios(self):
attack_scenarios = [
# Original PoC attack
"https://USER@YOUR_NGROK_OR_LOCALHOST/[email protected]%25legit%25repo.git",
# Variations of the attack
"https://user@[email protected]/legit/repo.git",
"[email protected]@github.com:user/repo.git",
"ssh://[email protected]@github.com/repo.git",
# URL encoding variations
"https://github.com%40evil.com/repo.git",
"https://[email protected]%2Fevil.com/repo.git",
]

entry_point = "train.py"

for malicious_url in attack_scenarios:
git_config = {"repo": malicious_url}
with pytest.raises(ValueError) as error:
git_utils.git_clone_repo(git_config, entry_point)
# Should be blocked by sanitization
assert any(
phrase in str(error.value)
for phrase in [
"multiple @ symbols detected",
"exactly one @ symbol",
"Suspicious URL encoding detected",
"Invalid characters in hostname",
]
)