diff --git a/CLAUDE.md b/CLAUDE.md index 55c93a068c..c37eddc2a0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -142,6 +142,7 @@ docker compose up -d # Start PostgreSQL, Redis, Minio - Include HTTP status codes in custom exceptions - Use dependency injection for database sessions - All DB queries should be in the Repository class. Use the right repository class. +- Always place imports at the top of files, not inside functions or methods. In most cases, you should never call `session.commit()` directly in business logic. We have established patterns for that: the API backend automatically commits the session at the end of each request, and background workers commit the session at the end of each task. It avoids to have a database in an inconsistent state in case of exceptions. If you have a `session.commit()` in your code, it's likely a mistake. Otherwise, please explicitly document why it's necessary. diff --git a/server/polar/customer_portal/endpoints/customer_seat.py b/server/polar/customer_portal/endpoints/customer_seat.py index b7207f6877..f24289aba2 100644 --- a/server/polar/customer_portal/endpoints/customer_seat.py +++ b/server/polar/customer_portal/endpoints/customer_seat.py @@ -168,12 +168,12 @@ async def assign_seat( metadata=seat_assign.metadata, ) - # Reload seat with customer relationship + # Reload seat with customer and member relationships seat_repository = CustomerSeatRepository.from_session(session) seat_statement = ( seat_repository.get_base_statement() .where(CustomerSeat.id == seat.id) - .options(joinedload(CustomerSeat.customer)) + .options(joinedload(CustomerSeat.customer), joinedload(CustomerSeat.member)) ) reloaded_seat = await seat_repository.get_one_or_none(seat_statement) diff --git a/server/polar/customer_seat/endpoints.py b/server/polar/customer_seat/endpoints.py index d3f6966352..afbcec82f3 100644 --- a/server/polar/customer_seat/endpoints.py +++ b/server/polar/customer_seat/endpoints.py @@ -341,12 +341,22 @@ async def get_claim_info( if not organization: raise ResourceNotFound("Organization not found") + # Get customer email with priority: seat.email > seat.member.email > seat.customer.email + # This handles both member_model_enabled=True (email on seat) and False (email on customer) + customer_email = "" + if seat.email: + customer_email = seat.email + elif seat.member: + customer_email = seat.member.email + elif seat.customer: + customer_email = seat.customer.email + return SeatClaimInfo( product_name=product.name, product_id=product.id, organization_name=organization.name, organization_slug=organization.slug, - customer_email=seat.customer.email if seat.customer else "", + customer_email=customer_email, can_claim=seat.status == SeatStatus.pending, ) diff --git a/server/polar/customer_seat/repository.py b/server/polar/customer_seat/repository.py index 1972fe6f7d..d48e16dbbd 100644 --- a/server/polar/customer_seat/repository.py +++ b/server/polar/customer_seat/repository.py @@ -59,6 +59,59 @@ async def get_by_container_and_customer( container.id, customer_id, options=options ) + async def get_by_container_and_email( + self, + container: SeatContainer, + email: str, + *, + options: Options = (), + ) -> CustomerSeat | None: + """Get seat by container and email (for member_model_enabled path).""" + if isinstance(container, Subscription): + return await self.get_by_subscription_and_email( + container.id, email, options=options + ) + else: + return await self.get_by_order_and_email( + container.id, email, options=options + ) + + async def get_by_subscription_and_email( + self, + subscription_id: UUID, + email: str, + *, + options: Options = (), + ) -> CustomerSeat | None: + """Get seat by subscription ID and email.""" + statement = ( + select(CustomerSeat) + .where( + CustomerSeat.subscription_id == subscription_id, + func.lower(CustomerSeat.email) == email.lower(), + ) + .options(*options) + ) + return await self.get_one_or_none(statement) + + async def get_by_order_and_email( + self, + order_id: UUID, + email: str, + *, + options: Options = (), + ) -> CustomerSeat | None: + """Get seat by order ID and email.""" + statement = ( + select(CustomerSeat) + .where( + CustomerSeat.order_id == order_id, + func.lower(CustomerSeat.email) == email.lower(), + ) + .options(*options) + ) + return await self.get_one_or_none(statement) + async def get_revoked_seat_by_container( self, container: SeatContainer, @@ -358,10 +411,11 @@ def get_eager_options(self) -> Options: joinedload(Subscription.customer), ), joinedload(CustomerSeat.order).options( - joinedload(Order.product), + joinedload(Order.product).joinedload(Product.organization), joinedload(Order.customer).joinedload(Customer.organization), ), joinedload(CustomerSeat.customer), + joinedload(CustomerSeat.member), ) def get_eager_options_with_prices(self) -> Options: diff --git a/server/polar/customer_seat/schemas.py b/server/polar/customer_seat/schemas.py index 92bb875f70..e739618775 100644 --- a/server/polar/customer_seat/schemas.py +++ b/server/polar/customer_seat/schemas.py @@ -95,7 +95,20 @@ class CustomerSeat(TimestampedSchema): None, description="The order ID (for one-time purchase seats)" ) status: SeatStatus = Field(..., description="Status of the seat") - customer_id: UUID | None = Field(None, description="The assigned customer ID") + customer_id: UUID | None = Field( + None, + description=( + "The customer ID. When member_model_enabled is true, this is the billing " + "customer (purchaser). When false, this is the seat member customer." + ), + ) + member_id: UUID | None = Field( + None, description="The member ID of the seat occupant" + ) + email: str | None = Field( + None, + description="Email of the seat member (set when member_model_enabled is true)", + ) customer_email: str | None = Field(None, description="The assigned customer email") invitation_token_expires_at: datetime | None = Field( None, description="When the invitation token expires" @@ -110,20 +123,30 @@ class CustomerSeat(TimestampedSchema): @classmethod def extract_customer_email(cls, data: Any) -> Any: if isinstance(data, dict): - # For dict data - if "customer" in data and data["customer"]: + # For dict data - priority: email > member.email > customer.email + if "email" in data and data["email"]: + data["customer_email"] = data["email"] + elif "member" in data and data["member"]: + data["customer_email"] = data.get("member", {}).get("email") + elif "customer" in data and data["customer"]: data["customer_email"] = data.get("customer", {}).get("email") return data elif hasattr(data, "__dict__"): - # For SQLAlchemy models - check if customer is loaded - state = inspect(data) - if "customer" not in state.unloaded: - # Customer is loaded, we can extract the email - # But we need to let Pydantic handle the model conversion - # We'll just add the customer_email field if customer is available - if hasattr(data, "customer") and data.customer: - # Add customer_email as a temporary attribute - object.__setattr__(data, "customer_email", data.customer.email) + # For SQLAlchemy models - check if email is set on the seat first + # Priority: seat.email > seat.member.email > seat.customer.email + if hasattr(data, "email") and data.email: + object.__setattr__(data, "customer_email", data.email) + else: + state = inspect(data) + # Try member first + if "member" not in state.unloaded: + if hasattr(data, "member") and data.member: + object.__setattr__(data, "customer_email", data.member.email) + return data + # Fall back to customer + if "customer" not in state.unloaded: + if hasattr(data, "customer") and data.customer: + object.__setattr__(data, "customer_email", data.customer.email) return data diff --git a/server/polar/customer_seat/service.py b/server/polar/customer_seat/service.py index 2c90f8b825..f91c2d8d03 100644 --- a/server/polar/customer_seat/service.py +++ b/server/polar/customer_seat/service.py @@ -1,6 +1,7 @@ import secrets import uuid from collections.abc import Sequence +from dataclasses import dataclass from datetime import UTC, datetime, timedelta from typing import Any @@ -15,10 +16,12 @@ from polar.eventstream.service import publish as eventstream_publish from polar.exceptions import PolarError from polar.kit.db.postgres import AsyncSession +from polar.member.repository import MemberRepository from polar.member.service import member_service from polar.models import ( Customer, CustomerSeat, + Member, Order, Organization, Product, @@ -26,6 +29,7 @@ User, ) from polar.models.customer_seat import SeatStatus +from polar.models.member import MemberRole from polar.models.order import OrderStatus from polar.models.webhook_endpoint import WebhookEventType from polar.organization.repository import OrganizationRepository @@ -75,8 +79,9 @@ def __init__(self) -> None: class InvalidSeatAssignmentRequest(SeatError): - def __init__(self) -> None: - message = "Exactly one of email, external_customer_id, or customer_id must be provided" + def __init__(self, message: str | None = None) -> None: + if message is None: + message = "Exactly one of email, external_customer_id, or customer_id must be provided" super().__init__(message, 400) @@ -90,6 +95,33 @@ def __init__(self, customer_identifier: str) -> None: SeatContainer = Subscription | Order +@dataclass +class SeatAssignmentTarget: + """Resolved target for a seat assignment. + + This dataclass unifies the result of resolving who a seat is being assigned to, + regardless of whether member_model_enabled is True or False. + """ + + customer_id: uuid.UUID + """The customer_id to store on the seat. + - member_model_enabled=True: billing customer (purchaser) + - member_model_enabled=False: seat member's customer + """ + + member_id: uuid.UUID | None + """The member_id to store on the seat (if member was created).""" + + email: str | None + """The email to store on the seat. + - member_model_enabled=True: seat member's email + - member_model_enabled=False: None (email comes from customer) + """ + + seat_member_email: str + """The email of the person getting the seat (for invitation emails).""" + + class SeatService: def _get_customer_id(self, container: SeatContainer) -> uuid.UUID: return container.customer_id @@ -219,62 +251,66 @@ async def assign_seat( metadata: dict[str, Any] | None = None, immediate_claim: bool = False, ) -> CustomerSeat: + # 1. Common setup and validation product = self._get_product(container) source_id = self._get_container_id(container) if product is None: - raise SeatNotAvailable( - source_id, - "Container has no associated product", - ) + raise SeatNotAvailable(source_id, "Container has no associated product") organization_id = self._get_organization_id(container) billing_manager_customer = container.customer + billing_customer_id = container.customer_id is_subscription = self._is_subscription(container) await self.check_seat_feature_enabled(session, organization_id) - # Validate order payment status - if isinstance(container, Order): - if container.status == OrderStatus.pending: - raise SeatNotAvailable( - source_id, "Order must be paid before assigning seats" - ) + if isinstance(container, Order) and container.status == OrderStatus.pending: + raise SeatNotAvailable( + source_id, "Order must be paid before assigning seats" + ) repository = CustomerSeatRepository.from_session(session) - available_seats = await repository.get_available_seats_count_for_container( container ) - if available_seats <= 0: raise SeatNotAvailable(source_id) - customer = await self._find__or_create_customer( - session, - organization_id, - email, - external_customer_id, - customer_id, - ) - + # 2. Get organization and check feature flag organization_repository = OrganizationRepository.from_session(session) organization = await organization_repository.get_by_id(organization_id) - member = None - if organization: - member = await member_service.get_or_create_seat_member( - session, customer, organization - ) - - existing_seat = await repository.get_by_container_and_customer( - container, customer.id + member_model_enabled = ( + organization.feature_settings.get("member_model_enabled", False) + if organization + else False ) - if existing_seat and not existing_seat.is_revoked(): - identifier = email or external_customer_id or str(customer_id) - raise SeatAlreadyAssigned(identifier) + # 3. Resolve seat assignment target (the ONLY branching point) + if member_model_enabled: + target = await self._resolve_member_model_target( + session, + repository, + container, + billing_customer_id, + organization_id, + email, + customer_id, + external_customer_id, + ) + else: + target = await self._resolve_legacy_target( + session, + repository, + container, + organization, + organization_id, + email, + customer_id, + external_customer_id, + ) - # Only generate invitation token for standard (non-immediate) claims + # 4. Generate invitation token (unified) if immediate_claim: invitation_token = None token_expires_at = None @@ -282,17 +318,17 @@ async def assign_seat( invitation_token = secrets.token_urlsafe(32) token_expires_at = datetime.now(UTC) + timedelta(days=1) + # 5. Create or reuse seat (unified) revoked_seat = await repository.get_revoked_seat_by_container(container) - member_id = member.id if member else None - if revoked_seat: seat = revoked_seat seat.status = SeatStatus.claimed if immediate_claim else SeatStatus.pending seat.invitation_token = invitation_token seat.invitation_token_expires_at = token_expires_at - seat.customer_id = customer.id - seat.member_id = member_id + seat.customer_id = target.customer_id + seat.member_id = target.member_id + seat.email = target.email seat.seat_metadata = metadata or {} seat.revoked_at = None seat.claimed_at = datetime.now(UTC) if immediate_claim else None @@ -301,8 +337,9 @@ async def assign_seat( "status": SeatStatus.claimed if immediate_claim else SeatStatus.pending, "invitation_token": invitation_token, "invitation_token_expires_at": token_expires_at, - "customer_id": customer.id, - "member_id": member_id, + "customer_id": target.customer_id, + "member_id": target.member_id, + "email": target.email, "seat_metadata": metadata or {}, "claimed_at": datetime.now(UTC) if immediate_claim else None, } @@ -316,39 +353,37 @@ async def assign_seat( await session.flush() + # 6. Post-creation actions (unified) if immediate_claim: - # Immediate claim flow: grant benefits and trigger claimed webhook log.info( "Seat immediately claimed", subscription_id=seat.subscription_id, order_id=seat.order_id, - email=email, - customer_id=customer.id, + email=target.seat_member_email, + customer_id=seat.customer_id, + member_model_enabled=member_model_enabled, ) - await self._publish_seat_claimed_event(seat, product.id) await self._enqueue_benefit_grant(seat, product.id) await self._send_seat_claimed_webhook(session, organization_id, seat) else: - # Standard flow: send invitation email and trigger assigned webhook log.info( "Seat assigned", subscription_id=seat.subscription_id, order_id=seat.order_id, - email=email, - customer_id=customer.id, + email=target.seat_member_email, + customer_id=seat.customer_id, invitation_token=invitation_token or "none", + member_model_enabled=member_model_enabled, ) - if organization: send_seat_invitation_email( - customer_email=customer.email, + customer_email=target.seat_member_email, seat=seat, organization=organization, product_name=product.name, billing_manager_email=billing_manager_customer.email, ) - await webhook_service.send( session, organization, @@ -406,24 +441,50 @@ async def claim_seat( if seat.is_claimed(): raise InvalidInvitationToken(invitation_token) - # Get product and organization_id from either subscription or order + # Get product and organization from either subscription or order if seat.subscription_id and seat.subscription: product = seat.subscription.product - organization_id = product.organization_id - product_id = product.id + organization = product.organization elif seat.order_id and seat.order: assert seat.order.product is not None product = seat.order.product - organization_id = product.organization_id - product_id = product.id + organization = seat.order.organization else: raise InvalidInvitationToken(invitation_token) + organization_id = product.organization_id + product_id = product.id + await self.check_seat_feature_enabled(session, organization_id) - if not seat.customer_id or not seat.customer: + # Validate seat has required data + if not seat.customer_id: raise InvalidInvitationToken(invitation_token) + # Get customer for session creation + # Both paths use seat.customer_id, but it points to different customers: + # - member_model: billing customer (purchaser) + # - legacy: seat member's customer + member_model_enabled = organization.feature_settings.get( + "member_model_enabled", False + ) + + if member_model_enabled: + # Validate member exists for member model + if not seat.member_id: + raise InvalidInvitationToken(invitation_token) + # Load billing customer for session + customer_repository = CustomerRepository.from_session(session) + session_customer = await customer_repository.get_by_id(seat.customer_id) + if not session_customer: + raise InvalidInvitationToken(invitation_token) + else: + # Use seat's customer relationship for legacy model + if not seat.customer: + raise InvalidInvitationToken(invitation_token) + session_customer = seat.customer + + # Claim the seat (unified) seat.status = SeatStatus.claimed seat.claimed_at = datetime.now(UTC) seat.invitation_token = None # Single-use token @@ -432,15 +493,18 @@ async def claim_seat( await self._publish_seat_claimed_event(seat, product_id) await self._enqueue_benefit_grant(seat, product_id) + session_token, _ = await customer_session_service.create_customer_session( - session, seat.customer + session, session_customer ) log.info( "Seat claimed", seat_id=seat.id, customer_id=seat.customer_id, + member_id=seat.member_id, subscription_id=seat.subscription_id, + member_model_enabled=member_model_enabled, **(request_metadata or {}), ) @@ -453,18 +517,25 @@ async def revoke_seat( session: AsyncSession, seat: CustomerSeat, ) -> CustomerSeat: - # Get product from either subscription or order + # Get product and organization from either subscription or order if seat.subscription_id and seat.subscription: organization_id = seat.subscription.product.organization_id product_id = seat.subscription.product_id + organization = seat.subscription.product.organization elif seat.order_id and seat.order and seat.order.product_id: organization_id = seat.order.organization.id product_id = seat.order.product_id + organization = seat.order.organization else: raise ValueError("Seat must have either subscription or order") await self.check_seat_feature_enabled(session, organization_id) + # Check feature flag + member_model_enabled = organization.feature_settings.get( + "member_model_enabled", False + ) + # Capture customer_id and member_id before clearing to avoid race condition original_customer_id = seat.customer_id original_member_id = seat.member_id @@ -492,8 +563,10 @@ async def revoke_seat( seat.status = SeatStatus.revoked seat.revoked_at = datetime.now(UTC) - seat.customer_id = None seat.invitation_token = None + seat.customer_id = None + seat.member_id = None + seat.email = None await session.flush() @@ -502,17 +575,15 @@ async def revoke_seat( seat_id=seat.id, subscription_id=seat.subscription_id, order_id=seat.order_id, + member_model_enabled=member_model_enabled, ) - organization_repository = OrganizationRepository.from_session(session) - organization = await organization_repository.get_by_id(organization_id) - if organization: - await webhook_service.send( - session, - organization, - WebhookEventType.customer_seat_revoked, - seat, - ) + await webhook_service.send( + session, + organization, + WebhookEventType.customer_seat_revoked, + seat, + ) return seat @@ -554,13 +625,15 @@ async def resend_invitation( session: AsyncSession, seat: CustomerSeat, ) -> CustomerSeat: - # Get product info from either subscription or order + # Get product info and organization from either subscription or order if seat.subscription_id and seat.subscription and seat.subscription.product: organization_id = seat.subscription.product.organization_id + organization = seat.subscription.product.organization product_name = seat.subscription.product.name billing_manager_email = seat.subscription.customer.email elif seat.order_id and seat.order and seat.order.product: organization_id = seat.order.product.organization_id + organization = seat.order.organization product_name = seat.order.product.name billing_manager_email = seat.order.customer.email else: @@ -571,27 +644,42 @@ async def resend_invitation( if not seat.is_pending(): raise SeatNotPending() - if not seat.customer or not seat.invitation_token: + if not seat.invitation_token: raise InvalidInvitationToken(seat.invitation_token or "") + # Check feature flag + member_model_enabled = organization.feature_settings.get( + "member_model_enabled", False + ) + + # Determine the seat member email based on feature flag + if member_model_enabled: + # NEW PATH: Use seat.email + if not seat.email: + raise InvalidInvitationToken(seat.invitation_token or "") + seat_member_email = seat.email + else: + # OLD PATH: Use seat.customer.email + if not seat.customer: + raise InvalidInvitationToken(seat.invitation_token or "") + seat_member_email = seat.customer.email + log.info( "Resending seat invitation", seat_id=seat.id, customer_id=seat.customer_id, subscription_id=seat.subscription_id, order_id=seat.order_id, + member_model_enabled=member_model_enabled, ) - organization_repository = OrganizationRepository.from_session(session) - organization = await organization_repository.get_by_id(organization_id) - if organization: - send_seat_invitation_email( - customer_email=seat.customer.email, - seat=seat, - organization=organization, - product_name=product_name, - billing_manager_email=billing_manager_email, - ) + send_seat_invitation_email( + customer_email=seat_member_email, + seat=seat, + organization=organization, + product_name=product_name, + billing_manager_email=billing_manager_email, + ) return seat @@ -649,7 +737,7 @@ async def revoke_all_seats_for_subscription( return revoked_count - async def _find__or_create_customer( + async def _find_or_create_customer( self, session: AsyncSession, organization_id: uuid.UUID, @@ -694,5 +782,146 @@ async def _find__or_create_customer( await session.flush() return customer + async def _get_or_create_member_for_seat( + self, + session: AsyncSession, + billing_customer_id: uuid.UUID, + organization_id: uuid.UUID, + email: str, + ) -> Member: + """ + Get or create a Member for a seat assignment under the billing customer. + + This is used when member_model_enabled = True. Instead of creating a + separate Customer for each seat member, we create Members under the + billing customer (the purchaser). + + Args: + session: Database session + billing_customer_id: The customer who purchased (billing manager) + organization_id: Organization ID + email: Email of the seat member + + Returns: + Member entity for the seat member + """ + member_repository = MemberRepository.from_session(session) + + # Check if member already exists under this customer with this email + existing_member = await member_repository.get_by_customer_id_and_email( + billing_customer_id, email + ) + if existing_member: + return existing_member + + # Create new member under billing customer + member = Member( + customer_id=billing_customer_id, + organization_id=organization_id, + email=email, + role=MemberRole.member, + ) + session.add(member) + await session.flush() + + log.info( + "Created member for seat assignment", + member_id=member.id, + customer_id=billing_customer_id, + organization_id=organization_id, + email=email, + ) + + return member + + async def _resolve_member_model_target( + self, + session: AsyncSession, + repository: CustomerSeatRepository, + container: SeatContainer, + billing_customer_id: uuid.UUID, + organization_id: uuid.UUID, + email: str | None, + customer_id: uuid.UUID | None, + external_customer_id: str | None, + ) -> SeatAssignmentTarget: + """Resolve seat assignment target when member_model_enabled=True. + + In the member model: + - Only email is accepted (customer_id/external_customer_id rejected) + - No Customer is created for the seat member + - A Member is created under the billing customer + - seat.customer_id = billing customer (purchaser) + - seat.email = seat member's email + """ + if not email or customer_id or external_customer_id: + raise InvalidSeatAssignmentRequest( + "Only email is supported when member_model_enabled is true. " + "customer_id and external_customer_id are not allowed." + ) + + # Check if seat already assigned to this email + existing_seat = await repository.get_by_container_and_email(container, email) + if existing_seat and not existing_seat.is_revoked(): + raise SeatAlreadyAssigned(email) + + # Create Member under billing customer + member = await self._get_or_create_member_for_seat( + session, billing_customer_id, organization_id, email + ) + + return SeatAssignmentTarget( + customer_id=billing_customer_id, + member_id=member.id, + email=email, + seat_member_email=email, + ) + + async def _resolve_legacy_target( + self, + session: AsyncSession, + repository: CustomerSeatRepository, + container: SeatContainer, + organization: Organization | None, + organization_id: uuid.UUID, + email: str | None, + customer_id: uuid.UUID | None, + external_customer_id: str | None, + ) -> SeatAssignmentTarget: + """Resolve seat assignment target when member_model_enabled=False. + + In the legacy model: + - email, customer_id, or external_customer_id accepted (exactly one) + - A Customer is created/found for the seat member + - A Member may be created under that customer + - seat.customer_id = seat member's customer + - seat.email = None (email comes from customer relationship) + """ + customer = await self._find_or_create_customer( + session, organization_id, email, external_customer_id, customer_id + ) + + # Check if seat already assigned to this customer + existing_seat = await repository.get_by_container_and_customer( + container, customer.id + ) + if existing_seat and not existing_seat.is_revoked(): + identifier = email or external_customer_id or str(customer_id) + raise SeatAlreadyAssigned(identifier) + + # Optionally create member under this customer + member = None + if organization: + member = await member_service.get_or_create_seat_member( + session, customer, organization + ) + + return SeatAssignmentTarget( + customer_id=customer.id, + member_id=member.id if member else None, + email=None, + seat_member_email=customer.email, + ) + seat_service = SeatService() diff --git a/server/polar/member/repository.py b/server/polar/member/repository.py index fa1ab59360..4fd2dcb126 100644 --- a/server/polar/member/repository.py +++ b/server/polar/member/repository.py @@ -43,6 +43,24 @@ async def get_by_customer_and_email( result = await session.execute(statement) return result.scalar_one_or_none() + async def get_by_customer_id_and_email( + self, + customer_id: UUID, + email: str, + ) -> Member | None: + """ + Get a member by customer ID and email. + + Returns: + Member if found, None otherwise + """ + statement = select(Member).where( + Member.customer_id == customer_id, + Member.email == email, + Member.deleted_at.is_(None), + ) + return await self.get_one_or_none(statement) + async def list_by_customer( self, session: AsyncReadSession, diff --git a/server/tests/customer_seat/test_service.py b/server/tests/customer_seat/test_service.py index 0c685c1811..fa70023804 100644 --- a/server/tests/customer_seat/test_service.py +++ b/server/tests/customer_seat/test_service.py @@ -641,7 +641,14 @@ async def test_assign_seat_with_member_model_enabled( session: AsyncSession, save_fixture: SaveFixture, ) -> None: - """Test that assign_seat creates a member when member_model_enabled is true.""" + """Test that assign_seat creates a member when member_model_enabled is true. + + When member_model_enabled=True: + - seat.customer_id = billing customer (subscription owner) + - seat.member_id = member created under billing customer + - seat.email = email of the seat member + - No separate Customer is created for the seat member + """ organization = await create_organization( save_fixture, feature_settings={ @@ -667,25 +674,22 @@ async def test_assign_seat_with_member_model_enabled( subscription = await create_subscription_with_seats( save_fixture, product=product, customer=billing_customer, seats=5 ) - # Seat customer (to be assigned a seat) - seat_customer = await create_customer( - save_fixture, - organization=organization, - email="seat@example.com", - ) seat = await seat_service.assign_seat( session, subscription, email="seat@example.com" ) - assert seat.customer_id == seat_customer.id + # customer_id should be the billing customer (purchaser), not a new customer + assert seat.customer_id == billing_customer.id assert seat.member_id is not None + assert seat.email == "seat@example.com" # Verify member was created with correct properties await session.refresh(seat, ["member"]) assert seat.member is not None - assert seat.member.customer_id == seat_customer.id - assert seat.member.email == seat_customer.email + # Member is created under the billing customer + assert seat.member.customer_id == billing_customer.id + assert seat.member.email == "seat@example.com" assert seat.member.organization_id == organization.id @pytest.mark.asyncio @@ -709,6 +713,84 @@ async def test_assign_seat_without_member_model_enabled( assert seat.customer_id == customer.id assert seat.member_id is None + @pytest.mark.asyncio + async def test_assign_seat_rejects_customer_id_when_member_model_enabled( + self, + session: AsyncSession, + save_fixture: SaveFixture, + ) -> None: + """Test that assign_seat rejects customer_id when member_model_enabled is true.""" + organization = await create_organization( + save_fixture, + feature_settings={ + "seat_based_pricing_enabled": True, + "member_model_enabled": True, + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + prices=[], + ) + await create_product_price_seat_unit( + save_fixture, product=product, price_per_seat=1000 + ) + billing_customer = await create_customer( + save_fixture, + organization=organization, + email="billing@example.com", + ) + subscription = await create_subscription_with_seats( + save_fixture, product=product, customer=billing_customer, seats=5 + ) + seat_customer = await create_customer( + save_fixture, + organization=organization, + email="seat@example.com", + ) + + with pytest.raises(InvalidSeatAssignmentRequest): + await seat_service.assign_seat( + session, subscription, customer_id=seat_customer.id + ) + + @pytest.mark.asyncio + async def test_assign_seat_requires_email_when_member_model_enabled( + self, + session: AsyncSession, + save_fixture: SaveFixture, + ) -> None: + """Test that assign_seat requires email when member_model_enabled is true.""" + organization = await create_organization( + save_fixture, + feature_settings={ + "seat_based_pricing_enabled": True, + "member_model_enabled": True, + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + prices=[], + ) + await create_product_price_seat_unit( + save_fixture, product=product, price_per_seat=1000 + ) + billing_customer = await create_customer( + save_fixture, + organization=organization, + email="billing@example.com", + ) + subscription = await create_subscription_with_seats( + save_fixture, product=product, customer=billing_customer, seats=5 + ) + + # No email, customer_id, or external_customer_id - should fail + with pytest.raises(InvalidSeatAssignmentRequest): + await seat_service.assign_seat(session, subscription) + class TestGetSeatByToken: @pytest.mark.asyncio @@ -925,6 +1007,59 @@ async def test_claim_seat_sends_webhook( assert args[0][2] == WebhookEventType.customer_seat_claimed assert args[0][3].id == seat.id + @pytest.mark.asyncio + async def test_claim_seat_with_member_model_enabled( + self, + session: AsyncSession, + save_fixture: SaveFixture, + ) -> None: + """Test that claim_seat creates session for billing customer when member_model_enabled.""" + organization = await create_organization( + save_fixture, + feature_settings={ + "seat_based_pricing_enabled": True, + "member_model_enabled": True, + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + prices=[], + ) + await create_product_price_seat_unit( + save_fixture, product=product, price_per_seat=1000 + ) + billing_customer = await create_customer( + save_fixture, + organization=organization, + email="billing@example.com", + ) + subscription = await create_subscription_with_seats( + save_fixture, product=product, customer=billing_customer, seats=5 + ) + + # Assign seat using the new model + seat = await seat_service.assign_seat( + session, subscription, email="seat@example.com" + ) + + assert seat.invitation_token is not None + assert seat.customer_id == billing_customer.id + assert seat.member_id is not None + assert seat.email == "seat@example.com" + + # Claim the seat + claimed_seat, session_token = await seat_service.claim_seat( + session, seat.invitation_token + ) + + assert claimed_seat.status == SeatStatus.claimed + assert claimed_seat.claimed_at is not None + assert claimed_seat.invitation_token is None # Token should be cleared + assert session_token is not None + assert len(session_token) > 0 + class TestRevokeSeat: @pytest.mark.asyncio @@ -980,6 +1115,63 @@ async def test_revoke_seat_sends_webhook( assert args[0][2] == WebhookEventType.customer_seat_revoked assert args[0][3].id == seat.id + @pytest.mark.asyncio + async def test_revoke_seat_with_member_model_enabled( + self, + session: AsyncSession, + save_fixture: SaveFixture, + ) -> None: + """Test that revoke_seat clears member_id, email and customer_id when member_model_enabled.""" + organization = await create_organization( + save_fixture, + feature_settings={ + "seat_based_pricing_enabled": True, + "member_model_enabled": True, + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + prices=[], + ) + await create_product_price_seat_unit( + save_fixture, product=product, price_per_seat=1000 + ) + billing_customer = await create_customer( + save_fixture, + organization=organization, + email="billing@example.com", + ) + subscription = await create_subscription_with_seats( + save_fixture, product=product, customer=billing_customer, seats=5 + ) + + # Assign and claim a seat + seat = await seat_service.assign_seat( + session, subscription, email="seat@example.com", immediate_claim=True + ) + + assert seat.customer_id == billing_customer.id + assert seat.member_id is not None + assert seat.email == "seat@example.com" + assert seat.status == SeatStatus.claimed + + # Revoke the seat + await session.refresh(seat, ["subscription"]) + assert seat.subscription is not None + await session.refresh(seat.subscription, ["product"]) + await session.refresh(seat.subscription.product, ["organization"]) + + revoked_seat = await seat_service.revoke_seat(session, seat) + + assert revoked_seat.status == SeatStatus.revoked + assert revoked_seat.revoked_at is not None + # All identifiers should be cleared + assert revoked_seat.customer_id is None + assert revoked_seat.member_id is None + assert revoked_seat.email is None + class TestGetSeat: @pytest.mark.asyncio @@ -1244,6 +1436,64 @@ async def test_resend_invitation_revoked_seat( with pytest.raises(SeatNotPending): await seat_service.resend_invitation(session, seat) + @pytest.mark.asyncio + async def test_resend_invitation_with_member_model_enabled( + self, + session: AsyncSession, + save_fixture: SaveFixture, + ) -> None: + """Test that resend_invitation uses seat.email when member_model_enabled.""" + organization = await create_organization( + save_fixture, + feature_settings={ + "seat_based_pricing_enabled": True, + "member_model_enabled": True, + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + prices=[], + ) + await create_product_price_seat_unit( + save_fixture, product=product, price_per_seat=1000 + ) + billing_customer = await create_customer( + save_fixture, + organization=organization, + email="billing@example.com", + ) + subscription = await create_subscription_with_seats( + save_fixture, product=product, customer=billing_customer, seats=5 + ) + + # Assign a seat (creates pending seat with email on seat) + seat = await seat_service.assign_seat( + session, subscription, email="seat@example.com" + ) + + assert seat.status == SeatStatus.pending + assert seat.email == "seat@example.com" + assert seat.invitation_token is not None + + # Reload seat with relationships + await session.refresh(seat, ["subscription"]) + assert seat.subscription is not None + await session.refresh(seat.subscription, ["product", "customer"]) + await session.refresh(seat.subscription.product, ["organization"]) + + with patch( + "polar.customer_seat.service.send_seat_invitation_email" + ) as mock_send_email: + result_seat = await seat_service.resend_invitation(session, seat) + + assert result_seat.status == SeatStatus.pending + mock_send_email.assert_called_once() + call_kwargs = mock_send_email.call_args[1] + # Should use seat.email, not seat.customer.email + assert call_kwargs["customer_email"] == "seat@example.com" + class TestBenefitGranting: """Tests for benefit granting when claiming and revoking seats."""