Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Add discount_applied_at to subscriptions

Revision ID: ea11a3dc85a2
Revises: 3d212567b9a6
Create Date: 2026-01-06 12:00:00.000000

"""

import sqlalchemy as sa
from alembic import op

# Polar Custom Imports

# revision identifiers, used by Alembic.
revision = "ea11a3dc85a2"
down_revision = "3c48bf325744"
branch_labels: tuple[str] | None = None
depends_on: tuple[str] | None = None


def upgrade() -> None:
op.add_column(
"subscriptions",
sa.Column("discount_applied_at", sa.TIMESTAMP(timezone=True), nullable=True),
)

# Backfill discount_applied_at for existing subscriptions with discounts
# by finding the first order that used the discount
op.execute(
"""
UPDATE subscriptions s
SET discount_applied_at = o.created_at
FROM (
SELECT DISTINCT ON (o.subscription_id, o.discount_id)
o.subscription_id,
o.discount_id,
o.created_at
FROM orders o
WHERE o.subscription_id IS NOT NULL
AND o.discount_id IS NOT NULL
AND o.deleted_at IS NULL
ORDER BY o.subscription_id, o.discount_id, o.created_at ASC
) o
WHERE s.id = o.subscription_id
AND s.discount_id = o.discount_id
"""
)


def downgrade() -> None:
op.drop_column("subscriptions", "discount_applied_at")
23 changes: 17 additions & 6 deletions server/polar/models/discount.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,32 @@ def is_applicable(self, product: "Product") -> bool:

def is_repetition_expired(
self,
started_at: datetime,
discount_applied_at: datetime,
current_period_start: datetime,
trial_ended: bool = False,
) -> bool:
"""
Check if a discount's repetition has expired for the current billing cycle.

Args:
discount_applied_at: The timestamp when the discount was first applied
to a billing cycle. This should be the cycle's start date.
current_period_start: The start date of the current billing period.

