Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Add slugs to products

Revision ID: 02ae611a9004
Revises: 59a5d45ae3fd
Create Date: 2025-10-17 11:41:38.305228

"""

import sqlalchemy as sa
from alembic import op
from slugify import slugify
from sqlalchemy.dialects import postgresql

# Polar Custom Imports

# revision identifiers, used by Alembic.
revision = "02ae611a9004"
down_revision = "59a5d45ae3fd"
branch_labels: tuple[str] | None = None
depends_on: tuple[str] | None = None


def upgrade() -> None:
# Add the slug column as nullable first
op.add_column("products", sa.Column("slug", postgresql.CITEXT(), nullable=True))

# Generate slugs for existing products
connection = op.get_bind()

# Get all organizations
organizations = connection.execute(
sa.text("SELECT id FROM organizations")
).fetchall()

# Process products per organization to ensure slug uniqueness within each org
for (organization_id,) in organizations:
# Get all products for this organization
products = connection.execute(
sa.text(
"SELECT id, name FROM products WHERE organization_id = :org_id ORDER BY created_at"
),
{"org_id": organization_id},
).fetchall()

# Track used slugs within this organization
used_slugs = set()

for product_id, product_name in products:
# Generate base slug from name
base_slug = slugify(
product_name,
max_length=128,
word_boundary=True,
)

# Handle collisions by appending a number
slug = base_slug
n = 0
while slug in used_slugs:
n += 1
slug = f"{base_slug}-{n}"
if n > 100: # Safety check
raise Exception(
f"Could not generate unique slug for product {product_id} in organization {organization_id}"
)

# Update the product with the slug
connection.execute(
sa.text("UPDATE products SET slug = :slug WHERE id = :id"),
{"slug": slug, "id": product_id},
)

used_slugs.add(slug)

# Make the column non-nullable
op.alter_column("products", "slug", nullable=False)

op.create_index(
op.f("ix_product_slug"),
"products",
["organization_id", "slug"],
unique=True,
)


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
"ix_product_slug",
table_name="products",
)

op.drop_column("products", "slug")
# ### end Alembic commands ###
11 changes: 11 additions & 0 deletions server/polar/models/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Boolean,
ColumnElement,
ForeignKey,
Index,
String,
Text,
Uuid,
Expand Down Expand Up @@ -47,7 +48,17 @@ class ProductBillingType(StrEnum):
class Product(TrialConfigurationMixin, MetadataMixin, RecordModel):
__tablename__ = "products"

__table_args__ = (
Index(
"ix_product_slug",
"organization_id",
"slug",
unique=True,
),
)

name: Mapped[str] = mapped_column(CITEXT(), nullable=False)
slug: Mapped[str] = mapped_column(CITEXT(), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
is_tax_applicable: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
Expand Down
12 changes: 8 additions & 4 deletions server/polar/product/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ async def list(
organization_id: MultipleQueryFilter[OrganizationID] | None = Query(
None, title="OrganizationID Filter", description="Filter by organization ID."
),
query: str | None = Query(None, description="Filter by product name."),
slug: MultipleQueryFilter[str] | None = Query(
None, title="Slug Filter", description="Filter by product slug."
),
query: str | None = Query(None, description="Filter by product name or slug."),
is_archived: bool | None = Query(None, description="Filter on archived products."),
is_recurring: bool | None = Query(
None,
Expand All @@ -82,6 +85,7 @@ async def list(
auth_subject,
id=id,
organization_id=organization_id,
slug=slug,
query=query,
is_archived=is_archived,
is_recurring=is_recurring,
Expand Down Expand Up @@ -110,7 +114,7 @@ async def get(
session: AsyncReadSession = Depends(get_db_read_session),
) -> Product:
"""Get a product by ID."""
product = await product_service.get(session, auth_subject, id)
product = await product_service.get(session, auth_subject, id=id)

if product is None:
raise ResourceNotFound()
Expand Down Expand Up @@ -154,7 +158,7 @@ async def update(
session: AsyncSession = Depends(get_db_session),
) -> Product:
"""Update a product."""
product = await product_service.get(session, auth_subject, id)
product = await product_service.get(session, auth_subject, id=id)

if product is None:
raise ResourceNotFound()
Expand Down Expand Up @@ -182,7 +186,7 @@ async def update_benefits(
session: AsyncSession = Depends(get_db_session),
) -> Product:
"""Update benefits granted by a product."""
product = await product_service.get(session, auth_subject, id)
product = await product_service.get(session, auth_subject, id=id)

if product is None:
raise ResourceNotFound()
Expand Down
25 changes: 25 additions & 0 deletions server/polar/product/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,31 @@ async def get_by_id_and_organization(
)
return await self.get_one_or_none(statement)

async def get_by_slug(
self,
slug: str,
*,
options: Options = (),
) -> Product | None:
statement = (
self.get_base_statement().where(Product.slug == slug).options(*options)
)
return await self.get_one_or_none(statement)

async def get_by_slug_and_organization(
self,
slug: str,
organization_id: UUID,
*,
options: Options = (),
) -> Product | None:
statement = (
self.get_base_statement()
.where(Product.slug == slug, Product.organization_id == organization_id)
.options(*options)
)
return await self.get_one_or_none(statement)

async def get_by_id_and_checkout(
self,
id: UUID,
Expand Down
13 changes: 13 additions & 0 deletions server/polar/product/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Schema,
SelectorWidget,
SetSchemaReference,
SlugValidator,
TimestampedSchema,
)
from polar.kit.trial import TrialConfigurationInputMixin, TrialConfigurationOutputMixin
Expand All @@ -50,6 +51,7 @@
from polar.organization.schemas import OrganizationID

PRODUCT_NAME_MIN_LENGTH = 3
PRODUCT_SLUG_MIN_LENGTH = 3

# PostgreSQL int4 range limit
INT_MAX_VALUE = 2_147_483_647
Expand Down Expand Up @@ -90,6 +92,14 @@
description="The name of the product.",
),
]
ProductSlug = Annotated[
str,
Field(
min_length=PRODUCT_SLUG_MIN_LENGTH,
description="The slug of the product.",
),
SlugValidator,
]
ProductDescription = Annotated[
str | None,
Field(description="The description of the product."),
Expand Down Expand Up @@ -295,6 +305,7 @@ def get_model_class(self) -> builtins.type[ProductPriceMeteredUnitModel]:

class ProductCreateBase(MetadataInputMixin, Schema):
name: ProductName
slug: ProductSlug = None
description: ProductDescription = None
prices: ProductPriceCreateList = Field(
...,
Expand Down Expand Up @@ -363,6 +374,7 @@ class ProductUpdate(TrialConfigurationInputMixin, MetadataInputMixin, Schema):
"""

name: ProductName | None = None
slug: ProductSlug | None = None
description: ProductDescription = None
recurring_interval: SubscriptionRecurringInterval | None = Field(
default=None,
Expand Down Expand Up @@ -632,6 +644,7 @@ def _get_discriminator_value(v: Any) -> Literal["legacy", "new"]:

class ProductBase(TrialConfigurationOutputMixin, TimestampedSchema, IDSchema):
name: str = Field(description="The name of the product.")
slug: str = Field(description="The slug of the product.")
description: str | None = Field(description="The description of the product.")
recurring_interval: SubscriptionRecurringInterval | None = Field(
description=(
Expand Down
Loading
Loading