Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions src/codeflare_sdk/common/utils/generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import os
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from cryptography import x509
from cryptography.x509.oid import NameOID
import ipaddress
import datetime
from ..kubernetes_cluster.auth import (
config_check,
Expand Down Expand Up @@ -151,7 +153,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
os.makedirs(tls_dir)

# Similar to:
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "tls.key"}}'
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
config_check()
v1 = client.CoreV1Api(get_api_client())
Expand All @@ -160,11 +162,30 @@ def generate_tls_cert(cluster_name, namespace, days=30):
secret_name = get_secret_name(cluster_name, namespace, v1)
secret = v1.read_namespaced_secret(secret_name, namespace).data

ca_cert = secret.get("ca.crt")
ca_key = secret.get("ca.key")
ca_cert = secret.get("ca.crt") or secret.get("tls.crt")
ca_key = secret.get("tls.key") or secret.get("ca.key")

if not ca_cert:
raise ValueError(
f"CA certificate (ca.crt or tls.crt) not found in secret {secret_name}. "
f"Available keys: {list(secret.keys())}"
)
if not ca_key:
raise ValueError(
f"CA private key (tls.key or ca.key) not found in secret {secret_name}. "
f"Available keys: {list(secret.keys())}"
)

# Decode and write CA certificate
ca_cert_pem = base64.b64decode(ca_cert).decode("utf-8")
with open(os.path.join(tls_dir, "ca.crt"), "w") as f:
f.write(base64.b64decode(ca_cert).decode("utf-8"))
f.write(ca_cert_pem)

# Extract CA subject to use as issuer for client certificate
ca_cert_obj = x509.load_pem_x509_certificate(
ca_cert_pem.encode("utf-8"), default_backend()
)
ca_subject = ca_cert_obj.subject

# Generate tls.key and signed tls.cert locally for ray client
# Similar to running these commands:
Expand All @@ -191,16 +212,22 @@ def generate_tls_cert(cluster_name, namespace, days=30):
with open(os.path.join(tls_dir, "tls.key"), "w") as f:
f.write(tls_key.decode("utf-8"))

head_svc_name = f"{cluster_name}-head-svc"
service_dns = f"{head_svc_name}.{namespace}.svc"
service_dns_cluster_local = f"{head_svc_name}.{namespace}.svc.cluster.local"

san_list = [
x509.DNSName("localhost"),
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
x509.DNSName(head_svc_name),
x509.DNSName(service_dns),
x509.DNSName(service_dns_cluster_local),
]

one_day = datetime.timedelta(1, 0, 0)
tls_cert = (
x509.CertificateBuilder()
.issuer_name(
x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"),
]
)
)
.issuer_name(ca_subject)
.subject_name(
x509.Name(
[
Expand All @@ -213,9 +240,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
.not_valid_after(datetime.datetime.today() + (one_day * days))
.serial_number(x509.random_serial_number())
.add_extension(
x509.SubjectAlternativeName(
[x509.DNSName("localhost"), x509.DNSName("127.0.0.1")]
),
x509.SubjectAlternativeName(san_list),
False,
)
.sign(
Expand Down
46 changes: 45 additions & 1 deletion src/codeflare_sdk/common/utils/test_generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_generate_ca_cert():

def secret_ca_retreival(secret_name, namespace):
ca_private_key_bytes, ca_cert = generate_ca_cert()
data = {"ca.crt": ca_cert, "ca.key": ca_private_key_bytes}
data = {"ca.crt": ca_cert, "tls.key": ca_private_key_bytes}
assert secret_name == "ca-secret-cluster"
assert namespace == "namespace"
return client.models.V1Secret(data=data)
Expand Down Expand Up @@ -87,6 +87,50 @@ def test_generate_tls_cert(mocker):
assert tls_cert.verify_directly_issued_by(root_cert) == None


def secret_ca_retreival_with_ca_key(secret_name, namespace):
"""Mock secret retrieval with ca.key instead of tls.key (KubeRay format)"""
ca_private_key_bytes, ca_cert = generate_ca_cert()
data = {"ca.crt": ca_cert, "ca.key": ca_private_key_bytes}
assert secret_name == "ca-secret-cluster2"
assert namespace == "namespace2"
return client.models.V1Secret(data=data)


def test_generate_tls_cert_with_ca_key_fallback(mocker):
"""
Test that generate_tls_cert works when secret contains ca.key instead of tls.key
This tests the fallback logic for KubeRay-created secrets
"""
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
mocker.patch(
"codeflare_sdk.common.utils.generate_cert.get_secret_name",
return_value="ca-secret-cluster2",
)
mocker.patch(
"kubernetes.client.CoreV1Api.read_namespaced_secret",
side_effect=secret_ca_retreival_with_ca_key,
)

generate_tls_cert("cluster2", "namespace2")
assert os.path.exists("tls-cluster2-namespace2")
assert os.path.exists(os.path.join("tls-cluster2-namespace2", "ca.crt"))
assert os.path.exists(os.path.join("tls-cluster2-namespace2", "tls.crt"))
assert os.path.exists(os.path.join("tls-cluster2-namespace2", "tls.key"))

# verify the that the signed tls.crt is issued by the ca_cert (root cert)
with open(os.path.join("tls-cluster2-namespace2", "tls.crt"), "r") as f:
tls_cert = load_pem_x509_certificate(f.read().encode("utf-8"))
with open(os.path.join("tls-cluster2-namespace2", "ca.crt"), "r") as f:
root_cert = load_pem_x509_certificate(f.read().encode("utf-8"))
assert tls_cert.verify_directly_issued_by(root_cert) == None

# Cleanup for this test
os.remove("tls-cluster2-namespace2/ca.crt")
os.remove("tls-cluster2-namespace2/tls.crt")
os.remove("tls-cluster2-namespace2/tls.key")
os.rmdir("tls-cluster2-namespace2")


def test_export_env():
"""
test the function codeflare_sdk.common.utils.generate_ca_cert.export_ev generates the correct outputs
Expand Down
Loading