Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class AWSConfig(BaseConfig):
# Database
database: DatabaseConfig = Field(default_factory=DatabaseConfig)

# Custom AMI
custom_ami_id: str | None = None

# Custom tags from user
custom_tags: dict[str, str] = Field(default_factory=dict)

Expand Down
6 changes: 6 additions & 0 deletions pulumi_pinecone_byoc/aws/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class PineconeAWSClusterArgs:
# gcp_project is needed by some helmfiles even for AWS clusters (cross-cloud monitoring/metrics)
gcp_project: str = "production-pinecone"

# custom AMI
custom_ami_id: str | None = None

# tags
tags: dict[str, str] | None = None

Expand Down Expand Up @@ -375,6 +378,8 @@ def __init__(
"aws_amp_remote_write_url": self._amp_access.amp_remote_write_endpoint,
"aws_amp_sigv4_role_arn": self._amp_access.pinecone_role_arn,
"aws_amp_ingest_role_arn": self._k8s_addons.amp_ingest_role.arn,
"base64_encoded_user_data": self._eks.base64_encoded_user_data,
"custom_ami_id": args.custom_ami_id,
}

self._k8s_configmaps = K8sConfigMaps(
Expand Down Expand Up @@ -509,6 +514,7 @@ def _build_config(self, args: PineconeAWSClusterArgs):
node_pools=node_pools,
parent_zone_name=args.parent_dns_zone_name,
database=DatabaseConfig(deletion_protection=args.deletion_protection),
custom_ami_id=args.custom_ami_id,
custom_tags=args.tags or {},
)

Expand Down
43 changes: 42 additions & 1 deletion pulumi_pinecone_byoc/aws/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Creates a managed EKS cluster with configurable node groups.
"""

import base64
import json

import pulumi
Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(
opts=child_opts,
)

self._base64_encoded_user_data = self._create_user_data() if config.custom_ami_id else None

self.node_groups: list[aws.eks.NodeGroup] = []
for np_config in config.node_pools:
node_group = self._create_node_group(name, np_config, vpc, self._node_role, child_opts)
Expand Down Expand Up @@ -166,6 +169,38 @@ def _create_node_role(self, name: str, opts: pulumi.ResourceOptions) -> aws.iam.

return role

def _create_user_data(self) -> pulumi.Output[str]:
"""Create base64-encoded user data for custom AMI bootstrap via nodeadm.

See https://awslabs.github.io/amazon-eks-ami/nodeadm/
"""
return pulumi.Output.all(
self.cluster.eks_cluster.name,
self.cluster.eks_cluster.endpoint,
self.cluster.eks_cluster.certificate_authorities[0].data,
self.cluster.eks_cluster.kubernetes_network_configs[0].service_ipv4_cidr,
).apply(
lambda args: base64.b64encode(
f"""MIME-Version: 1.0
Content-Type: multipart/mixed; boundary="==BOUNDARY=="

--==BOUNDARY==
Content-Type: application/node.eks.aws