Returns:
True if the discount should no longer apply to this cycle.
"""
if self.duration == DiscountDuration.once:
# If transitioning from trial to active, this is the first billed cycle
# so the discount should still apply
return not trial_ended
# "once" discounts only apply to the first billing cycle where applied
# They're expired if current period is after the period when first applied
return current_period_start > discount_applied_at
if self.duration == DiscountDuration.forever:
return False
if self.duration_in_months is None:
return False

# For repeating discounts, calculate expiration from when discount was first applied
# -1 because the first month counts as a first repetition
end_at = started_at + relativedelta(months=self.duration_in_months - 1)
end_at = discount_applied_at + relativedelta(months=self.duration_in_months - 1)
return current_period_start > end_at

__mapper_args__ = {
Expand Down
11 changes: 11 additions & 0 deletions server/polar/models/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ def product(cls) -> Mapped["Product"]:
Uuid, ForeignKey("discounts.id", ondelete="set null"), nullable=True
)

discount_applied_at: Mapped[datetime | None] = mapped_column(
TIMESTAMP(timezone=True), nullable=True, default=None
)
"""
Timestamp when the discount was first applied to a billing cycle.
"""

@declared_attr
def discount(cls) -> Mapped["Discount | None"]:
return relationship("Discount", lazy="joined")
Expand Down Expand Up @@ -448,3 +455,7 @@ def _discount_set(
initiator: Event,
) -> None:
target.update_amount_and_currency(target.subscription_product_prices, value)
# Reset discount_applied_at when discount changes so the new discount's
# expiration will be tracked from its first use in a billing cycle
if value != oldvalue:
target.discount_applied_at = None
21 changes: 15 additions & 6 deletions server/polar/subscription/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ async def create_or_update_from_checkout(
subscription.product = product
subscription.subscription_product_prices = subscription_product_prices
subscription.discount = checkout.discount
# For non-trial checkouts with a discount, the discount is applied immediately
# (the first payment at checkout includes the discount)
if checkout.discount is not None and trial_end is None:
subscription.discount_applied_at = current_period_start
subscription.checkout = checkout
subscription.user_metadata = checkout.user_metadata
subscription.custom_field_data = checkout.custom_field_data
Expand Down Expand Up @@ -659,11 +663,13 @@ async def cycle(

# Check if discount is still applicable
if subscription.discount is not None:
assert subscription.started_at is not None
# Set discount_applied_at on first use (when discount is actually applied to a cycle)
if subscription.discount_applied_at is None:
subscription.discount_applied_at = subscription.current_period_start

if subscription.discount.is_repetition_expired(
subscription.started_at,
subscription.discount_applied_at,
subscription.current_period_start,
previous_status == SubscriptionStatus.trialing,
):
subscription.discount = None

Expand Down Expand Up @@ -1773,12 +1779,15 @@ async def calculate_charge_preview(

# Ensure the discount has not expired yet for the next charge (so at current_period_end)
if subscription.discount is not None:
assert subscription.started_at is not None
assert subscription.current_period_end is not None
# If discount hasn't been applied yet, it will be applied at the next cycle
# (current_period_end will become the new current_period_start)
discount_applied_at = (
subscription.discount_applied_at or subscription.current_period_end
)
if not subscription.discount.is_repetition_expired(
subscription.started_at,
discount_applied_at,
subscription.current_period_end,
subscription.status == SubscriptionStatus.trialing,
):
applicable_discount = subscription.discount

Expand Down
43 changes: 13 additions & 30 deletions server/tests/discount/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,12 @@ async def test_code_case_insensitive(

@pytest.mark.asyncio
class TestIsRepetitionExpired:
async def test_once_not_trialing(
async def test_once_first_cycle(
self,
save_fixture: SaveFixture,
organization: Organization,
) -> None:
"""Test that 'once' discount expires immediately when not trialing."""
"""Test that 'once' discount applies only to its first billing cycle."""
discount = await create_discount(
save_fixture,
type=DiscountType.percentage,
Expand All @@ -626,26 +626,12 @@ async def test_once_not_trialing(
)

now = utc_now()
# For non-trialing subscriptions, 'once' discount should expire after first use
assert discount.is_repetition_expired(now, now, False) is True

async def test_once_was_trialing(
self,
save_fixture: SaveFixture,
organization: Organization,
) -> None:
"""Test that 'once' discount does NOT expire when transitioning from trial."""
discount = await create_discount(
save_fixture,
type=DiscountType.percentage,
basis_points=10_000,
duration=DiscountDuration.once,
organization=organization,
)

now = utc_now()
# When transitioning from trial, 'once' discount should still apply
assert discount.is_repetition_expired(now, now, True) is False
next_month = now + timedelta(days=30)
# 'once' discount should apply when discount_applied_at equals current_period_start
# (this is the first cycle where the discount is used)
assert discount.is_repetition_expired(now, now) is False
# 'once' discount should expire for subsequent cycles
assert discount.is_repetition_expired(now, next_month) is True

async def test_forever_never_expires(
self,
Expand All @@ -663,9 +649,8 @@ async def test_forever_never_expires(

now = utc_now()
future = now + timedelta(days=365)
# Forever discounts never expire
assert discount.is_repetition_expired(now, future, False) is False
assert discount.is_repetition_expired(now, future, True) is False
# Forever discounts never expire, regardless of when applied or current period
assert discount.is_repetition_expired(now, future) is False

async def test_repeating_expires_after_duration(
self,
Expand All @@ -686,9 +671,7 @@ async def test_repeating_expires_after_duration(
within_duration = now + timedelta(days=30) # ~1 month
after_duration = now + timedelta(days=120) # ~4 months

# Should not expire within duration
assert discount.is_repetition_expired(now, within_duration, False) is False
# Should not expire within duration (from when discount was first applied)
assert discount.is_repetition_expired(now, within_duration) is False
# Should expire after duration
assert discount.is_repetition_expired(now, after_duration, False) is True
# was_trialing should not affect repeating discounts
assert discount.is_repetition_expired(now, after_duration, True) is True
assert discount.is_repetition_expired(now, after_duration) is True
7 changes: 7 additions & 0 deletions server/tests/fixtures/random_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,13 @@ async def create_subscription(
seats=seats,
past_due_at=past_due_at,
)

# For non-trial subscriptions with a discount, set discount_applied_at to simulate
# the behavior of create_or_update_from_checkout where the discount is applied
# to the first payment at checkout time. Set this explicitly after constructor
# to ensure it's not cleared by the discount "set" event listener.
if discount is not None and trial_start is None:
subscription.discount_applied_at = current_period_start
await save_fixture(subscription)

return subscription
Expand Down
87 changes: 87 additions & 0 deletions server/tests/subscription/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,93 @@ async def test_trial_end_with_once_discount(
# Verify discount is NOW removed (used up after first billing cycle)
assert second_cycle_subscription.discount is None

async def test_trial_end_with_repeating_discount(
self,
session: AsyncSession,
enqueue_job_mock: MagicMock,
enqueue_email_mock: MagicMock,
save_fixture: SaveFixture,
product: Product,
customer: Customer,
organization: Organization,
) -> None:
"""Test that repeating discounts applied during checkout with trial
are properly tracked from the first billing cycle after trial ends."""
# Create a 3-month repeating discount
discount = await create_discount(
save_fixture,
type=DiscountType.fixed,
amount=1000,
currency="usd",
duration=DiscountDuration.repeating,
duration_in_months=3,
organization=organization,
)

# Create trialing subscription with the discount
subscription = await create_trialing_subscription(
save_fixture,
product=product,
customer=customer,
discount=discount,
scheduler_locked_at=utc_now(),
)

# Verify initial state: discount is set but discount_applied_at is None
# (discount hasn't been applied to a billing cycle yet)
assert subscription.discount == discount
assert subscription.discount_applied_at is None
assert subscription.status == SubscriptionStatus.trialing

# Cycle 1: Trial ends, first billing cycle
first_billing_subscription = await subscription_service.cycle(
session, subscription
)

# Verify discount_applied_at is now set to the first billing period start
assert first_billing_subscription.discount == discount
assert first_billing_subscription.discount_applied_at is not None
assert (
first_billing_subscription.discount_applied_at
== first_billing_subscription.current_period_start
)
assert first_billing_subscription.status == SubscriptionStatus.active

# Cycle 2: Second billing cycle (2nd month of discount)
second_billing_subscription = await subscription_service.cycle(
session, first_billing_subscription
)
assert second_billing_subscription.discount == discount

# Cycle 3: Third billing cycle (3rd month of discount)
third_billing_subscription = await subscription_service.cycle(
session, second_billing_subscription
)
assert third_billing_subscription.discount == discount

# Cycle 4: Fourth billing cycle - discount should now be expired
fourth_billing_subscription = await subscription_service.cycle(
session, third_billing_subscription
)
assert fourth_billing_subscription.discount is None

# Verify billing entries - 3 should have discount, 1 should not
billing_entry_repository = BillingEntryRepository.from_session(session)
billing_entries = await billing_entry_repository.get_pending_by_subscription(
subscription.id
)
cycle_entries = [
entry for entry in billing_entries if entry.type == BillingEntryType.cycle
]
assert len(cycle_entries) == 4

# First 3 entries should have discount applied
assert cycle_entries[0].discount == discount
assert cycle_entries[1].discount == discount
assert cycle_entries[2].discount == discount
# Fourth entry should have no discount
assert cycle_entries[3].discount is None


@pytest.mark.asyncio
class TestRevoke:
Expand Down
Loading