diff --git a/billing/tests/test_views.py b/billing/tests/test_views.py index bc7d57d440..5a33de67d7 100644 --- a/billing/tests/test_views.py +++ b/billing/tests/test_views.py @@ -1150,29 +1150,51 @@ def test_subscription_schedule_updated_logs_changes_to_schedule( assert self.owner.plan == original_plan assert self.owner.plan_user_count == original_quantity - def test_checkout_session_completed_sets_stripe_customer_id(self): + def test_checkout_session_completed_sets_stripe_ids(self): self.owner.stripe_customer_id = None self.owner.save() - expected_id = "fhjtwoo40" + expected_customer_id = "cus_1234" + expected_subscription_id = "sub_7890" self._send_event( payload={ "type": "checkout.session.completed", "data": { "object": { - "customer": expected_id, + "customer": expected_customer_id, "client_reference_id": str(self.owner.ownerid), + "subscription": expected_subscription_id, } }, } ) self.owner.refresh_from_db() - assert self.owner.stripe_customer_id == expected_id + assert self.owner.stripe_customer_id == expected_customer_id + assert self.owner.stripe_subscription_id == expected_subscription_id @patch("billing.views.stripe.Subscription.modify") def test_customer_update_but_not_payment_method(self, subscription_modify_mock): + payment_method = "pm_123" + self._send_event( + payload={ + "type": "customer.updated", + "data": { + "object": { + "invoice_settings": {"default_payment_method": None}, + "subscriptions": { + "data": [{"default_payment_method": payment_method}] + }, + } + }, + } + ) + + subscription_modify_mock.assert_not_called() + + @patch("billing.views.stripe.Subscription.modify") + def test_customer_update_but_payment_method_is_same(self, subscription_modify_mock): payment_method = "pm_123" self._send_event( payload={ diff --git a/billing/views.py b/billing/views.py index cda860b5c5..6c8229831e 100644 --- a/billing/views.py +++ b/billing/views.py @@ -376,6 +376,10 @@ def customer_updated(self, customer: stripe.Customer) -> None: new_default_payment_method = customer["invoice_settings"][ "default_payment_method" ] + + if new_default_payment_method is None: + return + for subscription in customer.get("subscriptions", {}).get("data", []): if new_default_payment_method == subscription["default_payment_method"]: continue @@ -401,6 +405,7 @@ def checkout_session_completed( ) owner = Owner.objects.get(ownerid=checkout_session.client_reference_id) owner.stripe_customer_id = checkout_session.customer + owner.stripe_subscription_id = checkout_session.subscription owner.save() self._log_updated([owner])