diff --git a/ui/contexts/shield/shield-subscription.test.tsx b/ui/contexts/shield/shield-subscription.test.tsx new file mode 100644 index 000000000000..a8186fb49b76 --- /dev/null +++ b/ui/contexts/shield/shield-subscription.test.tsx @@ -0,0 +1,281 @@ +import React, { useEffect, useRef } from 'react'; +import { render, waitFor } from '@testing-library/react'; +import '@testing-library/jest-dom'; +import * as redux from 'react-redux'; +import * as useSubscription from '../../hooks/subscription/useSubscription'; +import * as useSubscriptionMetrics from '../../hooks/shield/metrics/useSubscriptionMetrics'; +import * as selectors from '../../selectors'; +import * as authSelectors from '../../selectors/identity/authentication'; +import * as subscriptionSelectors from '../../selectors/subscription'; +import * as metamaskDucks from '../../ducks/metamask/metamask'; +import * as environment from '../../../shared/modules/environment'; +import { + ShieldSubscriptionProvider, + useShieldSubscriptionContext, +} from './shield-subscription'; + +jest.mock('../../hooks/subscription/useSubscription'); +jest.mock('../../hooks/shield/metrics/useSubscriptionMetrics'); +jest.mock('../../store/actions', () => ({ + assignUserToCohort: jest.fn(), + setPendingShieldCohort: jest.fn(), + setShowShieldEntryModalOnce: jest.fn(), + subscriptionsStartPolling: jest.fn(), +})); + +describe('ShieldSubscriptionProvider', () => { + const mockDispatch = jest.fn(); + const mockGetSubscriptionEligibility = jest.fn(); + const mockCaptureShieldEligibilityCohortEvent = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock Redux hooks + jest.spyOn(redux, 'useDispatch').mockReturnValue(mockDispatch); + jest.spyOn(redux, 'useSelector').mockImplementation((selector) => { + if (selector === selectors.getUseExternalServices) { + return true; + } + if (selector === metamaskDucks.getIsUnlocked) { + return true; + } + if (selector === authSelectors.selectIsSignedIn) { + return true; + } + if (selector === subscriptionSelectors.getIsActiveShieldSubscription) { + return false; + } + if (selector === subscriptionSelectors.getHasShieldEntryModalShownOnce) { + return false; + } + return false; + }); + + // Mock environment + jest + .spyOn(environment, 'getIsMetaMaskShieldFeatureEnabled') + .mockReturnValue(true); + + // Mock hooks + jest.spyOn(useSubscription, 'useSubscriptionEligibility').mockReturnValue({ + getSubscriptionEligibility: mockGetSubscriptionEligibility, + }); + + jest + .spyOn(useSubscriptionMetrics, 'useSubscriptionMetrics') + .mockReturnValue({ + captureShieldEligibilityCohortEvent: + mockCaptureShieldEligibilityCohortEvent, + } as unknown as ReturnType< + typeof useSubscriptionMetrics.useSubscriptionMetrics + >); + }); + + it('renders children correctly', () => { + const { getByTestId } = render( + +
Child Component
+
, + ); + + expect(getByTestId('child')).toBeInTheDocument(); + }); + + it('provides context with evaluateCohortEligibility function', () => { + const TestConsumer = () => { + const { evaluateCohortEligibility } = useShieldSubscriptionContext(); + return ( +
+ {typeof evaluateCohortEligibility === 'function' ? 'function' : 'not'} +
+ ); + }; + + const { getByTestId } = render( + + + , + ); + + expect(getByTestId('consumer')).toHaveTextContent('function'); + }); + + describe('Memoization', () => { + it('provides stable evaluateCohortEligibility callback across re-renders', () => { + const callbacks: ((cohort: string) => Promise)[] = []; + + const TestConsumer = () => { + const { evaluateCohortEligibility } = useShieldSubscriptionContext(); + const renderCount = useRef(0); + + useEffect(() => { + callbacks.push(evaluateCohortEligibility); + renderCount.current += 1; + }); + + return
Consumer
; + }; + + const { rerender } = render( + + + , + ); + + // Force re-render + rerender( + + + , + ); + + // The callback should be the same reference across renders + expect(callbacks.length).toBeGreaterThanOrEqual(2); + expect(callbacks[0]).toBe(callbacks[1]); + }); + + it('provides stable context value object across re-renders', () => { + const contexts: { + evaluateCohortEligibility: (cohort: string) => Promise; + }[] = []; + + const TestConsumer = () => { + const context = useShieldSubscriptionContext(); + + useEffect(() => { + contexts.push(context); + }); + + return
Consumer
; + }; + + const { rerender } = render( + + + , + ); + + // Force re-render + rerender( + + + , + ); + + // The context object should be the same reference across renders + expect(contexts.length).toBeGreaterThanOrEqual(2); + expect(contexts[0]).toBe(contexts[1]); + }); + }); + + describe('evaluateCohortEligibility', () => { + it('can be called successfully', async () => { + mockGetSubscriptionEligibility.mockResolvedValue({ + canSubscribe: true, + canViewEntryModal: true, + cohorts: [], + assignedCohort: null, + hasAssignedCohortExpired: false, + modalType: 'entry', + }); + + const evaluateFnRef: { + current: ((cohort: string) => Promise) | null; + } = { current: null }; + + const TestConsumer = () => { + const { evaluateCohortEligibility } = useShieldSubscriptionContext(); + // eslint-disable-next-line react-compiler/react-compiler + evaluateFnRef.current = evaluateCohortEligibility; + return
Consumer
; + }; + + render( + + + , + ); + + await evaluateFnRef.current?.('wallet_home'); + + await waitFor(() => { + expect(mockGetSubscriptionEligibility).toHaveBeenCalled(); + }); + }); + + it('accesses current values even with stable callback', async () => { + let isBasicFunctionalityEnabled = false; + + jest.spyOn(redux, 'useSelector').mockImplementation((selector) => { + if (selector === selectors.getUseExternalServices) { + return isBasicFunctionalityEnabled; + } + if (selector === metamaskDucks.getIsUnlocked) { + return true; + } + if (selector === authSelectors.selectIsSignedIn) { + return true; + } + if (selector === subscriptionSelectors.getIsActiveShieldSubscription) { + return false; + } + if ( + selector === subscriptionSelectors.getHasShieldEntryModalShownOnce + ) { + return false; + } + return false; + }); + + mockGetSubscriptionEligibility.mockResolvedValue({ + canSubscribe: true, + canViewEntryModal: true, + cohorts: [], + assignedCohort: null, + hasAssignedCohortExpired: false, + modalType: 'entry', + }); + + const evaluateFnRef: { + current: ((cohort: string) => Promise) | null; + } = { current: null }; + + const TestConsumer = () => { + const { evaluateCohortEligibility } = useShieldSubscriptionContext(); + // eslint-disable-next-line react-compiler/react-compiler + evaluateFnRef.current = evaluateCohortEligibility; + return
Consumer
; + }; + + const { rerender } = render( + + + , + ); + + await evaluateFnRef.current?.('wallet_home'); + + await waitFor(() => { + expect(mockGetSubscriptionEligibility).not.toHaveBeenCalled(); + }); + + isBasicFunctionalityEnabled = true; + + rerender( + + + , + ); + + await evaluateFnRef.current?.('wallet_home'); + + await waitFor( + () => { + expect(mockGetSubscriptionEligibility).toHaveBeenCalled(); + }, + { timeout: 3000 }, + ); + }); + }); +}); diff --git a/ui/contexts/shield/shield-subscription.tsx b/ui/contexts/shield/shield-subscription.tsx index 32f325fb782e..2042441ab1eb 100644 --- a/ui/contexts/shield/shield-subscription.tsx +++ b/ui/contexts/shield/shield-subscription.tsx @@ -1,4 +1,10 @@ -import React, { useCallback, useContext, useEffect } from 'react'; +import React, { + useCallback, + useContext, + useEffect, + useMemo, + useRef, +} from 'react'; import { useDispatch, useSelector } from 'react-redux'; import { PRODUCT_TYPES, @@ -123,20 +129,20 @@ export const ShieldSubscriptionProvider: React.FC = ({ children }) => { ); /** - * Evaluates cohort eligibility at a specific entrypoint. - * Follows the flowchart logic for cohort assignment and modal display. - * - * Shield entry modal will be shown if: - * - MetaMask Shield feature is enabled - * - Basic functionality is enabled - * - Subscription is not active - * - User is signed in and unlocked - * - User has not shown the shield entry modal before - * - User's balance meets the minimum fiat balance threshold - * - User meets cohort-specific eligibility criteria + * Ref to hold the latest implementation of evaluateCohortEligibility. + * This allows the stable callback to access current values without recreating. */ - const evaluateCohortEligibility = useCallback( - async (entrypointCohort: string): Promise => { + const evaluateCohortEligibilityRef = + useRef<(entrypointCohort: string) => Promise>(); + + /** + * Update the ref with the latest implementation whenever dependencies change. + * This ensures the stable callback always has access to current values. + */ + useEffect(() => { + evaluateCohortEligibilityRef.current = async ( + entrypointCohort: string, + ): Promise => { try { if (!isMetaMaskShieldFeatureEnabled || !isBasicFunctionalityEnabled) { return; @@ -264,19 +270,41 @@ export const ShieldSubscriptionProvider: React.FC = ({ children }) => { ); log.warn('[evaluateCohortEligibility] error', error); } + }; + }, [ + dispatch, + isMetaMaskShieldFeatureEnabled, + isBasicFunctionalityEnabled, + isShieldSubscriptionActive, + isSignedIn, + isUnlocked, + hasShieldEntryModalShownOnce, + getShieldSubscriptionEligibility, + assignToCohort, + captureShieldEligibilityCohortEvent, + ]); + + /** + * Evaluates cohort eligibility at a specific entrypoint. + * Follows the flowchart logic for cohort assignment and modal display. + * + * Shield entry modal will be shown if: + * - MetaMask Shield feature is enabled + * - Basic functionality is enabled + * - Subscription is not active + * - User is signed in and unlocked + * - User has not shown the shield entry modal before + * - User's balance meets the minimum fiat balance threshold + * - User meets cohort-specific eligibility criteria + * + * This callback remains stable across renders using the ref pattern, + * preventing unnecessary re-renders of consuming components. + */ + const evaluateCohortEligibility = useCallback( + async (entrypointCohort: string): Promise => { + await evaluateCohortEligibilityRef.current?.(entrypointCohort); }, - [ - dispatch, - isMetaMaskShieldFeatureEnabled, - isBasicFunctionalityEnabled, - isShieldSubscriptionActive, - isSignedIn, - isUnlocked, - hasShieldEntryModalShownOnce, - getShieldSubscriptionEligibility, - assignToCohort, - captureShieldEligibilityCohortEvent, - ], + [], // Empty deps = always stable! ); useEffect(() => { @@ -297,12 +325,17 @@ export const ShieldSubscriptionProvider: React.FC = ({ children }) => { isBasicFunctionalityEnabled, ]); + /** + * Memoize the context value to prevent creating a new object reference + * on every render, which would cause unnecessary re-renders of consuming components. + */ + const contextValue = useMemo( + () => ({ evaluateCohortEligibility }), + [evaluateCohortEligibility], + ); + return ( - + {children} );