Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions api/src/app/schemas/host_schemas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ipaddress import IPv4Address
from typing import Self

from pydantic import (
BaseModel,
ConfigDict,
Field,
ValidationInfo,
field_validator,
model_validator,
)

from src.app.enums.operating_systems import OpenLabsOS
Expand Down Expand Up @@ -82,20 +83,14 @@ def validate_hostname(cls, hostname: str) -> str:
raise ValueError(msg)
return hostname

@field_validator("size")
@classmethod
def validate_size(cls, size: int, info: ValidationInfo) -> int:
@model_validator(mode="after")
def validate_size(self) -> Self:
"""Check VM disk size is sufficient."""
os: OpenLabsOS | None = info.data.get("os")

if os is None:
msg = "OS field not set to OpenLabsOS type."
if not is_valid_disk_size(self.os, self.size):
msg = f"Disk size {self.size}GB too small for OS: {self.os.value}. Minimum disk size: {OS_SIZE_THRESHOLD[self.os]}GB."
raise ValueError(msg)

if not is_valid_disk_size(os, size):
msg = f"Disk size {size}GB too small for OS: {os.value}. Minimum disk size: {OS_SIZE_THRESHOLD[os]}GB."
raise ValueError(msg)
return size
return self


# ==================== Blueprints =====================
Expand Down
102 changes: 40 additions & 62 deletions api/src/app/schemas/range_schemas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from datetime import datetime, timezone
from ipaddress import IPv4Address
from typing import Any
from typing import Any, Self, Sequence

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
)

from ..enums.providers import OpenLabsProvider
from ..enums.range_states import RangeState
Expand All @@ -14,6 +19,7 @@
BlueprintVPCSchema,
DeployedVPCCreateSchema,
DeployedVPCSchema,
VPCCommonSchema,
)


Expand All @@ -32,6 +38,36 @@ class RangeCommonSchema(BaseModel):
vpn: bool = Field(default=False, description="Automatic VPN configuration.")


class RangeCreateValidationMixin(BaseModel):
"""Mixin class with common validation for all range creation schemas."""

vpcs: Sequence[VPCCommonSchema]

@model_validator(mode="after")
def validate_unique_vpc_names(self) -> Self:
"""Check VPC names are unique."""
if not self.vpcs:
return self

vpc_names = [vpc.name for vpc in self.vpcs]
if len(vpc_names) != len(set(vpc_names)):
msg = "All VPCs in the range must have unique names."
raise ValueError(msg)
return self

@model_validator(mode="after")
def validate_mutually_exclusive_vpcs(self) -> Self:
"""Check that VPCs do not overlap."""
if not self.vpcs:
return self

vpc_cidrs = [vpc.cidr for vpc in self.vpcs]
if not mutually_exclusive_networks_v4(vpc_cidrs):
msg = "All VPCs in the range must be mutually exclusive (not overlap)."
Copy link

Copilot AI Jul 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The error message differs from the original 'All VPCs in range should be mutually exclusive (not overlap).' Consider maintaining consistency with the original wording unless the change is intentional.

Suggested change
msg = "All VPCs in the range must be mutually exclusive (not overlap)."
msg = "All VPCs in range should be mutually exclusive (not overlap)."

Copilot uses AI. Check for mistakes.
raise ValueError(msg)
return self


# ==================== Blueprints =====================


Expand All @@ -48,42 +84,13 @@ class BlueprintRangeBaseSchema(RangeCommonSchema):
pass


class BlueprintRangeCreateSchema(BlueprintRangeBaseSchema):
class BlueprintRangeCreateSchema(BlueprintRangeBaseSchema, RangeCreateValidationMixin):
"""Schema to create blueprint range objects."""

vpcs: list[BlueprintVPCCreateSchema] = Field(
..., description="All blueprint VPCs in range."
)

@field_validator("vpcs")
@classmethod
def validate_unique_vpc_names(
cls, vpcs: list[BlueprintVPCCreateSchema], info: ValidationInfo
) -> list[BlueprintVPCCreateSchema]:
"""Check VPC names are unique."""
vpc_names = [vpc.name for vpc in vpcs]

