Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export function getSubscriptionServiceMessenger(
'SubscriptionController:getState',
'SubscriptionController:submitShieldSubscriptionCryptoApproval',
'SubscriptionController:linkRewards',
'SubscriptionController:clearLastSelectedPaymentMethod',
'AppStateController:getState',
'AppStateController:setPendingShieldCohort',
'AuthenticationController:getBearerToken',
Expand Down
10 changes: 10 additions & 0 deletions app/scripts/services/subscription/subscription-service.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const mockGetRewardSeasonMetadata = jest.fn();
const mockGetHasAccountOptedIn = jest.fn();
const mockLinkRewards = jest.fn();
const mockSubmitShieldSubscriptionCryptoApproval = jest.fn();
const mockClearLastSelectedPaymentMethod = jest.fn();

const rootMessenger: RootMessenger = new Messenger({
namespace: MOCK_ANY_NAMESPACE,
Expand Down Expand Up @@ -171,6 +172,10 @@ rootMessenger.registerActionHandler(
'SubscriptionController:submitShieldSubscriptionCryptoApproval',
mockSubmitShieldSubscriptionCryptoApproval,
);
rootMessenger.registerActionHandler(
'SubscriptionController:clearLastSelectedPaymentMethod',
mockClearLastSelectedPaymentMethod,
);

const messenger: SubscriptionServiceMessenger = new Messenger({
namespace: 'SubscriptionService',
Expand All @@ -184,6 +189,7 @@ rootMessenger.delegate({
'SubscriptionController:submitSponsorshipIntents',
'SubscriptionController:linkRewards',
'SubscriptionController:submitShieldSubscriptionCryptoApproval',
'SubscriptionController:clearLastSelectedPaymentMethod',
'TransactionController:getTransactions',
'PreferencesController:getState',
'AccountsController:getState',
Expand Down Expand Up @@ -441,6 +447,10 @@ describe('SubscriptionService - startSubscriptionWithCard', () => {
1000,
),
).rejects.toThrow(SHIELD_ERROR.subscriptionPollingTimedOut);

expect(mockClearLastSelectedPaymentMethod).toHaveBeenCalledWith(
PRODUCT_TYPES.SHIELD,
);
});
});

Expand Down
12 changes: 12 additions & 0 deletions app/scripts/services/subscription/subscription-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ export class SubscriptionService {
},
);

// Clear cached payment method after failed/cancelled payment - metrics already captured
this.#messenger.call(
'SubscriptionController:clearLastSelectedPaymentMethod',
PRODUCT_TYPES.SHIELD,
);

// fetch latest subscriptions to update the state in case subscription already created error (not when polling timed out)
if (errorMessage.toLocaleLowerCase().includes('already exists')) {
await this.#messenger.call('SubscriptionController:getSubscriptions');
Expand Down Expand Up @@ -525,6 +531,12 @@ export class SubscriptionService {
);
}

// Clear cached payment method after failed crypto payment - metrics already captured
this.#messenger.call(
'SubscriptionController:clearLastSelectedPaymentMethod',
PRODUCT_TYPES.SHIELD,
);

this.#captureException(
createSentryError(
'Error on Shield subscription approval transaction',
Expand Down
2 changes: 2 additions & 0 deletions app/scripts/services/subscription/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
SubscriptionControllerGetStateAction,
SubscriptionControllerLinkRewardsAction,
SubscriptionControllerSubmitShieldSubscriptionCryptoApprovalAction,
SubscriptionControllerClearLastSelectedPaymentMethodAction,
} from '@metamask/subscription-controller';
import { AuthenticationControllerGetBearerToken } from '@metamask/profile-sync-controller/auth';
import {
Expand Down Expand Up @@ -57,6 +58,7 @@ export type SubscriptionServiceAction =
| SubscriptionControllerGetStateAction
| SubscriptionControllerLinkRewardsAction
| SubscriptionControllerSubmitShieldSubscriptionCryptoApprovalAction
| SubscriptionControllerClearLastSelectedPaymentMethodAction
| TransactionControllerGetTransactionsAction
| PreferencesControllerGetStateAction
| AccountsControllerGetStateAction
Expand Down
31 changes: 29 additions & 2 deletions ui/pages/shield-plan/shield-plan.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import React, { useCallback, useEffect, useMemo, useState } from 'react';
import React, {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react';
import {
PAYMENT_TYPES,
PaymentType,
Expand Down Expand Up @@ -251,9 +257,23 @@ const ShieldPlan = () => {
}, [selectedPlan, setSelectedToken]);

const selectedTokenAddress = selectedToken?.address;

// Track if initial payment method selection has been done
// This prevents auto-switching after payment cancel/failure when cache is cleared
const hasInitializedPaymentMethod = useRef(false);

// set default selected payment method to crypto if selected token available
// should only trigger if selectedTokenAddress change (shouldn't trigger again if selected token object updated but still same token)
useEffect(() => {
// Skip auto-selection after initial setup to prevent switching after payment cancel
if (hasInitializedPaymentMethod.current) {
// Only handle the case when selectedTokenAddress becomes undefined (no tokens available)
if (!selectedTokenAddress) {
setSelectedPaymentMethod(PAYMENT_TYPES.byCard);
}
return;
}

const lastUsedPaymentMethod = lastUsedPaymentDetails?.type;
if (
selectedTokenAddress &&
Expand Down Expand Up @@ -295,6 +315,13 @@ const ShieldPlan = () => {
rewardPoints: claimedRewardsPoints ?? undefined,
});

const onStartSubscription = useCallback(() => {
// set flag to prevent auto-switching payment method after payment cancel/failure
hasInitializedPaymentMethod.current = true;

handleSubscription();
}, [handleSubscription]);

const handleUserChangeToken = useCallback(
async (token: TokenWithApprovalAmount) => {
setSelectedToken(token);
Expand Down Expand Up @@ -628,7 +655,7 @@ const ShieldPlan = () => {
size={ButtonSize.Lg}
variant={ButtonVariant.Primary}
isFullWidth
onClick={handleSubscription}
onClick={onStartSubscription}
data-testid="shield-plan-continue-button"
>
{t('continue')}
Expand Down
Loading