diff --git a/src/internal/portal/__tests__/portal.test.tsx b/src/internal/portal/__tests__/portal.test.tsx index b5c8b04..a00ee22 100644 --- a/src/internal/portal/__tests__/portal.test.tsx +++ b/src/internal/portal/__tests__/portal.test.tsx @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 import React, { useState } from 'react'; -import { act, fireEvent, render, screen } from '@testing-library/react'; +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'; import { warnOnce } from '../../logging'; import Portal, { PortalProps } from '../index'; @@ -108,6 +108,30 @@ describe('Portal', () => { expect(document.body.contains(container)).toBe(false); }); + test('should support aborting async container setup', async () => { + const container = document.createElement('div'); + const onAbort = jest.fn(); + const onContinue = jest.fn(); + const getContainer: PortalProps['getContainer'] = async ({ abortSignal }) => { + abortSignal.addEventListener('abort', onAbort); + await Promise.resolve(); + onContinue(abortSignal.aborted); + return container; + }; + const removeContainer = jest.fn(); + const { unmount } = renderPortal({ + children:

Hello!

, + getContainer, + removeContainer, + }); + unmount(); + await waitFor(() => { + expect(onContinue).not.toHaveBeenCalled(); + expect(onAbort).toHaveBeenCalled(); + expect(removeContainer).toHaveBeenCalledWith(null); + }); + }); + test('allows conditional change of getContainer/removeContainer', async () => { function MovablePortal({ getContainer, removeContainer }: Pick) { const [visible, setVisible] = useState(false); diff --git a/src/internal/portal/index.tsx b/src/internal/portal/index.tsx index 3582ea9..bc47c5c 100644 --- a/src/internal/portal/index.tsx +++ b/src/internal/portal/index.tsx @@ -8,8 +8,8 @@ import { warnOnce } from '../logging'; export interface PortalProps { container?: null | Element; - getContainer?: () => Promise; - removeContainer?: (container: HTMLElement) => void; + getContainer?: (options: { abortSignal: AbortSignal }) => Promise; + removeContainer?: (container: HTMLElement | null) => void; children: React.ReactNode; } @@ -23,13 +23,17 @@ function manageDefaultContainer(setState: React.Dispatch Promise, - removeContainer: (container: HTMLElement) => void, + getContainer: (options: { abortSignal: AbortSignal }) => Promise, + removeContainer: (container: HTMLElement | null) => void, setState: React.Dispatch> ) { - let newContainer: HTMLElement; - getContainer().then( + let newContainer: HTMLElement | null = null; + const abortController = new AbortController(); + getContainer({ abortSignal: abortController.signal }).then( container => { + if (abortController.signal.aborted) { + return; + } newContainer = container; setState(container); }, @@ -38,6 +42,7 @@ function manageAsyncContainer( } ); return () => { + abortController.abort(); removeContainer(newContainer); }; }