Skip to content

Commit d1b17ef

Browse files
authored
fix(types)!: do not try to infer types of overloaded functions (#2)
BREAKING CHANGE: overloaded function may now require explicit type annotations
1 parent 14a24cc commit d1b17ef

File tree

6 files changed

+57
-118
lines changed

6 files changed

+57
-118
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,27 @@ expect(spy('hello')).toEqual('goodbye');
230230

231231
[asymmetric matchers]: https://vitest.dev/api/expect.html#expect-anything
232232

233+
#### Types of overloaded functions
234+
235+
Due to fundamental limitations in TypeScript, `when()` will always use the _last_ overload to infer function parameters and return types. You can use the `TFunc` type parameter of `when()` to manually select a different overload entry:
236+
237+
```ts
238+
function overloaded(): null;
239+
function overloaded(input: number): string;
240+
function overloaded(input?: number): string | null {
241+
// ...
242+
}
243+
244+
// Last entry: all good!
245+
when(overloaded).calledWith(42).thenReturn('hello');
246+
247+
// $ts-expect-error: first entry
248+
when(overloaded).calledWith().thenReturn(null);
249+
250+
// Manually specified: all good!
251+
when<() => null>(overloaded).calledWith().thenReturn(null);
252+
```
253+
233254
### `.thenReturn(value: TReturn)`
234255

235256
When the stubbing is satisfied, return `value`

src/behaviors.ts

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import { equals } from '@vitest/expect';
2-
import type {
3-
AnyFunction,
4-
AllParameters,
5-
ReturnTypeFromArgs,
6-
} from './types.ts';
2+
import type { AnyFunction } from './types.ts';
73

84
export const ONCE = Symbol('ONCE');
95

106
export type StubValue<TValue> = TValue | typeof ONCE;
117

128
export interface BehaviorStack<TFunc extends AnyFunction> {
139
use: (
14-
args: AllParameters<TFunc>
15-
) => BehaviorEntry<AllParameters<TFunc>> | undefined;
10+
args: Parameters<TFunc>
11+
) => BehaviorEntry<Parameters<TFunc>> | undefined;
1612

17-
bindArgs: <TArgs extends AllParameters<TFunc>>(
13+
bindArgs: <TArgs extends Parameters<TFunc>>(
1814
args: TArgs
19-
) => BoundBehaviorStack<ReturnTypeFromArgs<TFunc, TArgs>>;
15+
) => BoundBehaviorStack<ReturnType<TFunc>>;
2016
}
2117