---
apiVersion: node.eks.aws/v1alpha1
kind: NodeConfig
spec:
cluster:
name: {args[0]}
apiServerEndpoint: {args[1]}
certificateAuthority: {args[2]}
cidr: {args[3]}
--==BOUNDARY==--
""".encode()
).decode("utf-8")
)

def _create_launch_template(
self,
name: str,
Expand Down Expand Up @@ -195,6 +230,8 @@ def _create_launch_template(
http_put_response_hop_limit=2,
http_tokens="optional",
),
image_id=self.config.custom_ami_id,
user_data=self._base64_encoded_user_data,
tag_specifications=[
aws.ec2.LaunchTemplateTagSpecificationArgs(
resource_type="instance",
Expand Down Expand Up @@ -237,7 +274,7 @@ def _create_node_group(
),
node_role_arn=node_role.arn,
subnet_ids=vpc.private_subnet_ids,
ami_type="AL2023_x86_64_STANDARD",
ami_type="AL2023_x86_64_STANDARD" if self.config.custom_ami_id is None else "CUSTOM",
instance_types=[np_config.instance_type],
# disk_size is configured in launch template
scaling_config=aws.eks.NodeGroupScalingConfigArgs(
Expand Down Expand Up @@ -290,3 +327,7 @@ def node_role_name(self) -> pulumi.Output[str]:
@property
def cluster_security_group_id(self) -> pulumi.Output[str]:
return self.cluster.eks_cluster.vpc_config.cluster_security_group_id

@property
def base64_encoded_user_data(self) -> pulumi.Output[str] | None:
return self._base64_encoded_user_data
21 changes: 21 additions & 0 deletions setup/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def _check_cidr_conflicts(self):


class AWSSetupWizard(BaseSetupWizard):
TOTAL_STEPS = 14
HEADER_TITLE = "Pinecone BYOC Setup Wizard"
HEADER_SUBTITLE = "This wizard will set up everything you need to deploy Pinecone BYOC."
DEFAULT_CIDR = "10.0.0.0/16"
Expand All @@ -786,6 +787,7 @@ def run(self, output_dir: str = ".") -> bool:

region = self._get_region()
azs = self._get_azs(region)
custom_ami_id = self._get_custom_ami_id()
cidr = self._get_cidr()
deletion_protection = self._get_deletion_protection()
public_access = self._get_public_access()
Expand All @@ -809,6 +811,7 @@ def run(self, output_dir: str = ".") -> bool:
deletion_protection,
public_access,
tags,
custom_ami_id=custom_ami_id,
)

def _run_headless(self, output_dir: str) -> bool:
Expand All @@ -828,6 +831,7 @@ def _run_headless(self, output_dir: str) -> bool:
)
public_access = os.environ.get("PINECONE_PUBLIC_ACCESS", "true").lower() == "true"
project_name = os.environ.get("PINECONE_PROJECT_NAME", "pinecone-byoc")
custom_ami_id = os.environ.get("PINECONE_CUSTOM_AMI_ID", "") or None

return self._generate_project(
output_dir,
Expand All @@ -839,6 +843,7 @@ def _run_headless(self, output_dir: str) -> bool:
deletion_protection,
public_access,
{},
custom_ami_id=custom_ami_id,
)

def _validate_aws_creds(self) -> bool:
Expand Down Expand Up @@ -900,6 +905,16 @@ def _get_azs(self, region: str) -> list[str]:
azs = [az.strip() for az in azs_input.split(",")]
return azs

def _get_custom_ami_id(self) -> str | None:
console.print()
console.print(f" {self._step('Custom AMI (Optional)')}")
console.print(
" [dim]Specify a custom AMI ID for EKS nodes (leave blank for default AWS AMI)[/]"
)
console.print()
ami_id = self._prompt("Enter AMI ID (or press Enter to skip)", "")
return ami_id or None

def _run_preflight_checks(self, region: str, azs: list[str], cidr: str) -> bool:
console.print()
console.print(f" {self._step('Preflight Checks')}")
Expand All @@ -926,6 +941,7 @@ def _generate_project(
deletion_protection: bool,
public_access: bool,
tags: dict[str, str],
custom_ami_id: str | None = None,
):
console.print()

Expand Down Expand Up @@ -970,6 +986,7 @@ def _generate_project(
availability_zones=config.require_object("availability-zones"),
deletion_protection=config.get_bool("deletion-protection") if config.get_bool("deletion-protection") is not None else True,
public_access_enabled=config.get_bool("public-access-enabled") if config.get_bool("public-access-enabled") is not None else True,
custom_ami_id=config.get("custom-ami-id"),
tags=config.get_object("tags"),
),
)
Expand Down Expand Up @@ -1014,6 +1031,10 @@ def _generate_project(
for az in azs:
config_content += f" - {az}\n"

# add custom AMI ID if provided
if custom_ami_id:
config_content += f" {project_name}:custom-ami-id: {custom_ami_id}\n"

# add tags if provided (quote values to handle YAML special chars)
if tags:
config_content += f" {project_name}:tags:\n"
Expand Down
Loading