diff --git a/src/execution/__tests__/mapAsyncIterable-test.ts b/src/execution/__tests__/mapAsyncIterable-test.ts index ac3753c9b4..8237582c58 100644 --- a/src/execution/__tests__/mapAsyncIterable-test.ts +++ b/src/execution/__tests__/mapAsyncIterable-test.ts @@ -264,6 +264,48 @@ describe('mapAsyncIterable', () => { await expectPromise(thrown).toRejectWith(message); }); + it('close source when mapped iterable is thrown even when the underlying source does not implement a throw method', async () => { + const items = [1, 2, 3]; + let returned = false; + const iterable: AsyncIterableIterator = { + [Symbol.asyncIterator]() { + return this; + }, + next() { + if (returned) { + return Promise.resolve({ done: true, value: undefined }); + } + const value = items[0]; + items.shift(); + return Promise.resolve({ + done: items.length === 0, + value, + }); + }, + return: () => { + returned = true; + return Promise.resolve({ done: true, value: undefined }); + }, + }; + + const doubles = mapAsyncIterable(iterable, (x) => x + x); + + expect(await doubles.next()).to.deep.equal({ value: 2, done: false }); + expect(await doubles.next()).to.deep.equal({ value: 4, done: false }); + + // Throw error + const message = 'allows throwing errors when mapping async iterable'; + const thrown = doubles.throw(new Error(message)); + await expectPromise(thrown).toRejectWith(message); + + // Returns early when throwing errors through async iterable + expect(returned).to.equal(true); + expect(await doubles.next()).to.deep.equal({ + value: undefined, + done: true, + }); + }); + it('passes through caught errors through async generators', async () => { async function* source() { try { diff --git a/src/execution/mapAsyncIterable.ts b/src/execution/mapAsyncIterable.ts index d85c7c4959..e3511bc4c4 100644 --- a/src/execution/mapAsyncIterable.ts +++ b/src/execution/mapAsyncIterable.ts @@ -30,17 +30,19 @@ export function mapAsyncIterable( try { return { value: await callback(value), done: false }; } catch (error) { - /* c8 ignore start */ - // FIXME: add test case - if (typeof iterator.return === 'function') { - try { - await iterator.return(); - } catch (_e) { - /* ignore error */ - } - } + await returnIgnoringErrors(); throw error; - /* c8 ignore stop */ + } + } + + async function returnIgnoringErrors(): Promise { + if (typeof iterator.return === 'function') { + try { + await iterator.return(); /* c8 ignore start */ + } catch (_error) { + // FIXME: add test case + /* ignore error */ + } /* c8 ignore stop */ } } @@ -58,6 +60,11 @@ export function mapAsyncIterable( if (typeof iterator.throw === 'function') { return mapResult(iterator.throw(error)); } + + if (typeof iterator.return === 'function') { + await returnIgnoringErrors(); + } + throw error; }, [Symbol.asyncIterator]() {