diff --git a/src/codeflare_sdk/common/utils/generate_cert.py b/src/codeflare_sdk/common/utils/generate_cert.py index 7c072da0..8e7c7af6 100644 --- a/src/codeflare_sdk/common/utils/generate_cert.py +++ b/src/codeflare_sdk/common/utils/generate_cert.py @@ -16,8 +16,10 @@ import os from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend from cryptography import x509 from cryptography.x509.oid import NameOID +import ipaddress import datetime from ..kubernetes_cluster.auth import ( config_check, @@ -151,7 +153,7 @@ def generate_tls_cert(cluster_name, namespace, days=30): os.makedirs(tls_dir) # Similar to: - # oc get secret ca-secret- -o template='{{index .data "ca.key"}}' + # oc get secret ca-secret- -o template='{{index .data "tls.key"}}' # oc get secret ca-secret- -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt config_check() v1 = client.CoreV1Api(get_api_client()) @@ -160,11 +162,30 @@ def generate_tls_cert(cluster_name, namespace, days=30): secret_name = get_secret_name(cluster_name, namespace, v1) secret = v1.read_namespaced_secret(secret_name, namespace).data - ca_cert = secret.get("ca.crt") - ca_key = secret.get("ca.key") + ca_cert = secret.get("ca.crt") or secret.get("tls.crt") + ca_key = secret.get("tls.key") or secret.get("ca.key") + if not ca_cert: + raise ValueError( + f"CA certificate (ca.crt or tls.crt) not found in secret {secret_name}. " + f"Available keys: {list(secret.keys())}" + ) + if not ca_key: + raise ValueError( + f"CA private key (tls.key or ca.key) not found in secret {secret_name}. " + f"Available keys: {list(secret.keys())}" + ) + + # Decode and write CA certificate + ca_cert_pem = base64.b64decode(ca_cert).decode("utf-8") with open(os.path.join(tls_dir, "ca.crt"), "w") as f: - f.write(base64.b64decode(ca_cert).decode("utf-8")) + f.write(ca_cert_pem) + + # Extract CA subject to use as issuer for client certificate + ca_cert_obj = x509.load_pem_x509_certificate( + ca_cert_pem.encode("utf-8"), default_backend() + ) + ca_subject = ca_cert_obj.subject # Generate tls.key and signed tls.cert locally for ray client # Similar to running these commands: @@ -191,16 +212,22 @@ def generate_tls_cert(cluster_name, namespace, days=30): with open(os.path.join(tls_dir, "tls.key"), "w") as f: f.write(tls_key.decode("utf-8")) + head_svc_name = f"{cluster_name}-head-svc" + service_dns = f"{head_svc_name}.{namespace}.svc" + service_dns_cluster_local = f"{head_svc_name}.{namespace}.svc.cluster.local" + + san_list = [ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + x509.DNSName(head_svc_name), + x509.DNSName(service_dns), + x509.DNSName(service_dns_cluster_local), + ] + one_day = datetime.timedelta(1, 0, 0) tls_cert = ( x509.CertificateBuilder() - .issuer_name( - x509.Name( - [ - x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"), - ] - ) - ) + .issuer_name(ca_subject) .subject_name( x509.Name( [ @@ -213,9 +240,7 @@ def generate_tls_cert(cluster_name, namespace, days=30): .not_valid_after(datetime.datetime.today() + (one_day * days)) .serial_number(x509.random_serial_number()) .add_extension( - x509.SubjectAlternativeName( - [x509.DNSName("localhost"), x509.DNSName("127.0.0.1")] - ), + x509.SubjectAlternativeName(san_list), False, ) .sign( diff --git a/src/codeflare_sdk/common/utils/test_generate_cert.py b/src/codeflare_sdk/common/utils/test_generate_cert.py index b4439c20..e821c48b 100644 --- a/src/codeflare_sdk/common/utils/test_generate_cert.py +++ b/src/codeflare_sdk/common/utils/test_generate_cert.py @@ -53,7 +53,7 @@ def test_generate_ca_cert(): def secret_ca_retreival(secret_name, namespace): ca_private_key_bytes, ca_cert = generate_ca_cert() - data = {"ca.crt": ca_cert, "ca.key": ca_private_key_bytes} + data = {"ca.crt": ca_cert, "tls.key": ca_private_key_bytes} assert secret_name == "ca-secret-cluster" assert namespace == "namespace" return client.models.V1Secret(data=data) @@ -87,6 +87,50 @@ def test_generate_tls_cert(mocker): assert tls_cert.verify_directly_issued_by(root_cert) == None +def secret_ca_retreival_with_ca_key(secret_name, namespace): + """Mock secret retrieval with ca.key instead of tls.key (KubeRay format)""" + ca_private_key_bytes, ca_cert = generate_ca_cert() + data = {"ca.crt": ca_cert, "ca.key": ca_private_key_bytes} + assert secret_name == "ca-secret-cluster2" + assert namespace == "namespace2" + return client.models.V1Secret(data=data) + + +def test_generate_tls_cert_with_ca_key_fallback(mocker): + """ + Test that generate_tls_cert works when secret contains ca.key instead of tls.key + This tests the fallback logic for KubeRay-created secrets + """ + mocker.patch("kubernetes.config.load_kube_config", return_value="ignore") + mocker.patch( + "codeflare_sdk.common.utils.generate_cert.get_secret_name", + return_value="ca-secret-cluster2", + ) + mocker.patch( + "kubernetes.client.CoreV1Api.read_namespaced_secret", + side_effect=secret_ca_retreival_with_ca_key, + ) + + generate_tls_cert("cluster2", "namespace2") + assert os.path.exists("tls-cluster2-namespace2") + assert os.path.exists(os.path.join("tls-cluster2-namespace2", "ca.crt")) + assert os.path.exists(os.path.join("tls-cluster2-namespace2", "tls.crt")) + assert os.path.exists(os.path.join("tls-cluster2-namespace2", "tls.key")) + + # verify the that the signed tls.crt is issued by the ca_cert (root cert) + with open(os.path.join("tls-cluster2-namespace2", "tls.crt"), "r") as f: + tls_cert = load_pem_x509_certificate(f.read().encode("utf-8")) + with open(os.path.join("tls-cluster2-namespace2", "ca.crt"), "r") as f: + root_cert = load_pem_x509_certificate(f.read().encode("utf-8")) + assert tls_cert.verify_directly_issued_by(root_cert) == None + + # Cleanup for this test + os.remove("tls-cluster2-namespace2/ca.crt") + os.remove("tls-cluster2-namespace2/tls.crt") + os.remove("tls-cluster2-namespace2/tls.key") + os.rmdir("tls-cluster2-namespace2") + + def test_export_env(): """ test the function codeflare_sdk.common.utils.generate_ca_cert.export_ev generates the correct outputs