if len(vpc_names) != len(set(vpc_names)):
msg = "All VPCs in the range must have unique names."
raise ValueError(msg)

return vpcs

@field_validator("vpcs")
@classmethod
def validate_mutually_exclusive_vpcs(
cls, vpcs: list[BlueprintVPCCreateSchema], info: ValidationInfo
) -> list[BlueprintVPCCreateSchema]:
"""Check that VPCs do not overlap."""
vpc_cidrs = [vpc.cidr for vpc in vpcs]

if not mutually_exclusive_networks_v4(vpc_cidrs):

msg = "All VPCs in range should be mutually exclusive (not overlap)."
raise ValueError(msg)

return vpcs


class BlueprintRangeSchema(BlueprintRangeBaseSchema):
"""Blueprint range object."""
Expand Down Expand Up @@ -152,42 +159,13 @@ class DeployedRangeBaseSchema(RangeCommonSchema):
)


class DeployedRangeCreateSchema(DeployedRangeBaseSchema):
class DeployedRangeCreateSchema(DeployedRangeBaseSchema, RangeCreateValidationMixin):
"""Schema to create deployed range object."""

vpcs: list[DeployedVPCCreateSchema] = Field(
..., description="Deployed VPCs in the range."
)

@field_validator("vpcs")
@classmethod
def validate_unique_vpc_names(
cls, vpcs: list[DeployedVPCCreateSchema], info: ValidationInfo
) -> list[DeployedVPCCreateSchema]:
"""Check VPC names are unique."""
vpc_names = [vpc.name for vpc in vpcs]

if len(vpc_names) != len(set(vpc_names)):
msg = "All VPCs in the range must have unique names."
raise ValueError(msg)

return vpcs

@field_validator("vpcs")
@classmethod
def validate_mutually_exclusive_vpcs(
cls, vpcs: list[DeployedVPCCreateSchema], info: ValidationInfo
) -> list[DeployedVPCCreateSchema]:
"""Check that VPCs do not overlap."""
vpc_cidrs = [vpc.cidr for vpc in vpcs]

if not mutually_exclusive_networks_v4(vpc_cidrs):

msg = "All VPCs in range should be mutually exclusive (not overlap)."
raise ValueError(msg)

return vpcs


class DeployedRangeSchema(DeployedRangeBaseSchema):
"""Deployed range object."""
Expand Down
130 changes: 49 additions & 81 deletions api/src/app/schemas/subnet_schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from ipaddress import IPv4Network

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from typing import Self, Sequence

from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)

from ..validators.names import OPENLABS_NAME_REGEX
from ..validators.network import max_num_hosts_in_subnet
Expand All @@ -9,6 +16,7 @@
BlueprintHostSchema,
DeployedHostCreateSchema,
DeployedHostSchema,
HostCommonSchema,
)


Expand All @@ -27,21 +35,13 @@ class SubnetCommonSchema(BaseModel):
)


# ==================== Blueprints =====================


class BlueprintSubnetBaseSchema(SubnetCommonSchema):
"""Base pydantic class for all blueprint subnet objects."""
class SubnetCreateValidationMixin(BaseModel):
"""Mixin class with common validation for all subnet creation schemas."""

pass


class BlueprintSubnetCreateSchema(BlueprintSubnetBaseSchema):
"""Schema to create blueprint subnet objects."""

hosts: list[BlueprintHostCreateSchema] = Field(
..., description="All blueprint hosts in the subnet."
)
# Forward references
name: str
cidr: IPv4Network
hosts: Sequence[HostCommonSchema]

@field_validator("cidr")
@classmethod
Expand All @@ -52,38 +52,48 @@ def validate_subnet_private_cidr_range(cls, cidr: IPv4Network) -> IPv4Network:
raise ValueError(msg)
return cidr

@field_validator("hosts")
@classmethod
def validate_unique_hostnames(
cls, hosts: list[BlueprintHostCreateSchema]
) -> list[BlueprintHostCreateSchema]:
@model_validator(mode="after")
def validate_unique_hostnames(self) -> Self:
"""Check hostnames are unique."""
hostnames = [host.hostname for host in hosts]
if not self.hosts:
return self

