Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 68 additions & 41 deletions enterprise_access/apps/api_client/license_manager_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,74 @@ class LicenseManagerApiClient(BaseOAuthClient):
subscription_provisioning_endpoint = api_base_url + 'provisioning-admins/subscriptions/'
subscription_plan_renewal_provisioning_endpoint = api_base_url + 'provisioning-admins/subscription-plan-renewals/'

def list_subscriptions(self, enterprise_customer_uuid):
"""
List subscription plans for an enterprise.

Returns a paginated DRF list response: { count, next, previous, results: [...] }
"""
try:
params = {
'enterprise_customer_uuid': enterprise_customer_uuid,
}

response = self.client.get(
self.subscriptions_endpoint,
params=params,
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT,
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as exc:
logger.exception(
'Failed to list subscriptions for enterprise %s, response: %s, exc: %s',
enterprise_customer_uuid, safe_error_response_content(exc), exc,
)
raise

def update_subscription_plan(self, subscription_uuid, salesforce_opportunity_line_item=None, **kwargs):
"""
Update a SubscriptionPlan's Salesforce Opportunity Line Item.

Arguments:
subscription_uuid (str): UUID of the SubscriptionPlan to update
salesforce_opportunity_line_item (str): Salesforce OLI to associate with the plan

Returns:
dict: Updated subscription plan data from the API

Raises:
APIClientException: If the API call fails
"""
endpoint = f"{self.subscription_provisioning_endpoint}{subscription_uuid}/"
payload = {
'change_reason': OTHER_SUBSCRIPTION_CHANGE_REASON,
}
payload.update(kwargs)
if salesforce_opportunity_line_item:
payload['salesforce_opportunity_line_item'] = salesforce_opportunity_line_item

try:
response = self.client.patch(
endpoint,
json=payload,
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as exc:
logger.exception(
'Failed to update subscription plan %s with OLI %s, response %s, exception: %s',
subscription_uuid,
salesforce_opportunity_line_item,
safe_error_response_content(exc),
exc,
)
raise APIClientException(
f'Could not update subscription plan {subscription_uuid}',
exc,
) from exc

def get_subscription_overview(self, subscription_uuid):
"""
Call license-manager API for data about a SubscriptionPlan.
Expand Down Expand Up @@ -213,47 +281,6 @@ def create_subscription_plan(
exc,
) from exc

def update_subscription_plan(self, subscription_uuid, salesforce_opportunity_line_item):
"""
Update a SubscriptionPlan's Salesforce Opportunity Line Item.

Arguments:
subscription_uuid (str): UUID of the SubscriptionPlan to update
salesforce_opportunity_line_item (str): Salesforce OLI to associate with the plan

Returns:
dict: Updated subscription plan data from the API

Raises:
APIClientException: If the API call fails
"""
endpoint = f"{self.subscription_provisioning_endpoint}{subscription_uuid}/"
payload = {
'salesforce_opportunity_line_item': salesforce_opportunity_line_item,
'change_reason': OTHER_SUBSCRIPTION_CHANGE_REASON,
}

try:
response = self.client.patch(
endpoint,
json=payload,
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as exc:
logger.exception(
'Failed to update subscription plan %s with OLI %s, response %s, exception: %s',
subscription_uuid,
salesforce_opportunity_line_item,
safe_error_response_content(exc),
exc,
)
raise APIClientException(
f'Could not update subscription plan {subscription_uuid}',
exc,
) from exc

def create_subscription_plan_renewal(
self,
prior_subscription_plan_uuid: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,49 @@ def test_create_customer_agreement(self, mock_oauth_client):
json=expected_payload,
)

@mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True)
def test_list_subscriptions_params(self, mock_oauth_client):
mock_get = mock_oauth_client.return_value.get
mock_get.return_value.json.return_value = {'results': []}

lm_client = LicenseManagerApiClient()
enterprise_uuid = 'ec-uuid-123'

# Should only set enterprise_customer_uuid parameter
result = lm_client.list_subscriptions(enterprise_uuid)
self.assertEqual(result, {'results': []})

# Verify URL and params
expected_url = (
'http://license-manager.example.com'
'/api/v1/subscriptions/'
)
mock_get.assert_called_with(
expected_url,
params={'enterprise_customer_uuid': enterprise_uuid},
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT,
)

@mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True)
def test_update_subscription_plan_patch(self, mock_oauth_client):
mock_patch = mock_oauth_client.return_value.patch
mock_patch.return_value.json.return_value = {'uuid': 'plan-uuid', 'is_active': False}

lm_client = LicenseManagerApiClient()
payload = {'is_active': False, 'change_reason': 'delayed_payment'}
result = lm_client.update_subscription_plan('plan-uuid', **payload)

self.assertEqual(result, mock_patch.return_value.json.return_value)
expected_url = (
'http://license-manager.example.com'
'/api/v1/provisioning-admins/subscriptions/plan-uuid/'
)
mock_patch.assert_called_once_with(
expected_url,
json=payload,
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT,
)

@mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True)
def test_create_subscription_plan(self, mock_oauth_client):
mock_post = mock_oauth_client.return_value.post
Expand Down Expand Up @@ -135,7 +178,7 @@ def test_create_subscription_plan(self, mock_oauth_client):
)

@mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True)
def test_update_subscription_plan(self, mock_oauth_client):
def test_update_subscription_plan_oli(self, mock_oauth_client):
mock_patch = mock_oauth_client.return_value.patch
subs_plan_uuid = uuid.uuid4()
new_oli_value = '1234512345'
Expand Down
125 changes: 78 additions & 47 deletions enterprise_access/apps/customer_billing/stripe_event_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from uuid import UUID

import stripe

Expand All @@ -16,6 +17,7 @@
)
from enterprise_access.apps.customer_billing.stripe_event_types import StripeEventType
from enterprise_access.apps.customer_billing.tasks import (
send_billing_error_email_task,
send_payment_receipt_email,
send_trial_cancellation_email_task,
send_trial_end_and_subscription_started_email_task,
Expand Down Expand Up @@ -102,30 +104,58 @@ def get_checkout_intent_or_raise(checkout_intent_id, event_id) -> CheckoutIntent
"""
try:
checkout_intent = CheckoutIntent.objects.get(id=checkout_intent_id)
return checkout_intent
except CheckoutIntent.DoesNotExist:
logger.warning(
'Could not find CheckoutIntent record with id %s for event %s',
checkout_intent_id, event_id,
)
raise

logger.info(
'Found existing CheckoutIntent record with id=%s, state=%s, for event=%s',
checkout_intent.id, checkout_intent.state, event_id,

def handle_pending_update(subscription_id: str, checkout_intent_id: int, pending_update):
"""
Log pending update information for visibility.
Assumes a pending_update is present.
"""
# TODO: take necessary action on the actual SubscriptionPlan and update the CheckoutIntent.
logger.warning(
"Subscription %s has pending update: %s. checkout_intent_id: %s",
subscription_id,
pending_update,
checkout_intent_id,
)
return checkout_intent


def link_event_data_to_checkout_intent(event, checkout_intent):
"""
Sets the StripeEventData record for the given event to point at the provided CheckoutIntent.
Set the StripeEventData record for the given event to point at the provided CheckoutIntent.
"""
event_data = StripeEventData.objects.get(event_id=event.id)
if not event_data.checkout_intent:
event_data.checkout_intent = checkout_intent
event_data.save() # this triggers a post_save signal that updates the related summary record


def cancel_all_future_plans(checkout_intent):
"""
Deactivate (cancel) all future renewal plans descending from the
anchor plan for this enterprise.
"""
unprocessed_renewals = checkout_intent.renewals.filter(processed_at__isnull=True)
client = LicenseManagerApiClient()
deactivated: list[UUID] = []

for renewal in unprocessed_renewals:
client.update_subscription_plan(
str(renewal.renewed_subscription_plan_uuid),
is_active=False,
)
deactivated.append(renewal.renewed_subscription_plan_uuid)

return deactivated


class StripeEventHandler:
"""
Container for Stripe event handler logic.
Expand Down Expand Up @@ -186,7 +216,8 @@ def invoice_paid(event: stripe.Event) -> None:

logger.info(
'Marking checkout_intent_id=%s as paid via invoice=%s',
checkout_intent_id, invoice.id,
checkout_intent_id,
invoice.id,
)
checkout_intent.mark_as_paid(stripe_customer_id=stripe_customer_id)
link_event_data_to_checkout_intent(event, checkout_intent)
Expand Down Expand Up @@ -224,7 +255,10 @@ def trial_will_end(event: stripe.Event) -> None:
link_event_data_to_checkout_intent(event, checkout_intent)

logger.info(
"Subscription %s trial ending in 72 hours. Queuing trial ending reminder email for checkout_intent_id=%s",
(
"Subscription %s trial ending in 72 hours. "
"Queuing trial ending reminder email for checkout_intent_id=%s"
),
subscription.id,
checkout_intent_id,
)
Expand Down Expand Up @@ -266,9 +300,9 @@ def subscription_created(event: stripe.Event) -> None:
payment_behavior='pending_if_incomplete',
)

logger.info(f'Successfully enabled pending updates for subscription {subscription.id}')
logger.info('Successfully enabled pending updates for subscription %s', subscription.id)
except stripe.StripeError as e:
logger.error(f'Failed to enable pending updates for subscription {subscription.id}: {e}')
logger.error('Failed to enable pending updates for subscription %s: %s', subscription.id, e)

summary = StripeEventSummary.objects.get(event_id=event.id)
summary.update_upcoming_invoice_amount_due()
Expand All @@ -282,30 +316,19 @@ def subscription_updated(event: stripe.Event) -> None:
Send cancellation notification email when a trial subscription is canceled.
"""
subscription = event.data.object
pending_update = getattr(subscription, "pending_update", None)

checkout_intent_id = get_checkout_intent_id_from_subscription(
subscription
)
checkout_intent = get_checkout_intent_or_raise(
checkout_intent_id, event.id
)
checkout_intent_id = get_checkout_intent_id_from_subscription(subscription)
checkout_intent = get_checkout_intent_or_raise(checkout_intent_id, event.id)
link_event_data_to_checkout_intent(event, checkout_intent)

# Pending update
pending_update = getattr(subscription, "pending_update", None)
if pending_update:
# TODO: take necessary action on the actual SubscriptionPlan
# and update the CheckoutIntent.
logger.warning(
"Subscription %s has pending update: %s. checkout_intent_id: %s",
subscription.id,
pending_update,
get_checkout_intent_id_from_subscription(subscription),
)
handle_pending_update(subscription.id, checkout_intent_id, pending_update)

# Handle trial-to-paid transition for renewal processing
current_status = subscription.get("status")
prior_status = getattr(checkout_intent.previous_summary(event), 'subscription_status', None)

# Handle trial-to-paid transition for renewal processing
if prior_status == "trialing" and current_status == "active":
logger.info(
f"Subscription {subscription.id} transitioned from trial to active. "
Expand All @@ -317,30 +340,38 @@ def subscription_updated(event: stripe.Event) -> None:
checkout_intent_id=checkout_intent.id,
)

# Handle trial subscription cancellation
# Check if status changed to canceled to avoid duplicate emails
if current_status == "canceled":
# Only send email if status changed from non-canceled to canceled
if prior_status != 'canceled':
trial_end = subscription.get("trial_end")
if trial_end:
logger.info(
f"Subscription {subscription.id} status changed from '{prior_status}' to 'canceled'. "
f"Queuing trial cancellation email for checkout_intent_id={checkout_intent_id}"
)

send_trial_cancellation_email_task.delay(
checkout_intent_id=checkout_intent.id,
trial_end_timestamp=trial_end,
)
else:
logger.info(
f"Subscription {subscription.id} canceled but has no trial_end, skipping cancellation email"
)
# Trial cancellation transition
if current_status == "canceled" and prior_status != "canceled":
logger.info(
f"Subscription {subscription.id} status changed from '{prior_status}' to 'canceled'. "
)
trial_end = subscription.get("trial_end")
if trial_end:
logger.info(f"Queuing trial cancellation email for checkout_intent_id={checkout_intent_id}")
send_trial_cancellation_email_task.delay(
checkout_intent_id=checkout_intent.id,
trial_end_timestamp=trial_end,
)
else:
logger.info(
f"Subscription {subscription.id} already canceled (status unchanged), skipping cancellation email"
f"Subscription {subscription.id} canceled but has no trial_end, skipping cancellation email"
)

# Past due transition
if current_status == "past_due" and prior_status != "past_due":
enterprise_uuid = checkout_intent.enterprise_uuid
if enterprise_uuid:
cancel_all_future_plans(checkout_intent)
else:
logger.error(
(
"Cannot deactivate future plans for subscription %s: "
"missing enterprise_uuid on CheckoutIntent %s"
),
subscription.id,
checkout_intent.id,
)
send_billing_error_email_task.delay(checkout_intent_id=checkout_intent.id)

@on_stripe_event("customer.subscription.deleted")
@staticmethod
Expand Down
Loading