diff --git a/src/DIRAC/Core/Security/DiracX.py b/src/DIRAC/Core/Security/DiracX.py index 8e866dd28d7..521188cae93 100644 --- a/src/DIRAC/Core/Security/DiracX.py +++ b/src/DIRAC/Core/Security/DiracX.py @@ -40,7 +40,7 @@ PEM_BEGIN = "-----BEGIN DIRACX-----" PEM_END = "-----END DIRACX-----" -RE_DIRACX_PEM = re.compile(rf"{PEM_BEGIN}\n(.*)\n{PEM_END}", re.MULTILINE | re.DOTALL) +RE_DIRACX_PEM = re.compile(rf"{PEM_BEGIN}\n(.*?)\n{PEM_END}", re.DOTALL) @convertToReturnValue @@ -62,21 +62,26 @@ def addTokenToPEM(pemPath, group): token_type=token_content.get("token_type"), refresh_token=token_content.get("refresh_token"), ) - token_pem = f"{PEM_BEGIN}\n" data = base64.b64encode(serialize_credentials(token).encode("utf-8")).decode() token_pem += textwrap.fill(data, width=64) token_pem += f"\n{PEM_END}\n" - with open(pemPath, "a") as f: - f.write(token_pem) + pem = Path(pemPath).read_text() + # Remove any existing DiracX token there would be + new_pem = re.sub(RE_DIRACX_PEM, "", pem) + new_pem += token_pem + + Path(pemPath).write_text(new_pem) def diracxTokenFromPEM(pemPath) -> dict[str, Any] | None: """Extract the DiracX token from the proxy PEM file""" pem = Path(pemPath).read_text() - if match := RE_DIRACX_PEM.search(pem): - match = match.group(1) + if match := RE_DIRACX_PEM.findall(pem): + if len(match) > 1: + raise ValueError("Found multiple DiracX tokens, this should never happen") + match = match[0] return json.loads(base64.b64decode(match).decode("utf-8")) diff --git a/src/DIRAC/Core/Security/test/test_diracx_token_from_pem.py b/src/DIRAC/Core/Security/test/test_diracx_token_from_pem.py new file mode 100644 index 00000000000..5f57768a55c --- /dev/null +++ b/src/DIRAC/Core/Security/test/test_diracx_token_from_pem.py @@ -0,0 +1,161 @@ +import base64 +import json +import pytest +import tempfile +from pathlib import Path +from unittest.mock import patch, mock_open + +from DIRAC.Core.Security.DiracX import diracxTokenFromPEM, PEM_BEGIN, PEM_END, RE_DIRACX_PEM + + +class TestDiracxTokenFromPEM: + """Test cases for diracxTokenFromPEM function""" + + def create_valid_token_data(self): + """Create valid token data for testing""" + return { + "access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.test", + "refresh_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.refresh", + "expires_in": 3600, + "token_type": "Bearer", + } + + def create_pem_content(self, token_data=None, include_other_content=True): + """Create PEM content with embedded DiracX token""" + if token_data is None: + token_data = self.create_valid_token_data() + + # Encode token data + token_json = json.dumps(token_data) + encoded_token = base64.b64encode(token_json.encode("utf-8")).decode() + + # Create PEM content + pem_content = "" + if include_other_content: + pem_content += "-----BEGIN CERTIFICATE-----\n" + pem_content += "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA...\n" + pem_content += "-----END CERTIFICATE-----\n" + + pem_content += f"{PEM_BEGIN}\n" + pem_content += encoded_token + "\n" + pem_content += f"{PEM_END}\n" + + return pem_content + + def test_valid_token_extraction(self): + """Test successful extraction of valid token from PEM file""" + token_data = self.create_valid_token_data() + pem_content = self.create_pem_content(token_data) + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".pem") as f: + f.write(pem_content) + temp_path = f.name + + try: + result = diracxTokenFromPEM(temp_path) + assert result == token_data + finally: + Path(temp_path).unlink() + + def test_no_token_in_pem(self): + """Test behavior when no DiracX token is present in PEM file""" + pem_content = """-----BEGIN CERTIFICATE----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA... +-----END CERTIFICATE-----""" + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".pem") as f: + f.write(pem_content) + temp_path = f.name + + try: + result = diracxTokenFromPEM(temp_path) + assert result is None + finally: + Path(temp_path).unlink() + + def test_multiple_tokens_error(self): + """Test that multiple tokens raise ValueError""" + token_data = self.create_valid_token_data() + + # Create PEM with two tokens + pem_content = self.create_pem_content(token_data) + pem_content += "\n" + self.create_pem_content(token_data, include_other_content=False) + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".pem") as f: + f.write(pem_content) + temp_path = f.name + + try: + with pytest.raises(ValueError, match="Found multiple DiracX tokens"): + diracxTokenFromPEM(temp_path) + finally: + Path(temp_path).unlink() + + def test_malformed_base64(self): + """Test behavior with malformed base64 data""" + pem_content = f"""{PEM_BEGIN} +invalid_base64_data_that_will_cause_error! +{PEM_END}""" + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".pem") as f: + f.write(pem_content) + temp_path = f.name + + try: + with pytest.raises(Exception): # base64.b64decode will raise an exception + diracxTokenFromPEM(temp_path) + finally: + Path(temp_path).unlink() + + def test_invalid_json_in_token(self): + """Test behavior with invalid JSON in token data""" + invalid_json = "this is not valid json" + encoded_invalid = base64.b64encode(invalid_json.encode("utf-8")).decode() + + pem_content = f"""{PEM_BEGIN} +{encoded_invalid} +{PEM_END}""" + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".pem") as f: + f.write(pem_content) + temp_path = f.name + + try: + with pytest.raises(json.JSONDecodeError): + diracxTokenFromPEM(temp_path) + finally: + Path(temp_path).unlink() + + def test_token_with_unicode_characters(self): + """Test token with unicode characters""" + unicode_token = { + "access_token": "token_with_unicode_ñ_é_ü", + "refresh_token": "refresh_with_emoji_🚀_🎉", + "expires_in": 3600, + "token_type": "Bearer", + } + + pem_content = self.create_pem_content(unicode_token) + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".pem") as f: + f.write(pem_content) + temp_path = f.name + + try: + result = diracxTokenFromPEM(temp_path) + assert result == unicode_token + finally: + Path(temp_path).unlink() + + def test_regex_pattern_validation(self): + """Test that the regex pattern correctly identifies DiracX tokens""" + # Test that the regex matches the expected pattern + token_data = self.create_valid_token_data() + token_json = json.dumps(token_data) + encoded_token = base64.b64encode(token_json.encode("utf-8")).decode() + + test_content = f"{PEM_BEGIN}\n{encoded_token}\n{PEM_END}" + matches = RE_DIRACX_PEM.findall(test_content) + + assert len(matches) == 1 + assert matches[0] == encoded_token diff --git a/src/DIRAC/Core/Utilities/test/Test_Profiler.py b/src/DIRAC/Core/Utilities/test/Test_Profiler.py index aea819e70e4..601025c7479 100644 --- a/src/DIRAC/Core/Utilities/test/Test_Profiler.py +++ b/src/DIRAC/Core/Utilities/test/Test_Profiler.py @@ -29,50 +29,50 @@ def test_base(): time.sleep(1) p = Profiler(mainProcess.pid) res = p.pid() - assert res["OK"] is True + assert res["OK"] is True, res res = p.status() - assert res["OK"] is True + assert res["OK"] is True, res res = p.runningTime() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 res = p.memoryUsage() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 resWC = p.memoryUsage(withChildren=True) - assert resWC["OK"] is True + assert resWC["OK"] is True, res assert resWC["Value"] > 0 assert resWC["Value"] >= res["Value"] res = p.vSizeUsage() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 resWC = p.vSizeUsage(withChildren=True) - assert resWC["OK"] is True + assert resWC["OK"] is True, res assert resWC["Value"] > 0 assert resWC["Value"] >= res["Value"] res = p.vSizeUsage() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 resWC = p.vSizeUsage(withChildren=True) - assert resWC["OK"] is True + assert resWC["OK"] is True, res assert resWC["Value"] > 0 assert resWC["Value"] >= res["Value"] res = p.numThreads() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 resWC = p.numThreads(withChildren=True) - assert resWC["OK"] is True + assert resWC["OK"] is True, res assert resWC["Value"] > 0 assert resWC["Value"] >= res["Value"] res = p.cpuPercentage() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] >= 0 resWC = p.cpuPercentage(withChildren=True) - assert resWC["OK"] is True + assert resWC["OK"] is True, res assert resWC["Value"] >= 0 assert resWC["Value"] >= res["Value"] @@ -88,13 +88,13 @@ def test_cpuUsage(): time.sleep(2) p = Profiler(mainProcess.pid) res = p.pid() - assert res["OK"] is True + assert res["OK"] is True, res res = p.status() - assert res["OK"] is True + assert res["OK"] is True, res # user res = p.cpuUsageUser() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 resC = p.cpuUsageUser(withChildren=True) assert resC["OK"] is True @@ -102,7 +102,7 @@ def test_cpuUsage(): assert resC["Value"] >= res["Value"] res = p.cpuUsageUser() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 resC = p.cpuUsageUser(withChildren=True) assert resC["OK"] is True @@ -121,15 +121,15 @@ def test_cpuUsage(): # system res = p.cpuUsageSystem() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] >= 0 resWC = p.cpuUsageSystem(withChildren=True) - assert resWC["OK"] is True + assert resWC["OK"] is True, res assert resWC["Value"] >= 0 assert resWC["Value"] >= res["Value"] res = p.cpuUsageSystem() - assert res["OK"] is True + assert res["OK"] is True, res assert res["Value"] > 0 resC = p.cpuUsageSystem(withChildren=True) assert resC["OK"] is True