Skip to content

Commit d2d9e19

Browse files
committed
fix formatting
1 parent 171b8a1 commit d2d9e19

File tree

2 files changed

+55
-46
lines changed

2 files changed

+55
-46
lines changed

src/sagemaker/git_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ def _sanitize_git_url(repo_url):
3737
Raises:
3838
ValueError: If the URL contains suspicious patterns that could indicate injection
3939
"""
40-
at_count = repo_url.count('@')
40+
at_count = repo_url.count("@")
4141

42-
if repo_url.startswith('git@'):
42+
if repo_url.startswith("git@"):
4343
# git@ format requires exactly one @
4444
if at_count != 1:
4545
raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol")
46-
elif repo_url.startswith('ssh://'):
46+
elif repo_url.startswith("ssh://"):
4747
# ssh:// format can have 0 or 1 @ symbols
4848
if at_count > 1:
4949
raise ValueError("Invalid SSH URL format: multiple @ symbols detected")
50-
elif repo_url.startswith('https://') or repo_url.startswith('http://'):
50+
elif repo_url.startswith("https://") or repo_url.startswith("http://"):
5151
# HTTPS format allows 0 or 1 @ symbols
5252
if at_count > 1:
5353
raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected")
@@ -58,21 +58,23 @@ def _sanitize_git_url(repo_url):
5858
# Check for suspicious characters in hostname that could indicate injection
5959
if parsed.hostname:
6060
# Check for URL-encoded characters that might be used for obfuscation
61-
suspicious_patterns = ['%25', '%40', '%2F', '%3A'] # encoded %, @, /, :
61+
suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, :
6262
for pattern in suspicious_patterns:
6363
if pattern in parsed.hostname.lower():
6464
raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}")
6565

6666
# Validate that the hostname looks legitimate
67-
if not re.match(r'^[a-zA-Z0-9.-]+$', parsed.hostname):
67+
if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname):
6868
raise ValueError("Invalid characters in hostname")
6969

7070
except Exception as e:
7171
if isinstance(e, ValueError):
7272
raise
7373
raise ValueError(f"Failed to parse URL: {str(e)}")
7474
else:
75-
raise ValueError("Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed")
75+
raise ValueError(
76+
"Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed"
77+
)
7678

7779
return repo_url
7880

@@ -142,10 +144,10 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
142144
if entry_point is None:
143145
raise ValueError("Please provide an entry point.")
144146
_validate_git_config(git_config)
145-
147+
146148
# SECURITY: Sanitize the repository URL to prevent injection attacks
147149
git_config["repo"] = _sanitize_git_url(git_config["repo"])
148-
150+
149151
dest_dir = tempfile.mkdtemp()
150152
_generate_and_run_clone_command(git_config, dest_dir)
151153

tests/unit/test_git_utils.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def test_git_clone_repo_codecommit_https_creds_not_stored_locally(tempdir, mkdte
501501
# URL Sanitization Tests - Security vulnerability prevention
502502
# ============================================================================
503503

504+
504505
class TestGitUrlSanitization:
505506
"""Test cases for Git URL sanitization to prevent injection attacks."""
506507

@@ -513,7 +514,7 @@ def test_sanitize_git_url_valid_https_urls(self):
513514
"https://user:[email protected]/user/repo.git",
514515
"http://internal-git.company.com/repo.git",
515516
]
516-
517+
517518
for url in valid_urls:
518519
# Should not raise any exception
519520
result = git_utils._sanitize_git_url(url)
@@ -528,7 +529,7 @@ def test_sanitize_git_url_valid_ssh_urls(self):
528529
"ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/", # 0 @ symbols - valid for ssh://
529530
"[email protected]:repo.git",
530531
]
531-
532+
532533
for url in valid_urls:
533534
# Should not raise any exception
534535
result = git_utils._sanitize_git_url(url)
@@ -542,7 +543,7 @@ def test_sanitize_git_url_blocks_multiple_at_https(self):
542543
"https://a@b@[email protected]/repo.git",
543544
"https://user@[email protected]/legit/repo.git",
544545
]
545-
546+
546547
for url in malicious_urls:
547548
with pytest.raises(ValueError) as error:
548549
git_utils._sanitize_git_url(url)
@@ -556,34 +557,34 @@ def test_sanitize_git_url_blocks_multiple_at_ssh(self):
556557
"ssh://git@[email protected]/repo.git",
557558
"git@a@b@c:repo.git",
558559
]
559-
560+
560561
for url in malicious_urls:
561562
with pytest.raises(ValueError) as error:
562563
git_utils._sanitize_git_url(url)
563564
# git@ URLs should give "exactly one @ symbol" error
564565
# ssh:// URLs should give "multiple @ symbols detected" error
565-
assert any(phrase in str(error.value) for phrase in [
566-
"multiple @ symbols detected",
567-
"exactly one @ symbol"
568-
])
566+
assert any(
567+
phrase in str(error.value)
568+
for phrase in ["multiple @ symbols detected", "exactly one @ symbol"]
569+
)
569570

570571
def test_sanitize_git_url_blocks_invalid_schemes_and_git_at_format(self):
571572
"""Test that invalid schemes and git@ format violations are blocked."""
572573
# Test unsupported schemes
573574
unsupported_scheme_urls = [
574575
"git-github.com:user/repo.git", # Doesn't start with git@, ssh://, http://, https://
575576
]
576-
577+
577578
for url in unsupported_scheme_urls:
578579
with pytest.raises(ValueError) as error:
579580
git_utils._sanitize_git_url(url)
580581
assert "Unsupported URL scheme" in str(error.value)
581-
582+
582583
# Test git@ URLs with wrong @ count
583584
invalid_git_at_urls = [
584585
"[email protected]@evil.com:repo.git", # 2 @ symbols
585586
]
586-
587+
587588
for url in invalid_git_at_urls:
588589
with pytest.raises(ValueError) as error:
589590
git_utils._sanitize_git_url(url)
@@ -597,15 +598,15 @@ def test_sanitize_git_url_blocks_url_encoding_obfuscation(self):
597598
"https://github.com%2Fevil.com/repo.git",
598599
"https://github.com%3Aevil.com/repo.git",
599600
]
600-
601+
601602
for url in obfuscated_urls:
602603
with pytest.raises(ValueError) as error:
603604
git_utils._sanitize_git_url(url)
604605
# The error could be either suspicious encoding or invalid characters
605-
assert any(phrase in str(error.value) for phrase in [
606-
"Suspicious URL encoding detected",
607-
"Invalid characters in hostname"
608-
])
606+
assert any(
607+
phrase in str(error.value)
608+
for phrase in ["Suspicious URL encoding detected", "Invalid characters in hostname"]
609+
)
609610

610611
def test_sanitize_git_url_blocks_invalid_hostname_chars(self):
611612
"""Test that hostnames with invalid characters are blocked."""
@@ -615,16 +616,19 @@ def test_sanitize_git_url_blocks_invalid_hostname_chars(self):
615616
"https://github[].com/repo.git",
616617
"https://github{}.com/repo.git",
617618
]
618-
619+
619620
for url in invalid_urls:
620621
with pytest.raises(ValueError) as error:
621622
git_utils._sanitize_git_url(url)
622623
# The error could be various types due to URL parsing edge cases
623-
assert any(phrase in str(error.value) for phrase in [
624-
"Invalid characters in hostname",
625-
"Failed to parse URL",
626-
"does not appear to be an IPv4 or IPv6 address"
627-
])
624+
assert any(
625+
phrase in str(error.value)
626+
for phrase in [
627+
"Invalid characters in hostname",
628+
"Failed to parse URL",
629+
"does not appear to be an IPv4 or IPv6 address",
630+
]
631+
)
628632

629633
def test_sanitize_git_url_blocks_unsupported_schemes(self):
630634
"""Test that unsupported URL schemes are blocked."""
@@ -634,7 +638,7 @@ def test_sanitize_git_url_blocks_unsupported_schemes(self):
634638
"javascript:alert('xss')",
635639
"data:text/html,<script>alert('xss')</script>",
636640
]
637-
641+
638642
for url in unsupported_urls:
639643
with pytest.raises(ValueError) as error:
640644
git_utils._sanitize_git_url(url)
@@ -644,10 +648,10 @@ def test_git_clone_repo_blocks_malicious_https_url(self):
644648
"""Test that git_clone_repo blocks malicious HTTPS URLs."""
645649
malicious_git_config = {
646650
"repo": "https://[email protected]@github.com/legit/repo.git",
647-
"branch": "main"
651+
"branch": "main",
648652
}
649653
entry_point = "train.py"
650-
654+
651655
with pytest.raises(ValueError) as error:
652656
git_utils.git_clone_repo(malicious_git_config, entry_point)
653657
assert "multiple @ symbols detected" in str(error.value)
@@ -656,10 +660,10 @@ def test_git_clone_repo_blocks_malicious_ssh_url(self):
656660
"""Test that git_clone_repo blocks malicious SSH URLs."""
657661
malicious_git_config = {
658662
"repo": "git@[email protected]:sage-maker/temp-sev2.git",
659-
"branch": "main"
663+
"branch": "main",
660664
}
661665
entry_point = "train.py"
662-
666+
663667
with pytest.raises(ValueError) as error:
664668
git_utils.git_clone_repo(malicious_git_config, entry_point)
665669
assert "exactly one @ symbol" in str(error.value)
@@ -668,10 +672,10 @@ def test_git_clone_repo_blocks_url_encoded_attack(self):
668672
"""Test that git_clone_repo blocks URL-encoded attacks."""
669673
malicious_git_config = {
670674
"repo": "https://github.com%40attacker.com/repo.git",
671-
"branch": "main"
675+
"branch": "main",
672676
}
673677
entry_point = "train.py"
674-
678+
675679
with pytest.raises(ValueError) as error:
676680
git_utils.git_clone_repo(malicious_git_config, entry_point)
677681
assert "Suspicious URL encoding detected" in str(error.value)
@@ -690,17 +694,20 @@ def test_sanitize_git_url_comprehensive_attack_scenarios(self):
690694
"https://github.com%40evil.com/repo.git",
691695
"https://[email protected]%2Fevil.com/repo.git",
692696
]
693-
697+
694698
entry_point = "train.py"
695-
699+
696700
for malicious_url in attack_scenarios:
697701
git_config = {"repo": malicious_url}
698702
with pytest.raises(ValueError) as error:
699703
git_utils.git_clone_repo(git_config, entry_point)
700704
# Should be blocked by sanitization
701-
assert any(phrase in str(error.value) for phrase in [
702-
"multiple @ symbols detected",
703-
"exactly one @ symbol",
704-
"Suspicious URL encoding detected",
705-
"Invalid characters in hostname"
706-
])
705+
assert any(
706+
phrase in str(error.value)
707+
for phrase in [
708+
"multiple @ symbols detected",
709+
"exactly one @ symbol",
710+
"Suspicious URL encoding detected",
711+
"Invalid characters in hostname",
712+
]
713+
)

0 commit comments

Comments
 (0)