diff --git a/server/migrations/versions/2025-10-17-1141_add_slugs_to_products.py b/server/migrations/versions/2025-10-17-1141_add_slugs_to_products.py new file mode 100644 index 0000000000..78f1f9a17f --- /dev/null +++ b/server/migrations/versions/2025-10-17-1141_add_slugs_to_products.py @@ -0,0 +1,94 @@ +"""Add slugs to products + +Revision ID: 02ae611a9004 +Revises: 59a5d45ae3fd +Create Date: 2025-10-17 11:41:38.305228 + +""" + +import sqlalchemy as sa +from alembic import op +from slugify import slugify +from sqlalchemy.dialects import postgresql + +# Polar Custom Imports + +# revision identifiers, used by Alembic. +revision = "02ae611a9004" +down_revision = "59a5d45ae3fd" +branch_labels: tuple[str] | None = None +depends_on: tuple[str] | None = None + + +def upgrade() -> None: + # Add the slug column as nullable first + op.add_column("products", sa.Column("slug", postgresql.CITEXT(), nullable=True)) + + # Generate slugs for existing products + connection = op.get_bind() + + # Get all organizations + organizations = connection.execute( + sa.text("SELECT id FROM organizations") + ).fetchall() + + # Process products per organization to ensure slug uniqueness within each org + for (organization_id,) in organizations: + # Get all products for this organization + products = connection.execute( + sa.text( + "SELECT id, name FROM products WHERE organization_id = :org_id ORDER BY created_at" + ), + {"org_id": organization_id}, + ).fetchall() + + # Track used slugs within this organization + used_slugs = set() + + for product_id, product_name in products: + # Generate base slug from name + base_slug = slugify( + product_name, + max_length=128, + word_boundary=True, + ) + + # Handle collisions by appending a number + slug = base_slug + n = 0 + while slug in used_slugs: + n += 1 + slug = f"{base_slug}-{n}" + if n > 100: # Safety check + raise Exception( + f"Could not generate unique slug for product {product_id} in organization {organization_id}" + ) + + # Update the product with the slug + connection.execute( + sa.text("UPDATE products SET slug = :slug WHERE id = :id"), + {"slug": slug, "id": product_id}, + ) + + used_slugs.add(slug) + + # Make the column non-nullable + op.alter_column("products", "slug", nullable=False) + + op.create_index( + op.f("ix_product_slug"), + "products", + ["organization_id", "slug"], + unique=True, + ) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + "ix_product_slug", + table_name="products", + ) + + op.drop_column("products", "slug") + # ### end Alembic commands ### diff --git a/server/polar/models/product.py b/server/polar/models/product.py index d6f16f3c87..3b9c327056 100644 --- a/server/polar/models/product.py +++ b/server/polar/models/product.py @@ -7,6 +7,7 @@ Boolean, ColumnElement, ForeignKey, + Index, String, Text, Uuid, @@ -47,7 +48,17 @@ class ProductBillingType(StrEnum): class Product(TrialConfigurationMixin, MetadataMixin, RecordModel): __tablename__ = "products" + __table_args__ = ( + Index( + "ix_product_slug", + "organization_id", + "slug", + unique=True, + ), + ) + name: Mapped[str] = mapped_column(CITEXT(), nullable=False) + slug: Mapped[str] = mapped_column(CITEXT(), nullable=False) description: Mapped[str | None] = mapped_column(Text, nullable=True) is_tax_applicable: Mapped[bool] = mapped_column( Boolean, nullable=False, default=True diff --git a/server/polar/product/endpoints.py b/server/polar/product/endpoints.py index 33c4a94bcc..e0fff55a81 100644 --- a/server/polar/product/endpoints.py +++ b/server/polar/product/endpoints.py @@ -59,7 +59,10 @@ async def list( organization_id: MultipleQueryFilter[OrganizationID] | None = Query( None, title="OrganizationID Filter", description="Filter by organization ID." ), - query: str | None = Query(None, description="Filter by product name."), + slug: MultipleQueryFilter[str] | None = Query( + None, title="Slug Filter", description="Filter by product slug." + ), + query: str | None = Query(None, description="Filter by product name or slug."), is_archived: bool | None = Query(None, description="Filter on archived products."), is_recurring: bool | None = Query( None, @@ -82,6 +85,7 @@ async def list( auth_subject, id=id, organization_id=organization_id, + slug=slug, query=query, is_archived=is_archived, is_recurring=is_recurring, @@ -110,7 +114,7 @@ async def get( session: AsyncReadSession = Depends(get_db_read_session), ) -> Product: """Get a product by ID.""" - product = await product_service.get(session, auth_subject, id) + product = await product_service.get(session, auth_subject, id=id) if product is None: raise ResourceNotFound() @@ -154,7 +158,7 @@ async def update( session: AsyncSession = Depends(get_db_session), ) -> Product: """Update a product.""" - product = await product_service.get(session, auth_subject, id) + product = await product_service.get(session, auth_subject, id=id) if product is None: raise ResourceNotFound() @@ -182,7 +186,7 @@ async def update_benefits( session: AsyncSession = Depends(get_db_session), ) -> Product: """Update benefits granted by a product.""" - product = await product_service.get(session, auth_subject, id) + product = await product_service.get(session, auth_subject, id=id) if product is None: raise ResourceNotFound() diff --git a/server/polar/product/repository.py b/server/polar/product/repository.py index 9c68d2b9e1..18193a2a1c 100644 --- a/server/polar/product/repository.py +++ b/server/polar/product/repository.py @@ -40,6 +40,31 @@ async def get_by_id_and_organization( ) return await self.get_one_or_none(statement) + async def get_by_slug( + self, + slug: str, + *, + options: Options = (), + ) -> Product | None: + statement = ( + self.get_base_statement().where(Product.slug == slug).options(*options) + ) + return await self.get_one_or_none(statement) + + async def get_by_slug_and_organization( + self, + slug: str, + organization_id: UUID, + *, + options: Options = (), + ) -> Product | None: + statement = ( + self.get_base_statement() + .where(Product.slug == slug, Product.organization_id == organization_id) + .options(*options) + ) + return await self.get_one_or_none(statement) + async def get_by_id_and_checkout( self, id: UUID, diff --git a/server/polar/product/schemas.py b/server/polar/product/schemas.py index 7cf0d84bee..676d121be4 100644 --- a/server/polar/product/schemas.py +++ b/server/polar/product/schemas.py @@ -25,6 +25,7 @@ Schema, SelectorWidget, SetSchemaReference, + SlugValidator, TimestampedSchema, ) from polar.kit.trial import TrialConfigurationInputMixin, TrialConfigurationOutputMixin @@ -50,6 +51,7 @@ from polar.organization.schemas import OrganizationID PRODUCT_NAME_MIN_LENGTH = 3 +PRODUCT_SLUG_MIN_LENGTH = 3 # PostgreSQL int4 range limit INT_MAX_VALUE = 2_147_483_647 @@ -90,6 +92,14 @@ description="The name of the product.", ), ] +ProductSlug = Annotated[ + str, + Field( + min_length=PRODUCT_SLUG_MIN_LENGTH, + description="The slug of the product.", + ), + SlugValidator, +] ProductDescription = Annotated[ str | None, Field(description="The description of the product."), @@ -295,6 +305,7 @@ def get_model_class(self) -> builtins.type[ProductPriceMeteredUnitModel]: class ProductCreateBase(MetadataInputMixin, Schema): name: ProductName + slug: ProductSlug = None description: ProductDescription = None prices: ProductPriceCreateList = Field( ..., @@ -363,6 +374,7 @@ class ProductUpdate(TrialConfigurationInputMixin, MetadataInputMixin, Schema): """ name: ProductName | None = None + slug: ProductSlug | None = None description: ProductDescription = None recurring_interval: SubscriptionRecurringInterval | None = Field( default=None, @@ -632,6 +644,7 @@ def _get_discriminator_value(v: Any) -> Literal["legacy", "new"]: class ProductBase(TrialConfigurationOutputMixin, TimestampedSchema, IDSchema): name: str = Field(description="The name of the product.") + slug: str = Field(description="The slug of the product.") description: str | None = Field(description="The description of the product.") recurring_interval: SubscriptionRecurringInterval | None = Field( description=( diff --git a/server/polar/product/service.py b/server/polar/product/service.py index b7de856041..326081f5d2 100644 --- a/server/polar/product/service.py +++ b/server/polar/product/service.py @@ -4,7 +4,8 @@ from typing import Any, Literal import stripe -from sqlalchemy import UnaryExpression, asc, case, desc, func, select +from slugify import slugify +from sqlalchemy import UnaryExpression, asc, case, desc, func, or_, select from sqlalchemy.orm import contains_eager, selectinload from polar.auth.models import AuthSubject, is_user @@ -68,6 +69,7 @@ async def list( *, id: Sequence[uuid.UUID] | None = None, organization_id: Sequence[uuid.UUID] | None = None, + slug: Sequence[str] | None = None, query: str | None = None, is_archived: bool | None = None, is_recurring: bool | None = None, @@ -104,8 +106,13 @@ async def list( if organization_id is not None: statement = statement.where(Product.organization_id.in_(organization_id)) + if slug is not None: + statement = statement.where(Product.slug.in_(slug)) + if query is not None: - statement = statement.where(Product.name.ilike(f"%{query}%")) + statement = statement.where( + or_(Product.name.ilike(f"%{query}%"), Product.slug.ilike(f"%{query}%")) + ) if is_archived is not None: statement = statement.where(Product.is_archived.is_(is_archived)) @@ -187,14 +194,25 @@ async def get( self, session: AsyncReadSession, auth_subject: AuthSubject[User | Organization], - id: uuid.UUID, + id: uuid.UUID | None, + slug: str | None, ) -> Product | None: repository = ProductRepository.from_session(session) - statement = ( - repository.get_readable_statement(auth_subject) - .where(Product.id == id) - .options(*repository.get_eager_options()) - ) + if id is not None: + statement = ( + repository.get_readable_statement(auth_subject) + .where(Product.id == id) + .options(*repository.get_eager_options()) + ) + elif slug is not None: + statement = ( + repository.get_readable_statement(auth_subject) + .where(Product.slug == slug) + .options(*repository.get_eager_options()) + ) + else: + raise ValueError("Either id or slug must be provided") + return await repository.get_one_or_none(statement) async def get_embed( @@ -208,6 +226,47 @@ async def get_embed( ) return await repository.get_one_or_none(statement) + async def slugify( + self, + session: AsyncSession, + schema: ProductCreate | ProductUpdate, + organization_id: uuid.UUID, + ) -> str: + repository = ProductRepository.from_session(session) + + slug = ( + slugify( + schema.name, + max_length=128, # arbitrary + word_boundary=True, + ) + if schema.slug is None + else schema.slug + ) + + orig_slug = slug + + for n in range(0, 100): + test_slug = orig_slug if n == 0 else f"{orig_slug}-{n}" + + exists = await repository.get_by_slug_and_organization( + test_slug, organization_id + ) + + # slug is unused, continue with creating a product with this slug + if exists is None: + slug = test_slug + break + + # continue until a free slug has been found + else: + # if no free slug has been found in 100 attempts, error out + raise Exception( + "This slug has been used more than 100 times in this organization." + ) + + return slug + async def create( self, session: AsyncSession, @@ -240,6 +299,8 @@ async def create( ) errors.extend(prices_errors) + slug = await self.slugify(session, create_schema, organization.id) + product = await repository.create( Product( organization=organization, @@ -248,12 +309,14 @@ async def create( product_benefits=[], product_medias=[], attached_custom_fields=[], + slug=slug, **create_schema.model_dump( exclude={ "organization_id", "prices", "medias", "attached_custom_fields", + "slug", }, by_alias=True, ), @@ -511,6 +574,11 @@ async def update( ) price.stripe_price_id = stripe_price.id + if update_schema.slug is not None: + product.slug = await self.slugify( + session, update_schema, product.organization_id + ) + if update_schema.is_archived: product = await self._archive(session, product) diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index f76bc0ad16..6fed17d4cd 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -360,6 +360,7 @@ async def create_product( organization: Organization, recurring_interval: SubscriptionRecurringInterval | None, name: str = "Product", + slug: str | None = None, is_archived: bool = False, prices: Sequence[PriceFixtureType] = [(1000,)], attached_custom_fields: Sequence[tuple[CustomField, bool]] = [], @@ -369,6 +370,7 @@ async def create_product( ) -> Product: product = Product( name=name, + slug=slug or rstr("product-slug"), description="Description", is_tax_applicable=is_tax_applicable, recurring_interval=recurring_interval, diff --git a/server/tests/product/test_endpoints.py b/server/tests/product/test_endpoints.py index c886f6c5bf..e6a4eb9f01 100644 --- a/server/tests/product/test_endpoints.py +++ b/server/tests/product/test_endpoints.py @@ -481,3 +481,140 @@ async def test_valid( json = response.json() assert len(json["benefits"]) == 1 + + +@pytest.mark.asyncio +class TestProductSlug: + """Test slug functionality in product endpoints.""" + + @pytest.mark.auth + async def test_auto_slug_generation( + self, + session: AsyncSession, + client: AsyncClient, + organization: Organization, + user_organization: UserOrganization, + stripe_service_mock: MagicMock, + ) -> None: + """Test that slug is automatically generated from product name.""" + create_product_mock: MagicMock = stripe_service_mock.create_product + create_product_mock.return_value = SimpleNamespace(id="PRODUCT_ID") + create_price_mock: MagicMock = stripe_service_mock.create_price_for_product + create_price_mock.return_value = SimpleNamespace(id="PRICE_ID") + + response = await client.post( + "/v1/products/", + json={ + "name": "Premium Subscription", + "organization_id": str(organization.id), + "prices": [{"amount_type": "fixed", "price_amount": 1000}], + }, + ) + + assert response.status_code == 201 + json = response.json() + assert json["name"] == "Premium Subscription" + assert json["slug"] == "premium-subscription" + assert "slug" in json + + @pytest.mark.auth + async def test_custom_slug( + self, + session: AsyncSession, + client: AsyncClient, + organization: Organization, + user_organization: UserOrganization, + stripe_service_mock: MagicMock, + ) -> None: + """Test providing a custom slug.""" + create_product_mock: MagicMock = stripe_service_mock.create_product + create_product_mock.return_value = SimpleNamespace(id="PRODUCT_ID") + create_price_mock: MagicMock = stripe_service_mock.create_price_for_product + create_price_mock.return_value = SimpleNamespace(id="PRICE_ID") + + response = await client.post( + "/v1/products/", + json={ + "name": "Premium Subscription", + "slug": "my-custom-slug", + "organization_id": str(organization.id), + "prices": [{"amount_type": "fixed", "price_amount": 1000}], + }, + ) + + assert response.status_code == 201 + json = response.json() + assert json["slug"] == "my-custom-slug" + + @pytest.mark.auth + async def test_slug_in_list_response( + self, + session: AsyncSession, + client: AsyncClient, + organization: Organization, + user_organization: UserOrganization, + product: Product, + ) -> None: + """Test that slug is included in list response.""" + response = await client.get( + "/v1/products/", + params={"organization_id": str(organization.id)}, + ) + + assert response.status_code == 200 + json = response.json() + assert len(json["items"]) > 0 + item = json["items"][0] + assert "slug" in item + assert isinstance(item["slug"], str) + + @pytest.mark.auth + async def test_filter_by_slug( + self, + session: AsyncSession, + save_fixture: SaveFixture, + client: AsyncClient, + organization: Organization, + user_organization: UserOrganization, + product: Product, + ) -> None: + """Test filtering products by slug.""" + response = await client.get( + "/v1/products/", + params={ + "organization_id": str(organization.id), + "slug": product.slug, + }, + ) + + assert response.status_code == 200 + json = response.json() + assert json["pagination"]["total_count"] == 1 + items = json["items"] + assert len(items) == 1 + assert items[0]["id"] == str(product.id) + assert items[0]["slug"] == product.slug + + @pytest.mark.auth + async def test_search_by_slug( + self, + session: AsyncSession, + client: AsyncClient, + organization: Organization, + user_organization: UserOrganization, + product: Product, + ) -> None: + """Test searching products by slug using query parameter.""" + response = await client.get( + "/v1/products/", + params={ + "organization_id": str(organization.id), + "query": product.slug, + }, + ) + + assert response.status_code == 200 + json = response.json() + # Product should be found + product_ids = [item["id"] for item in json["items"]] + assert str(product.id) in product_ids diff --git a/server/tests/product/test_service.py b/server/tests/product/test_service.py index 05b1f4cabd..99a6d976a6 100644 --- a/server/tests/product/test_service.py +++ b/server/tests/product/test_service.py @@ -278,7 +278,9 @@ async def test_user( # then session.expunge_all() - retrieved_product = await product_service.get(session, auth_subject, product.id) + retrieved_product = await product_service.get( + session, auth_subject, product.id, slug=None + ) assert retrieved_product is None @pytest.mark.auth @@ -294,12 +296,12 @@ async def test_user_organization( session.expunge_all() not_existing_product = await product_service.get( - session, auth_subject, uuid.uuid4() + session, auth_subject, uuid.uuid4(), slug=None ) assert not_existing_product is None accessible_product = await product_service.get( - session, auth_subject, product.id + session, auth_subject, product.id, slug=None ) assert accessible_product is not None assert accessible_product.id == product.id @@ -315,12 +317,12 @@ async def test_organization( session.expunge_all() not_existing_product = await product_service.get( - session, auth_subject, uuid.uuid4() + session, auth_subject, uuid.uuid4(), slug=None ) assert not_existing_product is None accessible_product = await product_service.get( - session, auth_subject, product.id + session, auth_subject, product.id, slug=None ) assert accessible_product is not None assert accessible_product.id == product.id @@ -1838,3 +1840,515 @@ async def test_has_seat_based_price( assert product_with_seats.has_seat_based_price is True assert product_without_seats.has_seat_based_price is False + + +@pytest.mark.asyncio +class TestSlugGeneration: + """Test slug generation and handling in product service.""" + + @pytest.mark.auth + async def test_slugify_from_name( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + enqueue_job_mock: AsyncMock, + stripe_service_mock: MagicMock, + ) -> None: + """Test automatic slug generation from product name.""" + create_product_mock: MagicMock = stripe_service_mock.create_product + create_product_mock.return_value = SimpleNamespace(id="PRODUCT_ID") + + create_price_for_product_mock: MagicMock = ( + stripe_service_mock.create_price_for_product + ) + create_price_for_product_mock.return_value = SimpleNamespace(id="PRICE_ID") + + product = await product_service.create( + session, + ProductCreateOneTime( + name="Premium Subscription Plan", + organization_id=organization.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=1000, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + + assert product.slug == "premium-subscription-plan" + + @pytest.mark.auth + async def test_slugify_with_special_characters( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + enqueue_job_mock: AsyncMock, + stripe_service_mock: MagicMock, + ) -> None: + """Test slug generation handles special characters.""" + create_product_mock: MagicMock = stripe_service_mock.create_product + create_product_mock.return_value = SimpleNamespace(id="PRODUCT_ID") + + create_price_for_product_mock: MagicMock = ( + stripe_service_mock.create_price_for_product + ) + create_price_for_product_mock.return_value = SimpleNamespace(id="PRICE_ID") + + product = await product_service.create( + session, + ProductCreateOneTime( + name="Pro++ Plan (2024) @ $99/mo", + organization_id=organization.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=9900, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + + # Should strip special chars and normalize + assert "pro" in product.slug + assert "plan" in product.slug + assert "2024" in product.slug + # Special characters should be removed or replaced + assert "++" not in product.slug + assert "@" not in product.slug + assert "$" not in product.slug + + @pytest.mark.auth + async def test_custom_slug( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + enqueue_job_mock: AsyncMock, + stripe_service_mock: MagicMock, + ) -> None: + """Test providing a custom slug.""" + create_product_mock: MagicMock = stripe_service_mock.create_product + create_product_mock.return_value = SimpleNamespace(id="PRODUCT_ID") + + create_price_for_product_mock: MagicMock = ( + stripe_service_mock.create_price_for_product + ) + create_price_for_product_mock.return_value = SimpleNamespace(id="PRICE_ID") + + product = await product_service.create( + session, + ProductCreateOneTime( + name="Some Product Name", + slug="my-custom-slug", + organization_id=organization.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=1000, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + + assert product.slug == "my-custom-slug" + + @pytest.mark.auth + async def test_duplicate_slug_handling( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + enqueue_job_mock: AsyncMock, + stripe_service_mock: MagicMock, + ) -> None: + """Test that duplicate slugs within the same organization are handled with numeric suffixes.""" + create_product_mock: MagicMock = stripe_service_mock.create_product + create_product_mock.return_value = SimpleNamespace(id="PRODUCT_ID") + + create_price_for_product_mock: MagicMock = ( + stripe_service_mock.create_price_for_product + ) + create_price_for_product_mock.return_value = SimpleNamespace(id="PRICE_ID") + + # Create first product + product1 = await product_service.create( + session, + ProductCreateOneTime( + name="Test Product", + organization_id=organization.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=1000, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + assert product1.slug == "test-product" + + # Create second product with same name in same organization + product2 = await product_service.create( + session, + ProductCreateOneTime( + name="Test Product", + organization_id=organization.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=2000, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + assert product2.slug == "test-product-1" + + # Create third product with same name in same organization + product3 = await product_service.create( + session, + ProductCreateOneTime( + name="Test Product", + organization_id=organization.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=3000, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + assert product3.slug == "test-product-2" + + @pytest.mark.auth + async def test_slug_unique_per_organization( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + user: User, + enqueue_job_mock: AsyncMock, + stripe_service_mock: MagicMock, + ) -> None: + """Test that the same slug can be used across different organizations.""" + create_product_mock: MagicMock = stripe_service_mock.create_product + create_product_mock.return_value = SimpleNamespace(id="PRODUCT_ID") + + create_price_for_product_mock: MagicMock = ( + stripe_service_mock.create_price_for_product + ) + create_price_for_product_mock.return_value = SimpleNamespace(id="PRICE_ID") + + # Create first organization and product + from tests.fixtures.random_objects import create_organization + + org1 = await create_organization(save_fixture) + user_org1 = UserOrganization(user=user, organization=org1) + await save_fixture(user_org1) + + product1 = await product_service.create( + session, + ProductCreateOneTime( + name="Premium Plan", + organization_id=org1.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=1000, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + assert product1.slug == "premium-plan" + assert product1.organization_id == org1.id + + # Create second organization and product with same name + org2 = await create_organization(save_fixture) + user_org2 = UserOrganization(user=user, organization=org2) + await save_fixture(user_org2) + + product2 = await product_service.create( + session, + ProductCreateOneTime( + name="Premium Plan", + organization_id=org2.id, + prices=[ + ProductPriceFixedCreate( + amount_type=ProductPriceAmountType.fixed, + price_amount=2000, + price_currency="usd", + ) + ], + ), + auth_subject, + ) + # Should have the same slug since it's in a different organization + assert product2.slug == "premium-plan" + assert product2.organization_id == org2.id + assert product1.organization_id != product2.organization_id + + +@pytest.mark.asyncio +class TestListBySlug: + """Test filtering and searching products by slug.""" + + @pytest.mark.auth + async def test_filter_by_slug( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + """Test filtering products by exact slug match.""" + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Test Product", + slug="test-product-slug", + ) + await session.refresh(product) + + # Filter by slug + results, count = await product_service.list( + session, + auth_subject, + slug=["test-product-slug"], + pagination=PaginationParams(1, 10), + ) + + assert count == 1 + assert len(results) == 1 + assert results[0].id == product.id + assert results[0].slug == "test-product-slug" + + @pytest.mark.auth + async def test_filter_by_multiple_slugs( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + """Test filtering products by multiple slugs.""" + product1 = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Product One", + slug="product-one", + ) + product2 = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Product Two", + slug="product-two", + ) + await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Product Three", + slug="product-three", + ) + await session.refresh(product1) + await session.refresh(product2) + + # Filter by multiple slugs + results, count = await product_service.list( + session, + auth_subject, + slug=["product-one", "product-two"], + pagination=PaginationParams(1, 10), + ) + + assert count == 2 + assert len(results) == 2 + product_ids = {p.id for p in results} + assert product1.id in product_ids + assert product2.id in product_ids + + @pytest.mark.auth + async def test_search_by_slug( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + """Test searching products by slug using query parameter.""" + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Different Name", + slug="searchable-slug-test", + ) + await session.refresh(product) + + # Search by slug + results, count = await product_service.list( + session, + auth_subject, + query="searchable-slug", + pagination=PaginationParams(1, 10), + ) + + assert count >= 1 + product_ids = [p.id for p in results] + assert product.id in product_ids + + @pytest.mark.auth + async def test_search_matches_name_and_slug( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + """Test that query searches both name and slug.""" + # Product with keyword in name + product_with_name = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Enterprise Edition", + slug="enterprise-slug", + ) + # Product with keyword in slug + product_with_slug = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Pro Edition", + slug="pro-enterprise-plan", + ) + await session.refresh(product_with_name) + await session.refresh(product_with_slug) + + # Search for "enterprise" - should match both + results, count = await product_service.list( + session, + auth_subject, + query="enterprise", + pagination=PaginationParams(1, 10), + ) + + assert count >= 2 + product_ids = [p.id for p in results] + assert product_with_name.id in product_ids + assert product_with_slug.id in product_ids + + +@pytest.mark.asyncio +class TestGetBySlug: + """Test getting products by slug.""" + + @pytest.mark.auth + async def test_get_by_slug( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + """Test getting a product by slug.""" + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Test Product", + slug="unique-test-slug", + ) + await session.refresh(product) + + # Get by slug + result = await product_service.get( + session, + auth_subject, + id=None, + slug="unique-test-slug", + ) + + assert result is not None + assert result.id == product.id + assert result.slug == "unique-test-slug" + + @pytest.mark.auth + async def test_get_by_id_still_works( + self, + session: AsyncSession, + save_fixture: SaveFixture, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + """Test that getting by ID still works (backward compatibility).""" + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=None, + name="Test Product", + slug="test-slug", + ) + await session.refresh(product) + + # Get by ID + result = await product_service.get( + session, + auth_subject, + id=product.id, + slug=None, + ) + + assert result is not None + assert result.id == product.id + + @pytest.mark.auth + async def test_get_nonexistent_slug( + self, + session: AsyncSession, + auth_subject: AuthSubject[User], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + """Test getting a product with non-existent slug returns None.""" + result = await product_service.get( + session, + auth_subject, + id=None, + slug="nonexistent-slug", + ) + + assert result is None