diff --git a/ui/contexts/shield/shield-subscription.test.tsx b/ui/contexts/shield/shield-subscription.test.tsx
new file mode 100644
index 000000000000..a8186fb49b76
--- /dev/null
+++ b/ui/contexts/shield/shield-subscription.test.tsx
@@ -0,0 +1,281 @@
+import React, { useEffect, useRef } from 'react';
+import { render, waitFor } from '@testing-library/react';
+import '@testing-library/jest-dom';
+import * as redux from 'react-redux';
+import * as useSubscription from '../../hooks/subscription/useSubscription';
+import * as useSubscriptionMetrics from '../../hooks/shield/metrics/useSubscriptionMetrics';
+import * as selectors from '../../selectors';
+import * as authSelectors from '../../selectors/identity/authentication';
+import * as subscriptionSelectors from '../../selectors/subscription';
+import * as metamaskDucks from '../../ducks/metamask/metamask';
+import * as environment from '../../../shared/modules/environment';
+import {
+ ShieldSubscriptionProvider,
+ useShieldSubscriptionContext,
+} from './shield-subscription';
+
+jest.mock('../../hooks/subscription/useSubscription');
+jest.mock('../../hooks/shield/metrics/useSubscriptionMetrics');
+jest.mock('../../store/actions', () => ({
+ assignUserToCohort: jest.fn(),
+ setPendingShieldCohort: jest.fn(),
+ setShowShieldEntryModalOnce: jest.fn(),
+ subscriptionsStartPolling: jest.fn(),
+}));
+
+describe('ShieldSubscriptionProvider', () => {
+ const mockDispatch = jest.fn();
+ const mockGetSubscriptionEligibility = jest.fn();
+ const mockCaptureShieldEligibilityCohortEvent = jest.fn();
+
+ beforeEach(() => {
+ jest.clearAllMocks();
+
+ // Mock Redux hooks
+ jest.spyOn(redux, 'useDispatch').mockReturnValue(mockDispatch);
+ jest.spyOn(redux, 'useSelector').mockImplementation((selector) => {
+ if (selector === selectors.getUseExternalServices) {
+ return true;
+ }
+ if (selector === metamaskDucks.getIsUnlocked) {
+ return true;
+ }
+ if (selector === authSelectors.selectIsSignedIn) {
+ return true;
+ }
+ if (selector === subscriptionSelectors.getIsActiveShieldSubscription) {
+ return false;
+ }
+ if (selector === subscriptionSelectors.getHasShieldEntryModalShownOnce) {
+ return false;
+ }
+ return false;
+ });
+
+ // Mock environment
+ jest
+ .spyOn(environment, 'getIsMetaMaskShieldFeatureEnabled')
+ .mockReturnValue(true);
+
+ // Mock hooks
+ jest.spyOn(useSubscription, 'useSubscriptionEligibility').mockReturnValue({
+ getSubscriptionEligibility: mockGetSubscriptionEligibility,
+ });
+
+ jest
+ .spyOn(useSubscriptionMetrics, 'useSubscriptionMetrics')
+ .mockReturnValue({
+ captureShieldEligibilityCohortEvent:
+ mockCaptureShieldEligibilityCohortEvent,
+ } as unknown as ReturnType<
+ typeof useSubscriptionMetrics.useSubscriptionMetrics
+ >);
+ });
+
+ it('renders children correctly', () => {
+ const { getByTestId } = render(
+
+ Child Component
+ ,
+ );
+
+ expect(getByTestId('child')).toBeInTheDocument();
+ });
+
+ it('provides context with evaluateCohortEligibility function', () => {
+ const TestConsumer = () => {
+ const { evaluateCohortEligibility } = useShieldSubscriptionContext();
+ return (
+
+ {typeof evaluateCohortEligibility === 'function' ? 'function' : 'not'}
+
+ );
+ };
+
+ const { getByTestId } = render(
+
+
+ ,
+ );
+
+ expect(getByTestId('consumer')).toHaveTextContent('function');
+ });
+
+ describe('Memoization', () => {
+ it('provides stable evaluateCohortEligibility callback across re-renders', () => {
+ const callbacks: ((cohort: string) => Promise)[] = [];
+
+ const TestConsumer = () => {
+ const { evaluateCohortEligibility } = useShieldSubscriptionContext();
+ const renderCount = useRef(0);
+
+ useEffect(() => {
+ callbacks.push(evaluateCohortEligibility);
+ renderCount.current += 1;
+ });
+
+ return Consumer
;
+ };
+
+ const { rerender } = render(
+
+
+ ,
+ );
+
+ // Force re-render
+ rerender(
+
+
+ ,
+ );
+
+ // The callback should be the same reference across renders
+ expect(callbacks.length).toBeGreaterThanOrEqual(2);
+ expect(callbacks[0]).toBe(callbacks[1]);
+ });
+
+ it('provides stable context value object across re-renders', () => {
+ const contexts: {
+ evaluateCohortEligibility: (cohort: string) => Promise;
+ }[] = [];
+
+ const TestConsumer = () => {
+ const context = useShieldSubscriptionContext();
+
+ useEffect(() => {
+ contexts.push(context);
+ });
+
+ return Consumer
;
+ };
+
+ const { rerender } = render(
+
+
+ ,
+ );
+
+ // Force re-render
+ rerender(
+
+
+ ,
+ );
+
+ // The context object should be the same reference across renders
+ expect(contexts.length).toBeGreaterThanOrEqual(2);
+ expect(contexts[0]).toBe(contexts[1]);
+ });
+ });
+
+ describe('evaluateCohortEligibility', () => {
+ it('can be called successfully', async () => {
+ mockGetSubscriptionEligibility.mockResolvedValue({
+ canSubscribe: true,
+ canViewEntryModal: true,
+ cohorts: [],
+ assignedCohort: null,
+ hasAssignedCohortExpired: false,
+ modalType: 'entry',
+ });
+
+ const evaluateFnRef: {
+ current: ((cohort: string) => Promise) | null;
+ } = { current: null };
+
+ const TestConsumer = () => {
+ const { evaluateCohortEligibility } = useShieldSubscriptionContext();
+ // eslint-disable-next-line react-compiler/react-compiler
+ evaluateFnRef.current = evaluateCohortEligibility;
+ return Consumer
;
+ };
+
+ render(
+
+
+ ,
+ );
+
+ await evaluateFnRef.current?.('wallet_home');
+
+ await waitFor(() => {
+ expect(mockGetSubscriptionEligibility).toHaveBeenCalled();
+ });
+ });
+
+ it('accesses current values even with stable callback', async () => {
+ let isBasicFunctionalityEnabled = false;
+
+ jest.spyOn(redux, 'useSelector').mockImplementation((selector) => {
+ if (selector === selectors.getUseExternalServices) {
+ return isBasicFunctionalityEnabled;
+ }
+ if (selector === metamaskDucks.getIsUnlocked) {
+ return true;
+ }
+ if (selector === authSelectors.selectIsSignedIn) {
+ return true;
+ }
+ if (selector === subscriptionSelectors.getIsActiveShieldSubscription) {
+ return false;
+ }
+ if (
+ selector === subscriptionSelectors.getHasShieldEntryModalShownOnce
+ ) {
+ return false;
+ }
+ return false;
+ });
+
+ mockGetSubscriptionEligibility.mockResolvedValue({
+ canSubscribe: true,
+ canViewEntryModal: true,
+ cohorts: [],
+ assignedCohort: null,
+ hasAssignedCohortExpired: false,
+ modalType: 'entry',
+ });
+
+ const evaluateFnRef: {
+ current: ((cohort: string) => Promise) | null;
+ } = { current: null };
+
+ const TestConsumer = () => {
+ const { evaluateCohortEligibility } = useShieldSubscriptionContext();
+ // eslint-disable-next-line react-compiler/react-compiler
+ evaluateFnRef.current = evaluateCohortEligibility;
+ return Consumer
;
+ };
+
+ const { rerender } = render(
+
+
+ ,
+ );
+
+ await evaluateFnRef.current?.('wallet_home');
+
+ await waitFor(() => {
+ expect(mockGetSubscriptionEligibility).not.toHaveBeenCalled();
+ });
+
+ isBasicFunctionalityEnabled = true;
+
+ rerender(
+
+
+ ,
+ );
+
+ await evaluateFnRef.current?.('wallet_home');
+
+ await waitFor(
+ () => {
+ expect(mockGetSubscriptionEligibility).toHaveBeenCalled();
+ },
+ { timeout: 3000 },
+ );
+ });
+ });
+});
diff --git a/ui/contexts/shield/shield-subscription.tsx b/ui/contexts/shield/shield-subscription.tsx
index 32f325fb782e..2042441ab1eb 100644
--- a/ui/contexts/shield/shield-subscription.tsx
+++ b/ui/contexts/shield/shield-subscription.tsx
@@ -1,4 +1,10 @@
-import React, { useCallback, useContext, useEffect } from 'react';
+import React, {
+ useCallback,
+ useContext,
+ useEffect,
+ useMemo,
+ useRef,
+} from 'react';
import { useDispatch, useSelector } from 'react-redux';
import {
PRODUCT_TYPES,
@@ -123,20 +129,20 @@ export const ShieldSubscriptionProvider: React.FC = ({ children }) => {
);
/**
- * Evaluates cohort eligibility at a specific entrypoint.
- * Follows the flowchart logic for cohort assignment and modal display.
- *
- * Shield entry modal will be shown if:
- * - MetaMask Shield feature is enabled
- * - Basic functionality is enabled
- * - Subscription is not active
- * - User is signed in and unlocked
- * - User has not shown the shield entry modal before
- * - User's balance meets the minimum fiat balance threshold
- * - User meets cohort-specific eligibility criteria
+ * Ref to hold the latest implementation of evaluateCohortEligibility.
+ * This allows the stable callback to access current values without recreating.
*/
- const evaluateCohortEligibility = useCallback(
- async (entrypointCohort: string): Promise => {
+ const evaluateCohortEligibilityRef =
+ useRef<(entrypointCohort: string) => Promise>();
+
+ /**
+ * Update the ref with the latest implementation whenever dependencies change.
+ * This ensures the stable callback always has access to current values.
+ */
+ useEffect(() => {
+ evaluateCohortEligibilityRef.current = async (
+ entrypointCohort: string,
+ ): Promise => {
try {
if (!isMetaMaskShieldFeatureEnabled || !isBasicFunctionalityEnabled) {
return;
@@ -264,19 +270,41 @@ export const ShieldSubscriptionProvider: React.FC = ({ children }) => {
);
log.warn('[evaluateCohortEligibility] error', error);
}
+ };
+ }, [
+ dispatch,
+ isMetaMaskShieldFeatureEnabled,
+ isBasicFunctionalityEnabled,
+ isShieldSubscriptionActive,
+ isSignedIn,
+ isUnlocked,
+ hasShieldEntryModalShownOnce,
+ getShieldSubscriptionEligibility,
+ assignToCohort,
+ captureShieldEligibilityCohortEvent,
+ ]);
+
+ /**
+ * Evaluates cohort eligibility at a specific entrypoint.
+ * Follows the flowchart logic for cohort assignment and modal display.
+ *
+ * Shield entry modal will be shown if:
+ * - MetaMask Shield feature is enabled
+ * - Basic functionality is enabled
+ * - Subscription is not active
+ * - User is signed in and unlocked
+ * - User has not shown the shield entry modal before
+ * - User's balance meets the minimum fiat balance threshold
+ * - User meets cohort-specific eligibility criteria
+ *
+ * This callback remains stable across renders using the ref pattern,
+ * preventing unnecessary re-renders of consuming components.
+ */
+ const evaluateCohortEligibility = useCallback(
+ async (entrypointCohort: string): Promise => {
+ await evaluateCohortEligibilityRef.current?.(entrypointCohort);
},
- [
- dispatch,
- isMetaMaskShieldFeatureEnabled,
- isBasicFunctionalityEnabled,
- isShieldSubscriptionActive,
- isSignedIn,
- isUnlocked,
- hasShieldEntryModalShownOnce,
- getShieldSubscriptionEligibility,
- assignToCohort,
- captureShieldEligibilityCohortEvent,
- ],
+ [], // Empty deps = always stable!
);
useEffect(() => {
@@ -297,12 +325,17 @@ export const ShieldSubscriptionProvider: React.FC = ({ children }) => {
isBasicFunctionalityEnabled,
]);
+ /**
+ * Memoize the context value to prevent creating a new object reference
+ * on every render, which would cause unnecessary re-renders of consuming components.
+ */
+ const contextValue = useMemo(
+ () => ({ evaluateCohortEligibility }),
+ [evaluateCohortEligibility],
+ );
+
return (
-
+
{children}
);