diff --git a/.changeset/thin-jobs-grab.md b/.changeset/thin-jobs-grab.md new file mode 100644 index 0000000000..ac806dac83 --- /dev/null +++ b/.changeset/thin-jobs-grab.md @@ -0,0 +1,5 @@ +--- +'livekit-client': patch +--- + +Adds new OutgoingDataTrackManager to manage sending data track payloads diff --git a/src/index.ts b/src/index.ts index d8a0179e84..1c18cd5c28 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,7 +15,7 @@ import * as attributes from './room/attribute-typings'; // FIXME: remove this import in a follow up data track pull request. import './room/data-track/depacketizer'; // FIXME: remove this import in a follow up data track pull request. -import './room/data-track/packetizer'; +import './room/data-track/outgoing/OutgoingDataTrackManager'; import LocalParticipant from './room/participant/LocalParticipant'; import Participant, { ConnectionQuality, ParticipantKind } from './room/participant/Participant'; import type { ParticipantTrackPermission } from './room/participant/ParticipantTrackPermission'; diff --git a/src/room/data-track/depacketizer.test.ts b/src/room/data-track/depacketizer.test.ts index 9cf74123a8..9a35ec9069 100644 --- a/src/room/data-track/depacketizer.test.ts +++ b/src/room/data-track/depacketizer.test.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import { DataTrackDepacketizer } from './depacketizer'; +import DataTrackDepacketizer from './depacketizer'; import { DataTrackHandle } from './handle'; import { DataTrackPacket, DataTrackPacketHeader, FrameMarker } from './packet'; import { DataTrackTimestamp, WrapAroundUnsignedInt } from './utils'; diff --git a/src/room/data-track/depacketizer.ts b/src/room/data-track/depacketizer.ts index 9f88908b6b..730193afe0 100644 --- a/src/room/data-track/depacketizer.ts +++ b/src/room/data-track/depacketizer.ts @@ -85,7 +85,7 @@ type PushOptions = { errorOnPartialFrames: boolean; }; -export class DataTrackDepacketizer { +export default class DataTrackDepacketizer { /** Maximum number of packets to buffer per frame before dropping. */ static MAX_BUFFER_PACKETS = 128; diff --git a/src/room/data-track/e2ee.ts b/src/room/data-track/e2ee.ts new file mode 100644 index 0000000000..29064787ae --- /dev/null +++ b/src/room/data-track/e2ee.ts @@ -0,0 +1,14 @@ +export type EncryptedPayload = { + payload: Uint8Array; + iv: Uint8Array; // NOTE: should be 12 bytes long + keyIndex: number; +}; + +export type EncryptionProvider = { + // FIXME: add in explicit `Throws<..., EncryptionError>`? + encrypt(payload: Uint8Array): EncryptedPayload; +}; + +export type DecryptionProvider = { + decrypt(payload: Uint8Array, senderIdentity: string): Uint8Array; +}; diff --git a/src/room/data-track/frame.ts b/src/room/data-track/frame.ts index 0e3e2ed689..56dbee9b56 100644 --- a/src/room/data-track/frame.ts +++ b/src/room/data-track/frame.ts @@ -1,5 +1,5 @@ import { DataTrackExtensions } from './packet/extensions'; -import { DataTrackPacketizer } from './packetizer'; +import DataTrackPacketizer from './packetizer'; /** A pair of payload bytes and packet extensions which can be fed into a {@link DataTrackPacketizer}. */ export type DataTrackFrame = { diff --git a/src/room/data-track/handle.test.ts b/src/room/data-track/handle.test.ts index 81592e46e3..c486c4a8c0 100644 --- a/src/room/data-track/handle.test.ts +++ b/src/room/data-track/handle.test.ts @@ -4,7 +4,7 @@ import { DataTrackHandle } from './handle'; describe('DataTrackHandle', () => { it('should parse handle raw inputs', () => { - expect(DataTrackHandle.fromNumber(3).value).toEqual(3); + expect(DataTrackHandle.fromNumber(3)).toEqual(3); expect(() => DataTrackHandle.fromNumber(0)).toThrow('0x0 is a reserved value'); expect(() => DataTrackHandle.fromNumber(9999999)).toThrow( 'Value too large to be a valid track handle', diff --git a/src/room/data-track/handle.ts b/src/room/data-track/handle.ts index c02c6bc24c..58f7688026 100644 --- a/src/room/data-track/handle.ts +++ b/src/room/data-track/handle.ts @@ -41,10 +41,9 @@ export class DataTrackHandleError< } } -export class DataTrackHandle { - public value: number; - - static fromNumber( +export type DataTrackHandle = number; +export const DataTrackHandle = { + fromNumber( raw: number, ): Throws< DataTrackHandle, @@ -57,24 +56,20 @@ export class DataTrackHandle { if (raw > U16_MAX_SIZE) { throw DataTrackHandleError.tooLarge(); } - return new DataTrackHandle(raw); - } - - constructor(raw: number) { - this.value = raw; - } -} + return raw; + }, +}; /** Manage allocating new handles which don't conflict over the lifetime of the client. */ export class DataTrackHandleAllocator { - static value = 0; + value = 0; /** Returns a unique track handle for the next publication, if one can be obtained. */ - static get(): DataTrackHandle | null { + get(): DataTrackHandle | null { this.value += 1; if (this.value > U16_MAX_SIZE) { return null; } - return new DataTrackHandle(this.value); + return this.value; } } diff --git a/src/room/data-track/outgoing/OutgoingDataTrackManager.test.ts b/src/room/data-track/outgoing/OutgoingDataTrackManager.test.ts new file mode 100644 index 0000000000..cea8f6f8c7 --- /dev/null +++ b/src/room/data-track/outgoing/OutgoingDataTrackManager.test.ts @@ -0,0 +1,392 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { subscribeToEvents } from '../../../utils/subscribeToEvents'; +import { EncryptionProvider } from '../e2ee'; +import { DataTrackHandle } from '../handle'; +import { DataTrackPacket, FrameMarker } from '../packet'; +import OutgoingDataTrackManager, { + DataTrackOutgoingManagerCallbacks, + Descriptor, +} from './OutgoingDataTrackManager'; +import { DataTrackPublishError } from './errors'; + +/** A fake "encryption" provider used for test purposes. Adds a prefix to the payload. */ +const PrefixingEncryptionProvider: EncryptionProvider = { + encrypt(payload: Uint8Array) { + const prefix = new Uint8Array([0xde, 0xad, 0xbe, 0xef]); + + const output = new Uint8Array(prefix.length + payload.length); + output.set(prefix, 0); + output.set(payload, prefix.length); + + return { + payload: output, + iv: new Uint8Array(12), // Just leaving this empty, is this a bad idea? + keyIndex: 0, + }; + }, +}; + +describe('DataTrackOutgoingManager', () => { + it('should test track publishing (ok case)', async () => { + const manager = new OutgoingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuPublishRequest', + ]); + + // 1. Publish a data track + const publishRequestPromise = manager.publishRequest({ name: 'test' }); + + // 2. This publish request should be sent along to the SFU + const sfuPublishEvent = await managerEvents.waitFor('sfuPublishRequest'); + expect(sfuPublishEvent.name).toStrictEqual('test'); + expect(sfuPublishEvent.usesE2ee).toStrictEqual(false); + const handle = sfuPublishEvent.handle; + + // 3. Respond to the SFU publish request with an OK response + manager.receivedSfuPublishResponse(handle, { + type: 'ok', + data: { + sid: 'bogus-sid', + pubHandle: sfuPublishEvent.handle, + name: 'test', + usesE2ee: false, + }, + }); + + // Make sure that the original input event resolves. + const localDataTrack = await publishRequestPromise; + expect(localDataTrack.isPublished()).toStrictEqual(true); + }); + + it('should test track publishing (error case)', async () => { + const manager = new OutgoingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuPublishRequest', + ]); + + // 1. Publish a data track + const publishRequestPromise = manager.publishRequest({ name: 'test' }); + + // 2. This publish request should be sent along to the SFU + const sfuPublishEvent = await managerEvents.waitFor('sfuPublishRequest'); + + // 3. Respond to the SFU publish request with an ERROR response + manager.receivedSfuPublishResponse(sfuPublishEvent.handle, { + type: 'error', + error: DataTrackPublishError.limitReached(), + }); + + // Make sure that the rejection bubbles back to the caller + expect(publishRequestPromise).rejects.toThrowError('Data track publication limit reached'); + }); + + it('should test track publishing (cancellation half way through)', async () => { + const manager = new OutgoingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuPublishRequest', + 'sfuUnpublishRequest', + ]); + + // 1. Publish a data track + const controller = new AbortController(); + const publishRequestPromise = manager.publishRequest({ name: 'test' }, controller.signal); + + // 2. This publish request should be sent along to the SFU + const sfuPublishEvent = await managerEvents.waitFor('sfuPublishRequest'); + expect(sfuPublishEvent.name).toStrictEqual('test'); + expect(sfuPublishEvent.usesE2ee).toStrictEqual(false); + const handle = sfuPublishEvent.handle; + + // 3. Explictly cancel the publish + controller.abort(); + + // 4. Make sure an unpublish event is sent so that the SFU cleans up things properly + // on its end as well + const sfuUnpublishEvent = await managerEvents.waitFor('sfuUnpublishRequest'); + expect(sfuUnpublishEvent.handle).toStrictEqual(handle); + + // 5. Make sure cancellation is bubbled up as an error to stop further execution + expect(publishRequestPromise).rejects.toStrictEqual(DataTrackPublishError.cancelled()); + }); + + it.each([ + // Single packet payload case + [ + new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05]), + [ + { + header: { + extensions: { + e2ee: null, + userTimestamp: null, + }, + frameNumber: 0, + marker: FrameMarker.Single, + sequence: 0, + timestamp: expect.anything(), + trackHandle: 5, + }, + payload: new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05]), + }, + ], + ], + + // Multi packet payload case + [ + new Uint8Array(24_000).fill(0xbe), + [ + { + header: { + extensions: { + e2ee: null, + userTimestamp: null, + }, + frameNumber: 0, + marker: FrameMarker.Start, + sequence: 0, + timestamp: expect.anything(), + trackHandle: 5, + }, + payload: new Uint8Array(15988 /* 16k mtu - 12 header bytes */).fill(0xbe), + }, + { + header: { + extensions: { + e2ee: null, + userTimestamp: null, + }, + frameNumber: 0, + marker: FrameMarker.Final, + sequence: 1, + timestamp: expect.anything(), + trackHandle: 5, + }, + payload: new Uint8Array(8012 /* 24k payload - (16k mtu - 12 header bytes) */).fill(0xbe), + }, + ], + ], + ])( + 'should test track payload sending', + async (inputBytes: Uint8Array, outputPacketsJson: Array) => { + // Create a manager prefilled with a descriptor + const manager = OutgoingDataTrackManager.withDescriptors( + new Map([ + [ + DataTrackHandle.fromNumber(5), + Descriptor.active( + { + sid: 'bogus-sid', + pubHandle: 5, + name: 'test', + usesE2ee: false, + }, + null, + ), + ], + ]), + ); + const managerEvents = subscribeToEvents(manager, [ + 'packetsAvailable', + ]); + + const localDataTrack = manager.createLocalDataTrack(5)!; + expect(localDataTrack).not.toStrictEqual(null); + + // Kick off sending the bytes... + localDataTrack.tryPush(inputBytes); + + // ... and make sure the corresponding events are emitted to tell the SFU to send the packets + for (const outputPacketJson of outputPacketsJson) { + const packetBytes = await managerEvents.waitFor('packetsAvailable'); + const [packet] = DataTrackPacket.fromBinary(packetBytes.bytes); + + expect(packet.toJSON()).toStrictEqual(outputPacketJson); + } + }, + ); + + it('should send e2ee encrypted datatrack payload', async () => { + const manager = new OutgoingDataTrackManager({ + encryptionProvider: PrefixingEncryptionProvider, + }); + const managerEvents = subscribeToEvents(manager, [ + 'sfuPublishRequest', + 'packetsAvailable', + ]); + + // 1. Publish a data track + const publishRequestPromise = manager.publishRequest({ name: 'test' }); + + // 2. This publish request should be sent along to the SFU + const sfuPublishEvent = await managerEvents.waitFor('sfuPublishRequest'); + expect(sfuPublishEvent.name).toStrictEqual('test'); + expect(sfuPublishEvent.usesE2ee).toStrictEqual(true); // NOTE: this is true, e2ee is enabled! + const handle = sfuPublishEvent.handle; + + // 3. Respond to the SFU publish request with an OK response + manager.receivedSfuPublishResponse(handle, { + type: 'ok', + data: { + sid: 'bogus-sid', + pubHandle: sfuPublishEvent.handle, + name: 'test', + usesE2ee: true, // NOTE: this is true, e2ee is enabled! + }, + }); + + // Get the connected local data track + const localDataTrack = await publishRequestPromise; + expect(localDataTrack.isPublished()).toStrictEqual(true); + + // Kick off sending the payload bytes + localDataTrack.tryPush(new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05])); + + // Make sure the packet that was sent was encrypted with the PrefixingEncryptionProvider + const packetBytes = await managerEvents.waitFor('packetsAvailable'); + const [packet] = DataTrackPacket.fromBinary(packetBytes.bytes); + + expect(packet.toJSON()).toStrictEqual({ + header: { + extensions: { + e2ee: { + iv: new Uint8Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + keyIndex: 0, + lengthBytes: 13, + tag: 1, + }, + userTimestamp: null, + }, + frameNumber: 0, + marker: 3, + sequence: 0, + timestamp: expect.anything(), + trackHandle: 1, + }, + payload: new Uint8Array([ + // Encryption added prefix + 0xde, 0xad, 0xbe, 0xef, + // Actual payload + 0x01, 0x02, 0x03, 0x04, 0x05, + ]), + }); + }); + + it('should test track unpublishing', async () => { + // Create a manager prefilled with a descriptor + const manager = OutgoingDataTrackManager.withDescriptors( + new Map([ + [ + DataTrackHandle.fromNumber(5), + Descriptor.active( + { + sid: 'bogus-sid', + pubHandle: 5, + name: 'test', + usesE2ee: false, + }, + null, + ), + ], + ]), + ); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUnpublishRequest', + ]); + + // Make sure the descriptor is in there + expect(manager.getDescriptor(5)?.type).toStrictEqual('active'); + + // Unpublish data track + const unpublishRequestPromise = manager.unpublishRequest(DataTrackHandle.fromNumber(5)); + + const sfuUnpublishEvent = await managerEvents.waitFor('sfuUnpublishRequest'); + expect(sfuUnpublishEvent.handle).toStrictEqual(5); + + manager.receivedSfuUnpublishResponse(DataTrackHandle.fromNumber(5)); + + await unpublishRequestPromise; + + // Make sure data track is no longer + expect(manager.getDescriptor(5)).toStrictEqual(null); + }); + + it('should query currently active descriptors', async () => { + // Create a manager prefilled with a descriptor + const manager = OutgoingDataTrackManager.withDescriptors( + new Map([ + [ + DataTrackHandle.fromNumber(2), + Descriptor.active( + { + sid: 'bogus-sid-2', + pubHandle: 2, + name: 'twotwotwo', + usesE2ee: false, + }, + null, + ), + ], + [ + DataTrackHandle.fromNumber(6), + Descriptor.active( + { + sid: 'bogus-sid-6', + pubHandle: 6, + name: 'sixsixsix', + usesE2ee: false, + }, + null, + ), + ], + ]), + ); + + const result = await manager.queryPublished(); + + expect(result).toStrictEqual([ + { sid: 'bogus-sid-2', pubHandle: 2, name: 'twotwotwo', usesE2ee: false }, + { sid: 'bogus-sid-6', pubHandle: 6, name: 'sixsixsix', usesE2ee: false }, + ]); + }); + + it('should shutdown cleanly', async () => { + // Create a manager prefilled with a descriptor + const pendingDescriptor = Descriptor.pending(); + const manager = OutgoingDataTrackManager.withDescriptors( + new Map([ + [DataTrackHandle.fromNumber(2), pendingDescriptor], + [ + DataTrackHandle.fromNumber(6), + Descriptor.active( + { + sid: 'bogus-sid-6', + pubHandle: 6, + name: 'sixsixsix', + usesE2ee: false, + }, + null, + ), + ], + ]), + ); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUnpublishRequest', + ]); + + // Shut down the manager + const shutdownPromise = manager.shutdown(); + + // The pending data track should be cancelled + expect(pendingDescriptor.completionFuture.promise).rejects.toThrowError('Room disconnected'); + + // And the active data track should be requested to be unpublished + const unpublishEvent = await managerEvents.waitFor('sfuUnpublishRequest'); + expect(unpublishEvent.handle).toStrictEqual(6); + + // Acknowledge that the unpublish has occurred + manager.receivedSfuUnpublishResponse(DataTrackHandle.fromNumber(6)); + + await shutdownPromise; + }); +}); diff --git a/src/room/data-track/outgoing/OutgoingDataTrackManager.ts b/src/room/data-track/outgoing/OutgoingDataTrackManager.ts new file mode 100644 index 0000000000..244616bbb5 --- /dev/null +++ b/src/room/data-track/outgoing/OutgoingDataTrackManager.ts @@ -0,0 +1,302 @@ +import { EventEmitter } from 'events'; +import type TypedEmitter from 'typed-emitter'; +import { LoggerNames, getLogger } from '../../../logger'; +import type { Throws } from '../../../utils/throws'; +import { Future } from '../../utils'; +import { type EncryptionProvider } from '../e2ee'; +import type { DataTrackFrame } from '../frame'; +import { DataTrackHandle, DataTrackHandleAllocator } from '../handle'; +import { DataTrackExtensions } from '../packet/extensions'; +import { type DataTrackInfo, LocalDataTrack } from '../track'; +import { + DataTrackPublishError, + DataTrackPublishErrorReason, + DataTrackPushFrameError, + DataTrackPushFrameErrorReason, +} from './errors'; +import DataTrackOutgoingPipeline from './pipeline'; +import { + type DataTrackOptions, + type OutputEventPacketsAvailable, + type OutputEventSfuPublishRequest, + type OutputEventSfuUnpublishRequest, + type SfuPublishResponseResult, +} from './types'; + +const log = getLogger(LoggerNames.DataTracks); + +export type PendingDescriptor = { + type: 'pending'; + completionFuture: Future< + LocalDataTrack, + | DataTrackPublishError + | DataTrackPublishError + | DataTrackPublishError + | DataTrackPublishError + | DataTrackPublishError + | DataTrackPublishError + >; +}; +export type ActiveDescriptor = { + type: 'active'; + info: DataTrackInfo; + + pipeline: DataTrackOutgoingPipeline; + + /** Resolves when the descriptor is unpublished. */ + unpublishingFuture: Future; +}; +export type Descriptor = PendingDescriptor | ActiveDescriptor; + +export const Descriptor = { + pending(): PendingDescriptor { + return { + type: 'pending', + completionFuture: new Future(), + }; + }, + active(info: DataTrackInfo, encryptionProvider: EncryptionProvider | null): ActiveDescriptor { + return { + type: 'active', + info, + pipeline: new DataTrackOutgoingPipeline({ info, encryptionProvider }), + unpublishingFuture: new Future(), + }; + }, +}; + +export type DataTrackOutgoingManagerCallbacks = { + /** Request sent to the SFU to publish a track. */ + sfuPublishRequest: (event: OutputEventSfuPublishRequest) => void; + /** Request sent to the SFU to unpublish a track. */ + sfuUnpublishRequest: (event: OutputEventSfuUnpublishRequest) => void; + /** Serialized packets are ready to be sent over the transport. */ + packetsAvailable: (event: OutputEventPacketsAvailable) => void; +}; + +type DataTrackLocalManagerOptions = { + /** + * Provider to use for encrypting outgoing frame payloads. + * + * If none, end-to-end encryption will be disabled for all published tracks. + */ + encryptionProvider?: EncryptionProvider; +}; + +/** How long to wait when attempting to publish before timing out. */ +const PUBLISH_TIMEOUT_MILLISECONDS = 10_000; + +export default class OutgoingDataTrackManager extends (EventEmitter as new () => TypedEmitter) { + private encryptionProvider: EncryptionProvider | null; + + private handleAllocator = new DataTrackHandleAllocator(); + + private descriptors = new Map(); + + constructor(options?: DataTrackLocalManagerOptions) { + super(); + this.encryptionProvider = options?.encryptionProvider ?? null; + } + + static withDescriptors(descriptors: Map) { + const manager = new OutgoingDataTrackManager(); + manager.descriptors = descriptors; + return manager; + } + + /** + * Used by attached {@link LocalDataTrack} instances to query their associated descriptor info. + * @internal + */ + getDescriptor(handle: DataTrackHandle) { + return this.descriptors.get(handle) ?? null; + } + + createLocalDataTrack(handle: DataTrackHandle) { + const descriptor = this.getDescriptor(handle); + if (descriptor?.type !== 'active') { + return null; + } + return new LocalDataTrack(descriptor.info, this); + } + + /** Used by attached {@link LocalDataTrack} instances to broadcast data track packets to other + * subscribers. + * @internal + */ + tryProcessAndSend( + handle: DataTrackHandle, + payload: Uint8Array, + ): Throws< + void, + | DataTrackPushFrameError + | DataTrackPushFrameError + > { + const descriptor = this.getDescriptor(handle); + if (descriptor?.type !== 'active') { + throw DataTrackPushFrameError.trackUnpublished(); + } + + const frame: DataTrackFrame = { + payload, + extensions: new DataTrackExtensions(), + }; + + try { + for (const packet of descriptor.pipeline.processFrame(frame)) { + this.emit('packetsAvailable', { bytes: packet.toBinary() }); + } + } catch (err) { + // NOTE: In the rust implementation this "dropped" error means something different (not enough room + // in the track mpsc channel) + throw DataTrackPushFrameError.dropped(err); + } + } + + /** Client requested to publish a track. */ + async publishRequest(options: DataTrackOptions, signal?: AbortSignal) { + const handle = this.handleAllocator.get(); + if (!handle) { + throw DataTrackPublishError.limitReached(); + } + + const timeoutSignal = AbortSignal.timeout(PUBLISH_TIMEOUT_MILLISECONDS); + const combinedSignal = signal ? AbortSignal.any([signal, timeoutSignal]) : timeoutSignal; + + if (this.descriptors.has(handle)) { + // @throws-transformer ignore - this should be treated as a "panic" and not be caught + throw new Error('Descriptor for handle already exists'); + } + + const descriptor = Descriptor.pending(); + this.descriptors.set(handle, descriptor); + + const onAbort = () => { + const existingDescriptor = this.descriptors.get(handle); + if (!existingDescriptor) { + log.warn(`No descriptor for ${handle}`); + return; + } + this.descriptors.delete(handle); + + // Let the SFU know that the publish has been cancelled + this.emit('sfuUnpublishRequest', { handle }); + + if (existingDescriptor.type === 'pending') { + existingDescriptor.completionFuture.reject?.( + timeoutSignal.aborted + ? DataTrackPublishError.timeout() + : // NOTE: the below cancelled case was introduced by web / there isn't a corresponding case in the rust version. + DataTrackPublishError.cancelled(), + ); + } + }; + combinedSignal.addEventListener('abort', onAbort); + + this.emit('sfuPublishRequest', { + handle, + name: options.name, + usesE2ee: this.encryptionProvider !== null, + }); + + const localDataTrack = await descriptor.completionFuture.promise; + combinedSignal.removeEventListener('abort', onAbort); + return localDataTrack; + } + + /** Get information about all currently published tracks. */ + async queryPublished() { + const descriptorInfos = Array.from(this.descriptors.values()) + .filter((descriptor): descriptor is ActiveDescriptor => descriptor.type === 'active') + .map((descriptor) => descriptor.info); + + return descriptorInfos; + } + + /** Client request to unpublish a track. */ + async unpublishRequest(handle: DataTrackHandle) { + const descriptor = this.descriptors.get(handle); + if (!descriptor) { + log.warn(`No descriptor for ${handle}`); + return; + } + if (descriptor.type !== 'active') { + log.warn(`Track ${handle} not active`); + return; + } + + this.emit('sfuUnpublishRequest', { handle }); + + await descriptor.unpublishingFuture.promise; + } + + /** SFU responded to a request to publish a data track. */ + receivedSfuPublishResponse(handle: DataTrackHandle, result: SfuPublishResponseResult) { + const descriptor = this.descriptors.get(handle); + if (!descriptor) { + log.warn(`No descriptor for ${handle}`); + return; + } + this.descriptors.delete(handle); + + if (descriptor.type !== 'pending') { + log.warn(`Track ${handle} already active`); + return; + } + + if (result.type === 'ok') { + const info = result.data; + + const encryptionProvider = info.usesE2ee ? this.encryptionProvider : null; + this.descriptors.set(info.pubHandle, Descriptor.active(info, encryptionProvider)); + + const localDataTrack = this.createLocalDataTrack(info.pubHandle); + if (!localDataTrack) { + // @throws-transformer ignore - this should be treated as a "panic" and not be caught + throw new Error( + 'DataTrackOutgoingManager.handleSfuPublishResponse: localDataTrack was not created after active descriptor stored.', + ); + } + + descriptor.completionFuture.resolve?.(localDataTrack); + } else { + descriptor.completionFuture.reject?.(result.error); + } + } + + /** SFU notification that a track has been unpublished. */ + receivedSfuUnpublishResponse(handle: DataTrackHandle) { + const descriptor = this.descriptors.get(handle); + if (!descriptor) { + log.warn(`No descriptor for ${handle}`); + return; + } + this.descriptors.delete(handle); + + if (descriptor.type !== 'active') { + log.warn(`Track ${handle} not active`); + return; + } + + descriptor.unpublishingFuture.resolve?.(); + } + + /** Shuts down the manager and all associated tracks. */ + async shutdown() { + for (const descriptor of this.descriptors.values()) { + switch (descriptor.type) { + case 'pending': + descriptor.completionFuture.reject?.(DataTrackPublishError.disconnected()); + break; + case 'active': + // Abandon any unpublishing descriptors that were in flight and assume they will get + // cleaned up automatically with the connection shutdown. + descriptor.unpublishingFuture.resolve?.(); + + await this.unpublishRequest(descriptor.info.pubHandle); + break; + } + } + this.descriptors.clear(); + } +} diff --git a/src/room/data-track/outgoing/errors.ts b/src/room/data-track/outgoing/errors.ts new file mode 100644 index 0000000000..2340c23593 --- /dev/null +++ b/src/room/data-track/outgoing/errors.ts @@ -0,0 +1,157 @@ +import { LivekitReasonedError } from '../../errors'; +import { DataTrackPacketizerError, DataTrackPacketizerReason } from '../packetizer'; + +export enum DataTrackPublishErrorReason { + /** + * Local participant does not have permission to publish data tracks. + * + * Ensure the participant's token contains the `canPublishData` grant. + */ + NotAllowed = 0, + + /** A track with the same name is already published by the local participant. */ + DuplicateName = 1, + + /** Request to publish the track took long to complete. */ + Timeout = 2, + + /** No additional data tracks can be published by the local participant. */ + LimitReached = 3, + + /** Cannot publish data track when the room is disconnected. */ + Disconnected = 4, + + // NOTE: this was introduced by web / there isn't a corresponding case in the rust version. + Cancelled = 5, +} + +export class DataTrackPublishError< + Reason extends DataTrackPublishErrorReason, +> extends LivekitReasonedError { + readonly name = 'DataTrackPublishError'; + + reason: Reason; + + reasonName: string; + + constructor(message: string, reason: Reason, options?: { cause?: unknown }) { + super(21, message, options); + this.reason = reason; + this.reasonName = DataTrackPublishErrorReason[reason]; + } + + static notAllowed() { + return new DataTrackPublishError( + 'Data track publishing unauthorized', + DataTrackPublishErrorReason.NotAllowed, + ); + } + + static duplicateName() { + return new DataTrackPublishError( + 'Track name already taken', + DataTrackPublishErrorReason.DuplicateName, + ); + } + + static timeout() { + return new DataTrackPublishError( + 'Publish data track timed-out', + DataTrackPublishErrorReason.Timeout, + ); + } + + static limitReached() { + return new DataTrackPublishError( + 'Data track publication limit reached', + DataTrackPublishErrorReason.LimitReached, + ); + } + + static disconnected() { + return new DataTrackPublishError('Room disconnected', DataTrackPublishErrorReason.Disconnected); + } + + // NOTE: this was introduced by web / there isn't a corresponding case in the rust version. + static cancelled() { + return new DataTrackPublishError( + 'Publish data track cancelled by caller', + DataTrackPublishErrorReason.Cancelled, + ); + } +} + +export enum DataTrackPushFrameErrorReason { + /** Track is no longer published. */ + TrackUnpublished = 0, + /** Frame was dropped. */ + // NOTE: this should become a web specific error, the rust version of this "dropped" error means + // something different and will be renamed to "QueueFull". + Dropped = 1, +} + +export class DataTrackPushFrameError< + Reason extends DataTrackPushFrameErrorReason, +> extends LivekitReasonedError { + readonly name = 'DataTrackPushFrameError'; + + reason: Reason; + + reasonName: string; + + constructor(message: string, reason: Reason, options?: { cause?: unknown }) { + super(22, message, options); + this.reason = reason; + this.reasonName = DataTrackPushFrameErrorReason[reason]; + } + + static trackUnpublished() { + return new DataTrackPushFrameError( + 'Track is no longer published', + DataTrackPushFrameErrorReason.TrackUnpublished, + ); + } + + static dropped(cause: unknown) { + return new DataTrackPushFrameError('Frame was dropped', DataTrackPushFrameErrorReason.Dropped, { + cause, + }); + } +} + +export enum DataTrackOutgoingPipelineErrorReason { + Packetizer = 0, + Encryption = 1, +} + +export class DataTrackOutgoingPipelineError< + Reason extends DataTrackOutgoingPipelineErrorReason, +> extends LivekitReasonedError { + readonly name = 'DataTrackOutgoingPipelineError'; + + reason: Reason; + + reasonName: string; + + constructor(message: string, reason: Reason, options?: { cause?: unknown }) { + super(21, message, options); + this.reason = reason; + this.reasonName = DataTrackOutgoingPipelineErrorReason[reason]; + } + + static packetizer(cause: DataTrackPacketizerError) { + return new DataTrackOutgoingPipelineError( + 'Error packetizing frame', + DataTrackOutgoingPipelineErrorReason.Packetizer, + { cause }, + ); + } + + static encryption(cause: unknown) { + return new DataTrackOutgoingPipelineError( + 'Error encrypting frame', + DataTrackOutgoingPipelineErrorReason.Encryption, + { cause }, + ); + } +} diff --git a/src/room/data-track/outgoing/pipeline.ts b/src/room/data-track/outgoing/pipeline.ts new file mode 100644 index 0000000000..f4f0c252af --- /dev/null +++ b/src/room/data-track/outgoing/pipeline.ts @@ -0,0 +1,76 @@ +import { type Throws } from '../../../utils/throws'; +import { type EncryptedPayload, type EncryptionProvider } from '../e2ee'; +import { type DataTrackFrame } from '../frame'; +import { DataTrackPacket } from '../packet'; +import { DataTrackE2eeExtension } from '../packet/extensions'; +import DataTrackPacketizer, { DataTrackPacketizerError } from '../packetizer'; +import type { DataTrackInfo } from '../track'; +import { DataTrackOutgoingPipelineError, DataTrackOutgoingPipelineErrorReason } from './errors'; + +type Options = { + info: DataTrackInfo; + encryptionProvider: EncryptionProvider | null; +}; + +/** Processes outgoing frames into final packets for distribution to the SFU. */ +export default class DataTrackOutgoingPipeline { + private encryptionProvider: EncryptionProvider | null; + + private packetizer: DataTrackPacketizer; + + /** Maximum transmission unit (MTU) of the transport. */ + private static TRANSPORT_MTU_BYTES = 16_000; + + constructor(options: Options) { + this.encryptionProvider = options.encryptionProvider; + this.packetizer = new DataTrackPacketizer( + options.info.pubHandle, + DataTrackOutgoingPipeline.TRANSPORT_MTU_BYTES, + ); + } + + *processFrame( + frame: DataTrackFrame, + ): Throws< + Generator, + | DataTrackOutgoingPipelineError + | DataTrackOutgoingPipelineError + > { + const encryptedFrame = this.encryptIfNeeded(frame); + + try { + yield* this.packetizer.packetize(encryptedFrame); + } catch (error) { + if (error instanceof DataTrackPacketizerError) { + throw DataTrackOutgoingPipelineError.packetizer(error); + } + throw error; + } + } + + encryptIfNeeded( + frame: DataTrackFrame, + ): Throws< + DataTrackFrame, + DataTrackOutgoingPipelineError + > { + if (!this.encryptionProvider) { + return frame; + } + + let encryptedResult: EncryptedPayload; + try { + encryptedResult = this.encryptionProvider.encrypt(frame.payload); + } catch (err) { + throw DataTrackOutgoingPipelineError.encryption(err); + } + + frame.payload = encryptedResult.payload; + frame.extensions.e2ee = new DataTrackE2eeExtension( + encryptedResult.keyIndex, + encryptedResult.iv, + ); + + return frame; + } +} diff --git a/src/room/data-track/outgoing/types.ts b/src/room/data-track/outgoing/types.ts new file mode 100644 index 0000000000..95492ffbd7 --- /dev/null +++ b/src/room/data-track/outgoing/types.ts @@ -0,0 +1,37 @@ +import { type DataTrackHandle } from '../handle'; +import { type DataTrackInfo } from '../track'; +import { type DataTrackPublishError, type DataTrackPublishErrorReason } from './errors'; + +/** Options for publishing a data track. */ +export type DataTrackOptions = { + name: string; +}; + +/** Encodes whether a data track publish request to the SFU has been successful or not. */ +export type SfuPublishResponseResult = + | { type: 'ok'; data: DataTrackInfo } + | { + type: 'error'; + error: + | DataTrackPublishError + | DataTrackPublishError + | DataTrackPublishError; + }; + +/** Request sent to the SFU to publish a track. */ +export type OutputEventSfuPublishRequest = { + handle: DataTrackHandle; + name: string; + usesE2ee: boolean; +}; + +/** Request sent to the SFU to unpublish a track. */ +export type OutputEventSfuUnpublishRequest = { + handle: DataTrackHandle; +}; + +/** Serialized packets are ready to be sent over the transport. */ +export type OutputEventPacketsAvailable = { + bytes: Uint8Array; + signal?: AbortSignal; +}; diff --git a/src/room/data-track/packet/index.ts b/src/room/data-track/packet/index.ts index c3b59a551f..74e4d89a19 100644 --- a/src/room/data-track/packet/index.ts +++ b/src/room/data-track/packet/index.ts @@ -126,7 +126,7 @@ export class DataTrackPacketHeader extends Serializable { dataView.setUint8(byteIndex, 0); // Reserved byteIndex += U8_LENGTH_BYTES; - dataView.setUint16(byteIndex, this.trackHandle.value); + dataView.setUint16(byteIndex, this.trackHandle); byteIndex += U16_LENGTH_BYTES; dataView.setUint16(byteIndex, this.sequence.value); byteIndex += U16_LENGTH_BYTES; @@ -277,7 +277,7 @@ export class DataTrackPacketHeader extends Serializable { toJSON() { return { marker: this.marker, - trackHandle: this.trackHandle.value, + trackHandle: this.trackHandle, sequence: this.sequence.value, frameNumber: this.frameNumber.value, timestamp: this.timestamp.asTicks(), diff --git a/src/room/data-track/packetizer.test.ts b/src/room/data-track/packetizer.test.ts index 6912a5dba7..bbd224cb24 100644 --- a/src/room/data-track/packetizer.test.ts +++ b/src/room/data-track/packetizer.test.ts @@ -4,7 +4,7 @@ import { DataTrackFrame } from './frame'; import { DataTrackHandle } from './handle'; import { FrameMarker } from './packet'; import { DataTrackExtensions } from './packet/extensions'; -import { DataTrackPacketizer } from './packetizer'; +import DataTrackPacketizer from './packetizer'; import { DataTrackTimestamp } from './utils'; describe('DataTrackPacketizer', () => { @@ -13,7 +13,7 @@ describe('DataTrackPacketizer', () => { const packets = Array.from( packetizer.packetize( { - payload: new Uint8Array(300).fill(0xbe).buffer, + payload: new Uint8Array(300).fill(0xbe), extensions: new DataTrackExtensions(), }, { now: DataTrackTimestamp.fromRtpTicks(1804548298) }, @@ -91,7 +91,7 @@ describe('DataTrackPacketizer', () => { const packetizer = new DataTrackPacketizer(DataTrackHandle.fromNumber(1), mtuSizeBytes); const frame: DataTrackFrame = { - payload: new Uint8Array(payloadSizeBytes).fill(0xab).buffer, + payload: new Uint8Array(payloadSizeBytes).fill(0xab), extensions: new DataTrackExtensions(), }; const packets = Array.from( diff --git a/src/room/data-track/packetizer.ts b/src/room/data-track/packetizer.ts index eba9346fcc..0e6a8fe7cd 100644 --- a/src/room/data-track/packetizer.ts +++ b/src/room/data-track/packetizer.ts @@ -40,7 +40,7 @@ export enum DataTrackPacketizerReason { /** A packetizer takes a {@link DataTrackFrame} as input and generates a series * of {@link DataTrackPacket}s for transmission to other clients over webrtc. */ -export class DataTrackPacketizer { +export default class DataTrackPacketizer { private handle: DataTrackHandle; private mtuSizeBytes: number; @@ -120,7 +120,11 @@ export class DataTrackPacketizer { // ... and the last packet will be as long as it needs to be to finish out the buffer. frame.payload.byteLength - indexBytes, ); - const packetPayload = new Uint8Array(frame.payload, indexBytes, packetPayloadLengthBytes); + const packetPayload = new Uint8Array( + frame.payload.buffer, + frame.payload.byteOffset + indexBytes, + packetPayloadLengthBytes, + ); yield new DataTrackPacket(packetHeader, packetPayload); } diff --git a/src/room/data-track/track.ts b/src/room/data-track/track.ts new file mode 100644 index 0000000000..5a2888ab5e --- /dev/null +++ b/src/room/data-track/track.ts @@ -0,0 +1,50 @@ +import type { DataTrackFrame } from './frame'; +import { type DataTrackHandle } from './handle'; +import type OutgoingDataTrackManager from './outgoing/OutgoingDataTrackManager'; + +export type DataTrackSid = string; + +/** Information about a published data track. */ +export type DataTrackInfo = { + sid: DataTrackSid; + pubHandle: DataTrackHandle; + name: String; + usesE2ee: boolean; +}; + +export class LocalDataTrack { + info: DataTrackInfo; + + protected manager: OutgoingDataTrackManager; + + constructor(info: DataTrackInfo, manager: OutgoingDataTrackManager) { + this.info = info; + this.manager = manager; + } + + /** The raw descriptor from the manager containing the internal state for this local track. */ + protected get descriptor() { + return this.manager.getDescriptor(this.info.pubHandle); + } + + isPublished() { + return this.descriptor?.type === 'active'; + } + + /** Try pushing a frame to subscribers of the track. + * + * Pushing a frame can fail for several reasons: + * + * - The track has been unpublished by the local participant or SFU + * - The room is no longer connected + */ + tryPush(payload: DataTrackFrame['payload']) { + try { + return this.manager.tryProcessAndSend(this.info.pubHandle, payload); + } catch (err) { + // NOTE: wrapping in the bare try/catch like this means that the Throws<...> type doesn't + // propegate upwards into the public interface. + throw err; + } + } +} diff --git a/src/room/participant/LocalParticipant.ts b/src/room/participant/LocalParticipant.ts index 495ccc80dd..c6b9432844 100644 --- a/src/room/participant/LocalParticipant.ts +++ b/src/room/participant/LocalParticipant.ts @@ -275,6 +275,8 @@ export default class LocalParticipant extends Participant { private handleClosing = () => { if (this.reconnectFuture) { + // @throws-transformer ignore - introduced due to adding Throws into Future, investigate this + // further this.reconnectFuture.promise.catch((e) => this.log.warn(e.message, this.logContext)); this.reconnectFuture?.reject?.(new Error('Got disconnected during reconnection attempt')); this.reconnectFuture = undefined; diff --git a/src/room/utils.ts b/src/room/utils.ts index 42e071ccbc..c7694b0e97 100644 --- a/src/room/utils.ts +++ b/src/room/utils.ts @@ -8,6 +8,7 @@ import { import TypedPromise from '../utils/TypedPromise'; import { getBrowser } from '../utils/browserParser'; import type { BrowserDetails } from '../utils/browserParser'; +import { type Throws } from '../utils/throws'; import { protocolVersion, version } from '../version'; import { type ConnectionError, ConnectionErrorReason } from './errors'; import type LocalParticipant from './participant/LocalParticipant'; @@ -457,8 +458,12 @@ export function getStereoAudioStreamTrack() { return stereoTrack; } +/** An object that represents a serialized version of a `new Promise((resolve, reject) => {})` + * constructor. Wait for a promise resolution with `await future.promise` and explicitly resolve or + * reject the inner promise with `future.resolve(...)` or `future.reject(...)`. + */ export class Future { - promise: Promise; + promise: Promise>; resolve?: (arg: T) => void; @@ -486,7 +491,7 @@ export class Future { }).finally(() => { this._isResolved = true; this.onFinally?.(); - }); + }) as Promise>; } } diff --git a/src/utils/subscribeToEvents.ts b/src/utils/subscribeToEvents.ts new file mode 100644 index 0000000000..8d9b623c06 --- /dev/null +++ b/src/utils/subscribeToEvents.ts @@ -0,0 +1,63 @@ +import { type EventMap } from 'typed-emitter'; +import type TypedEventEmitter from 'typed-emitter'; +import { Future } from '../room/utils'; + +/** A test helper to listen to events received by an event emitter and allow them to be imperatively + * queried after the fact. */ +export function subscribeToEvents< + Callbacks extends EventMap, + EventNames extends keyof Callbacks = keyof Callbacks, +>(eventEmitter: TypedEventEmitter, eventNames: Array) { + const nextEventListeners = new Map>>( + eventNames.map((eventName) => [eventName, []]), + ); + const buffers = new Map>( + eventNames.map((eventName) => [eventName, []]), + ); + + const eventHandlers = eventNames.map((eventName) => { + const onEvent = ((event: unknown) => { + const listeners = nextEventListeners.get(eventName)!; + if (listeners.length > 0) { + for (const listener of listeners) { + listener.resolve?.(event); + } + nextEventListeners.set(eventName, []); + } else { + buffers.get(eventName)!.push(event); + } + }) as Callbacks[keyof Callbacks]; + return [eventName, onEvent] as [keyof Callbacks, Callbacks[keyof Callbacks]]; + }); + for (const [eventName, onEvent] of eventHandlers) { + eventEmitter.on(eventName, onEvent); + } + + return { + /** Listen for the next occurrance of an event to be emitted, or return the last event that was + * buffered (but hasn't been processed yet). */ + async waitFor< + EventPayload extends Parameters[0], + EventName extends EventNames = EventNames, + >(eventName: EventName): Promise { + // If an event is already buffered which hasn't been processed yet, pull that off the buffer + // and use it. + const earliestBufferedEvent = buffers.get(eventName)!.shift(); + if (earliestBufferedEvent) { + return earliestBufferedEvent as EventPayload; + } + + // Otherwise wait for the next event to come in. + const future = new Future(); + nextEventListeners.get(eventName)!.push(future); + const nextEvent = await future.promise; + return nextEvent as EventPayload; + }, + /** Cleanup any lingering subscriptions. */ + unsubscribe: () => { + for (const [eventName, onEvent] of eventHandlers) { + eventEmitter.off(eventName, onEvent); + } + }, + }; +} diff --git a/throws-transformer/engine.ts b/throws-transformer/engine.ts index 8a68b66186..7a95e1dc81 100644 --- a/throws-transformer/engine.ts +++ b/throws-transformer/engine.ts @@ -8,8 +8,6 @@ */ import * as ts from "typescript"; -import * as path from "path"; -import { sync as globSync } from "glob"; // Symbol name for the Throws type brand const THROWS_BRAND = "__throws"; @@ -79,6 +77,26 @@ export function checkSourceFile( return results; } +function preceededByIgnoreComment(node: ts.Node, sourceFile: ts.SourceFile) { + const foundComments = ts.getLeadingCommentRanges(sourceFile.text, node.pos); + if (foundComments) { + const foundCommentsText = foundComments.map((info) => { + return sourceFile + .text + .slice(info.pos, info.end) + .replace(/^(\/\/|\/\*)\s*/ /* Remove leading comment prefix */, ''); + }); + + const isIgnoreComment = foundCommentsText.find((commentText) => { + return commentText.startsWith('@throws-transformer ignore'); + }); + + return isIgnoreComment; + } else { + return false; + } +} + function checkThrowStatement( node: ts.ThrowStatement, sourceFile: ts.SourceFile, @@ -106,22 +124,8 @@ function checkThrowStatement( // Check to see if there is a comment about the throw site starting with "@throws-transformer // ignore", and if so, disregard. - const foundComments = ts.getLeadingCommentRanges(sourceFile.text, node.pos); - if (foundComments) { - const foundCommentsText = foundComments.map((info) => { - return sourceFile - .text - .slice(info.pos, info.end) - .replace(/^(\/\/|\/\*)\s*/ /* Remove leading comment prefix */, ''); - }); - - const isIgnoreComment = foundCommentsText.find((commentText) => { - return commentText.startsWith('@throws-transformer ignore'); - }); - - if (isIgnoreComment) { - return null; - } + if (preceededByIgnoreComment(node, sourceFile)) { + return null; } const thrownType = checker.getTypeAtLocation(node.expression); @@ -331,8 +335,12 @@ function checkCallExpression( // Get the return type of the call const callType = checker.getTypeAtLocation(node); + const tryCatch = getContainingTryCatch(node); + // Extract error types - const errorTypes = extractThrowsErrorTypes(callType, checker); + const errorTypes = tryCatch ? ( + getTryCatchThrownErrors(tryCatch, sourceFile, checker) + ) : extractThrowsErrorTypes(callType, checker); if (errorTypes.length === 0) { return null; @@ -340,19 +348,20 @@ function checkCallExpression( // Check handling const containingFunction = getContainingFunction(node); - const tryCatch = getContainingTryCatch(node); const handledErrors = tryCatch ? getHandledErrorTypes(tryCatch, checker, node) : new Set(); - // If catch-all, everything is handled + // If the catch clause contains no throws all errors are being silenced + // ie, something like `try { /* code here */ } catch (err) {}` + // TODO: maybe log a warning here, this is probably bad at least in some cases? if (handledErrors === "all") { return null; } const propagatedErrors = containingFunction - ? getPropagatedErrorTypes(containingFunction, checker) + ? getPropagatedErrorTypes(node, containingFunction, sourceFile, checker) : new Set(); // Find unhandled @@ -365,6 +374,10 @@ function checkCallExpression( return null; } + if (preceededByIgnoreComment(node, sourceFile)) { + return null; + } + const start = node.getStart(); const length = node.getWidth(); const { line, character } = sourceFile.getLineAndCharacterOfPosition(start); @@ -559,6 +572,39 @@ function getContainingTryCatch(node: ts.Node): ts.TryStatement | null { return null; } +/** Get errors which the given try/catch passed itself throws within its catch block. */ +function getTryCatchThrownErrors(tryCatch: ts.TryStatement, sourceFile: ts.SourceFile, checker: ts.TypeChecker) { + const thrownErrorTypes: Array = []; + + if (!tryCatch.catchClause) { + return thrownErrorTypes; + } + + function visitThrowStatement(throwStmt: ts.ThrowStatement) { + const thrownErrorType = checker.getTypeAtLocation(throwStmt.expression); + if (!isAnyOrUnknownType(thrownErrorType)) { + if (thrownErrorType.isUnion()) { + for (const type of thrownErrorType.types.filter(t => !isAnyOrUnknownType(t))) { + thrownErrorTypes.push(type); + } + } else { + thrownErrorTypes.push(thrownErrorType); + } + } + } + + function visit(node: ts.Node): void { + if (ts.isThrowStatement(node) && checkThrowStatement(node, sourceFile, checker)) { + visitThrowStatement(node); + } + ts.forEachChild(node, visit); + } + + visit(tryCatch.catchClause); + + return thrownErrorTypes; +} + function isInTryBlock(node: ts.Node, tryStatement: ts.TryStatement): boolean { let current: ts.Node | undefined = node; @@ -661,7 +707,7 @@ function findNarrowedErrorTypes( * Analyze if-statements to find type narrowing branches that don't re-throw. * These represent error types that are handled. */ - function visitIfStatement(ifStmt: ts.IfStatement): void { + function visitIfStatement(ifStmt: ts.IfStatement) { // Get the type of the error variable after the type guard in the if condition // const condition = ifStmt.expression; @@ -817,13 +863,22 @@ function findInstanceofChecks( } function getPropagatedErrorTypes( + node: ts.Node, func: ts.FunctionLikeDeclaration, + sourceFile: ts.SourceFile, checker: ts.TypeChecker, ): Set { const propagated = new Set(); if (!func.type) { return propagated; } + // If `node` is in a try/catch, then the errors propegated are the errors that the catch itself throws + const tryCatch = getContainingTryCatch(node); + if (tryCatch?.catchClause) { + const thrownErrorTypes = getTryCatchThrownErrors(tryCatch, sourceFile, checker); + return new Set(thrownErrorTypes.map(e => checker.typeToString(e))); + } + const returnType = checker.getTypeFromTypeNode(func.type); const errorTypes = extractThrowsErrorTypes(returnType, checker);