hostnames = [host.hostname for host in self.hosts]
if len(hostnames) != len(set(hostnames)):
msg = "All hostnames must be unique."
msg = f"All hostnames in subnet: {self.name} must be unique."
raise ValueError(msg)
return hosts

@field_validator("hosts")
@classmethod
def validate_max_number_hosts(
cls, hosts: list[BlueprintHostCreateSchema], info: ValidationInfo
) -> list[BlueprintHostCreateSchema]:
return self

@model_validator(mode="after")
def validate_max_number_hosts(self) -> Self:
"""Check that the number of hosts does not exceed subnet CIDR."""
subnet_cidr = info.data.get("cidr")
max_num_hosts = max_num_hosts_in_subnet(self.cidr)

if not subnet_cidr:
msg = "Subnet missing CIDR."
if len(self.hosts) > max_num_hosts:
msg = f"Too many hosts in subnet: {self.name}! Max: {max_num_hosts}, Requested: {len(self.hosts)}"
raise ValueError(msg)

max_num_hosts = max_num_hosts_in_subnet(subnet_cidr)
num_requested_hosts = len(hosts)
return self

if num_requested_hosts > max_num_hosts:
msg = f"Too many hosts in subnet! Max: {max_num_hosts}, Requested: {num_requested_hosts}"
raise ValueError(msg)

return hosts
# ==================== Blueprints =====================


class BlueprintSubnetBaseSchema(SubnetCommonSchema):
"""Base pydantic class for all blueprint subnet objects."""

pass


class BlueprintSubnetCreateSchema(
BlueprintSubnetBaseSchema, SubnetCreateValidationMixin
):
"""Schema to create blueprint subnet objects."""

hosts: list[BlueprintHostCreateSchema] = Field(
..., description="All blueprint hosts in the subnet."
)


class BlueprintSubnetSchema(BlueprintSubnetBaseSchema):
Expand Down Expand Up @@ -119,55 +129,13 @@ class DeployedSubnetBaseSchema(SubnetCommonSchema):
)


class DeployedSubnetCreateSchema(DeployedSubnetBaseSchema):
class DeployedSubnetCreateSchema(DeployedSubnetBaseSchema, SubnetCreateValidationMixin):
"""Schema to create deployed subnet objects."""

hosts: list[DeployedHostCreateSchema] = Field(
..., description="Deployed hosts within subnet."
)

@field_validator("cidr")
@classmethod
def validate_subnet_private_cidr_range(cls, cidr: IPv4Network) -> IPv4Network:
"""Check subnet CIDR ranges are private."""
if not cidr.is_private:
msg = "Subnets should only use private CIDR ranges."
raise ValueError(msg)
return cidr

@field_validator("hosts")
@classmethod
def validate_unique_hostnames(
cls, hosts: list[DeployedHostCreateSchema]
) -> list[DeployedHostCreateSchema]:
"""Check hostnames are unique."""
hostnames = [host.hostname for host in hosts]
if len(hostnames) != len(set(hostnames)):
msg = "All hostnames must be unique."
raise ValueError(msg)
return hosts

@field_validator("hosts")
@classmethod
def validate_max_number_hosts(
cls, hosts: list[DeployedHostCreateSchema], info: ValidationInfo
) -> list[DeployedHostCreateSchema]:
"""Check that the number of hosts does not exceed subnet CIDR."""
subnet_cidr = info.data.get("cidr")

if not subnet_cidr:
msg = "Subnet missing CIDR."
raise ValueError(msg)

max_num_hosts = max_num_hosts_in_subnet(subnet_cidr)
num_requested_hosts = len(hosts)

if num_requested_hosts > max_num_hosts:
msg = f"Too many hosts in subnet! Max: {max_num_hosts}, Requested: {num_requested_hosts}"
raise ValueError(msg)

return hosts

model_config = ConfigDict(from_attributes=True)


Expand Down
Loading