2218
export interface BoundBehaviorStack<TReturn> {
@@ -43,7 +39,7 @@ export interface BehaviorOptions<TValue> {
4339
export const createBehaviorStack = <
4440
TFunc extends AnyFunction
4541
>(): BehaviorStack<TFunc> => {
46-
const behaviors: BehaviorEntry<AllParameters<TFunc>>[] = [];
42+
const behaviors: BehaviorEntry<Parameters<TFunc>>[] = [];
4743

4844
return {
4945
use: (args) => {

src/stubs.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import type { Mock as Spy } from 'vitest';
22
import { createBehaviorStack, type BehaviorStack } from './behaviors.ts';
33
import { NotAMockFunctionError } from './errors.ts';
4-
import type { AnyFunction, AllParameters } from './types.ts';
4+
import type { AnyFunction } from './types.ts';
55

66
const BEHAVIORS_KEY = Symbol('behaviors');
77

88
interface WhenStubImplementation<TFunc extends AnyFunction> {
9-
(...args: AllParameters<TFunc>): unknown;
9+
(...args: Parameters<TFunc>): unknown;
1010
[BEHAVIORS_KEY]: BehaviorStack<TFunc>;
1111
}
1212

@@ -25,7 +25,7 @@ export const configureStub = <TFunc extends AnyFunction>(
2525

2626
const behaviors = createBehaviorStack<TFunc>();
2727

28-
const implementation = (...args: AllParameters<TFunc>): unknown => {
28+
const implementation = (...args: Parameters<TFunc>): unknown => {
2929
const behavior = behaviors.use(args);
3030

3131
if (behavior?.throwError) {
@@ -48,15 +48,15 @@ export const configureStub = <TFunc extends AnyFunction>(
4848

4949
const validateSpy = <TFunc extends AnyFunction>(
5050
maybeSpy: unknown
51-
): Spy<AllParameters<TFunc>, unknown> => {
51+
): Spy<Parameters<TFunc>, unknown> => {
5252
if (
5353
typeof maybeSpy === 'function' &&
5454
'mockImplementation' in maybeSpy &&
5555
typeof maybeSpy.mockImplementation === 'function' &&
5656
'getMockImplementation' in maybeSpy &&
5757
typeof maybeSpy.getMockImplementation === 'function'
5858
) {
59-
return maybeSpy as Spy<AllParameters<TFunc>, unknown>;
59+
return maybeSpy as Spy<Parameters<TFunc>, unknown>;
6060
}
6161

6262
throw new NotAMockFunctionError(maybeSpy);

src/types.ts

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,4 @@
1-
/**
2-
* Get function arguments and return value types.
3-
*
4-
* Support for overloaded functions, thanks to @Shakeskeyboarde
5-
* https://github.com/microsoft/TypeScript/issues/14107#issuecomment-1146738780
6-
*/
7-
8-
import type { SpyInstance } from 'vitest';
1+
/** Common type definitions. */
92

103
/** Any function, for use in `extends` */
114
export type AnyFunction = (...args: never[]) => unknown;
12-
13-
/** Acceptable arguments for a function.*/
14-
export type AllParameters<TFunc extends AnyFunction> =
15-
TFunc extends SpyInstance<infer TArgs, unknown>
16-
? TArgs
17-
: Parameters<ToOverloads<TFunc>>;
18-
19-
/** The return type of a function, given the actual arguments used.*/
20-
export type ReturnTypeFromArgs<
21-
TFunc extends AnyFunction,
22-
TArgs extends unknown[]
23-
> = TFunc extends SpyInstance<unknown[], infer TReturn>
24-
? TReturn
25-
: ExtractReturn<ToOverloads<TFunc>, TArgs>;
26-
27-
/** Given a functions and actual arguments used, extract the return type. */
28-
type ExtractReturn<
29-
TFunc extends AnyFunction,
30-
TArgs extends unknown[]
31-
> = TFunc extends (...args: infer TFuncArgs) => infer TFuncReturn
32-
? TArgs extends TFuncArgs
33-
? TFuncReturn
34-
: never
35-
: never;
36-
37-
/** Transform an overloaded function into a union of functions. */
38-
type ToOverloads<TFunc extends AnyFunction> = Exclude<
39-
OverloadUnion<(() => never) & TFunc>,
40-
TFunc extends () => never ? never : () => never
41-
>;
42-
43-
/** Recursively extract functions from an overload into a union. */
44-
type OverloadUnion<TFunc, TPartialOverload = unknown> = TFunc extends (
45-
...args: infer TArgs
46-
) => infer TReturn
47-
? TPartialOverload extends TFunc
48-
? never
49-
:
50-
| OverloadUnion<
51-
TPartialOverload & TFunc,
52-
TPartialOverload &
53-
((...args: TArgs) => TReturn) &
54-
OverloadProps<TFunc>
55-
>
56-
| ((...args: TArgs) => TReturn)
57-
: never;
58-
59-
/** Properties attached to a function. */
60-
type OverloadProps<TFunc> = Pick<TFunc, keyof TFunc>;

src/vitest-when.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
import { configureStub } from './stubs.ts';
22
import type { StubValue } from './behaviors.ts';
3-
import type {
4-
AnyFunction,
5-
AllParameters,
6-
ReturnTypeFromArgs,
7-
} from './types.ts';
3+
import type { AnyFunction } from './types.ts';
84

95
export { ONCE, type StubValue } from './behaviors.ts';
106
export * from './errors.ts';
117

128
export interface StubWrapper<TFunc extends AnyFunction> {
13-
calledWith<TArgs extends AllParameters<TFunc>>(
9+
calledWith<TArgs extends Parameters<TFunc>>(
1410
...args: TArgs
15-
): Stub<TArgs, ReturnTypeFromArgs<TFunc, TArgs>>;
11+
): Stub<TArgs, ReturnType<TFunc>>;
1612
}
1713

1814
export interface Stub<TArgs extends unknown[], TReturn> {

test/typing.test-d.ts

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,12 @@ describe('vitest-when type signatures', () => {
3131
assertType<subject.Stub<[1], string>>(stub);
3232
});
3333

34-
it('should reject invalid usage of a simple function', () => {
35-
// @ts-expect-error: args missing
36-
subject.when(simple).calledWith();
34+
it('should handle a generic function', () => {
35+
const stub = subject.when(generic).calledWith(1);
3736

38-
// @ts-expect-error: args wrong type
39-
subject.when(simple).calledWith('hello');
37+
stub.thenReturn('hello');
4038

41-
// @ts-expect-error: return wrong type
42-
subject.when(simple).calledWith(1).thenReturn(42);
39+
assertType<subject.Stub<[number], string>>(stub);
4340
});
4441

4542
it('should handle an overloaded function using its last overload', () => {
@@ -50,30 +47,14 @@ describe('vitest-when type signatures', () => {
5047
assertType<subject.Stub<[1], string>>(stub);
5148
});
5249

53-
it('should handle an overloaded function using its first overload', () => {
54-
const stub = subject.when(overloaded).calledWith();
50+
it('should handle an overloaded function using an explicit type', () => {
51+
const stub = subject.when<() => null>(overloaded).calledWith();
5552

5653
stub.thenReturn(null);
5754

5855
assertType<subject.Stub<[], null>>(stub);
5956
});
6057

61-
it('should handle an very overloaded function using its first overload', () => {
62-
const stub = subject.when(veryOverloaded).calledWith();
63-
64-
stub.thenReturn(null);
65-
66-
assertType<subject.Stub<[], null>>(stub);
67-
});
68-
69-
it('should handle an overloaded function using its last overload', () => {
70-
const stub = subject.when(veryOverloaded).calledWith(1, 2, 3, 4);
71-
72-
stub.thenReturn(42);
73-
74-
assertType<subject.Stub<[1, 2, 3, 4], number>>(stub);
75-
});
76-
7758
it('should reject invalid usage of a simple function', () => {
7859
// @ts-expect-error: args missing
7960
subject.when(simple).calledWith();
@@ -84,6 +65,17 @@ describe('vitest-when type signatures', () => {
8465
// @ts-expect-error: return wrong type
8566
subject.when(simple).calledWith(1).thenReturn(42);
8667
});
68+
69+
it('should reject invalid usage of a generic function', () => {
70+
// @ts-expect-error: args missing
71+
subject.when(generic).calledWith();
72+
73+
// @ts-expect-error: args wrong type
74+
subject.when(generic<string>).calledWith(42);
75+
76+
// @ts-expect-error: return wrong type
77+
subject.when(generic).calledWith(1).thenReturn(42);
78+
});
8779
});
8880

8981
function untyped(...args: any[]): any {
@@ -94,22 +86,12 @@ function simple(input: number): string {
9486
throw new Error(`simple(${input})`);
9587
}
9688

89+
function generic<T>(input: T): string {
90+
throw new Error(`generic(${input})`);
91+
}
92+
9793
function overloaded(): null;
9894
function overloaded(input: number): string;
9995
function overloaded(input?: number): string | null {
10096
throw new Error(`overloaded(${input})`);
10197
}
102-
103-
function veryOverloaded(): null;
104-
function veryOverloaded(i1: number): string;
105-
function veryOverloaded(i1: number, i2: number): boolean;
106-
function veryOverloaded(i1: number, i2: number, i3: number): null;
107-
function veryOverloaded(i1: number, i2: number, i3: number, i4: number): number;
108-
function veryOverloaded(
109-
i1?: number,
110-
i2?: number,
111-
i3?: number,
112-
i4?: number
113-
): string | boolean | number | null {
114-
throw new Error(`veryOverloaded(${i1}, ${i2}, ${i3}, ${i4})`);
115-
}

0 commit comments

Comments
 (0)