Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ export class OpenAI {
fetchOptions: MergedRequestInit | undefined;

private fetch: Fetch;
#encoder: Opts.RequestEncoder;
// Use a normal private field instead of JS #private to avoid
// brand-check crashes when methods are invoked across copies.
private encoder: Opts.RequestEncoder;
protected idempotencyHeader?: string;
protected _options: ClientOptions;

Expand Down Expand Up @@ -392,7 +394,7 @@ export class OpenAI {
this.fetchOptions = options.fetchOptions;
this.maxRetries = options.maxRetries ?? 2;
this.fetch = options.fetch ?? Shims.getDefaultFetch();
this.#encoder = Opts.FallbackEncoder;
this.encoder = Opts.FallbackEncoder;

this._options = options;

Expand Down Expand Up @@ -427,7 +429,7 @@ export class OpenAI {
/**
* Check whether the base URL is set to its default.
*/
#baseURLOverridden(): boolean {
private baseURLOverridden(): boolean {
return this.baseURL !== 'https://api.openai.com/v1';
}

Expand Down Expand Up @@ -494,7 +496,7 @@ export class OpenAI {
query: Record<string, unknown> | null | undefined,
defaultBaseURL?: string | undefined,
): string {
const baseURL = (!this.#baseURLOverridden() && defaultBaseURL) || this.baseURL;
const baseURL = (!this.baseURLOverridden() && defaultBaseURL) || this.baseURL;
const url =
isAbsoluteURL(path) ?
new URL(path)
Expand Down Expand Up @@ -960,7 +962,7 @@ export class OpenAI {
) {
return { bodyHeaders: undefined, body: Shims.ReadableStreamFrom(body as AsyncIterable<Uint8Array>) };
} else {
return this.#encoder({ body, headers });
return this.encoder({ body, headers });
}
}

Expand Down
16 changes: 12 additions & 4 deletions src/core/api-promise.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

import { type OpenAI } from '../client';
import { OpenAIError } from './error';

import { type PromiseOrValue } from '../internal/types';
import {
Expand All @@ -14,9 +15,12 @@ import {
* A subclass of `Promise` providing additional helper methods
* for interacting with the SDK.
*/
// Associate instance with client via a module WeakMap to avoid
// JS private-field brand checks across bundles/copies.
const apiPromiseClient = /* @__PURE__ */ new WeakMap<object, OpenAI>();

export class APIPromise<T> extends Promise<WithRequestID<T>> {
private parsedPromise: Promise<WithRequestID<T>> | undefined;
#client: OpenAI;

constructor(
client: OpenAI,
Expand All @@ -32,11 +36,13 @@ export class APIPromise<T> extends Promise<WithRequestID<T>> {
// to parse the response
resolve(null as any);
});
this.#client = client;
apiPromiseClient.set(this, client);
}

_thenUnwrap<U>(transform: (data: T, props: APIResponseProps) => U): APIPromise<U> {
return new APIPromise(this.#client, this.responsePromise, async (client, props) =>
const client = apiPromiseClient.get(this);
if (!client) throw new OpenAIError('Illegal invocation of APIPromise method');
return new APIPromise(client, this.responsePromise, async (client, props) =>
addRequestID(transform(await this.parseResponse(client, props), props), props.response),
);
}
Expand Down Expand Up @@ -75,8 +81,10 @@ export class APIPromise<T> extends Promise<WithRequestID<T>> {

private parse(): Promise<WithRequestID<T>> {
if (!this.parsedPromise) {
const client = apiPromiseClient.get(this);
if (!client) throw new OpenAIError('Illegal invocation of APIPromise method');
this.parsedPromise = this.responsePromise.then((data) =>
this.parseResponse(this.#client, data),
this.parseResponse(client, data),
) as any as Promise<WithRequestID<T>>;
}
return this.parsedPromise;
Expand Down
11 changes: 8 additions & 3 deletions src/core/pagination.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@ import { maybeObj } from '../internal/utils/values';

export type PageRequestOptions = Pick<FinalRequestOptions, 'query' | 'headers' | 'body' | 'path' | 'method'>;

// Associate pages with their client via a module WeakMap to avoid
// JS private-field brand checks across bundles/copies.
const pageClient = /* @__PURE__ */ new WeakMap<object, OpenAI>();

export abstract class AbstractPage<Item> implements AsyncIterable<Item> {
#client: OpenAI;
protected options: FinalRequestOptions;

protected response: Response;
protected body: unknown;

constructor(client: OpenAI, response: Response, body: unknown, options: FinalRequestOptions) {
this.#client = client;
pageClient.set(this, client);
this.options = options;
this.response = response;
this.body = body;
Expand All @@ -42,7 +45,9 @@ export abstract class AbstractPage<Item> implements AsyncIterable<Item> {
);
}

return await this.#client.requestAPIList(this.constructor as any, nextOptions);
const client = pageClient.get(this);
if (!client) throw new OpenAIError('Illegal invocation of Page method');
return await client.requestAPIList(this.constructor as any, nextOptions);
}

async *iterPages(): AsyncGenerator<this> {
Expand Down
16 changes: 10 additions & 6 deletions src/core/streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ export type ServerSentEvent = {
raw: string[];
};

// Associate Stream instances with their client via a module WeakMap to avoid
// JS private-field brand checks across bundles/copies.
const streamClient = /* @__PURE__ */ new WeakMap<object, OpenAI | undefined>();

export class Stream<Item> implements AsyncIterable<Item> {
controller: AbortController;
#client: OpenAI | undefined;

constructor(
private iterator: () => AsyncIterator<Item>,
controller: AbortController,
client?: OpenAI,
) {
this.controller = controller;
this.#client = client;
streamClient.set(this, client);
}

static fromSSEResponse<Item>(
Expand Down Expand Up @@ -75,8 +78,8 @@ export class Stream<Item> implements AsyncIterable<Item> {
try {
data = JSON.parse(sse.data);
} catch (e) {
console.error(`Could not parse message into JSON:`, sse.data);
console.error(`From chunk:`, sse.raw);
logger.error(`Could not parse message into JSON:`, sse.data);
logger.error(`From chunk:`, sse.raw);
throw e;
}
// TODO: Is this where the error should be thrown?
Expand Down Expand Up @@ -177,9 +180,10 @@ export class Stream<Item> implements AsyncIterable<Item> {
};
};

const client = streamClient.get(this);
return [
new Stream(() => teeIterator(left), this.controller, this.#client),
new Stream(() => teeIterator(right), this.controller, this.#client),
new Stream(() => teeIterator(left), this.controller, client),
new Stream(() => teeIterator(right), this.controller, client),
];
}

Expand Down
65 changes: 65 additions & 0 deletions tests/core/brand-guards.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { APIPromise } from 'openai/core/api-promise';
import { AbstractPage, Page } from 'openai/core/pagination';
import { Stream } from 'openai/core/streaming';

const dummyResponseProps: any = { response: new Response(), options: {} };
const dummyParse = (_client: any, _props: any) => ({ data: null, response: new Response() });

describe('core brand-guard stability', () => {
test('APIPromise.then with mismatched this does not throw private-field TypeError', async () => {
const fake: any = Object.create(APIPromise.prototype);
fake.responsePromise = Promise.resolve(dummyResponseProps);
fake.parseResponse = dummyParse;

const call = () => (APIPromise.prototype.then as any).call(fake, () => {});
expect(call).toThrow(Error);
expect(call).not.toThrow(/private member/i);
});

test('APIPromise.catch with mismatched this does not throw private-field TypeError', async () => {
const fake: any = Object.create(APIPromise.prototype);
fake.responsePromise = Promise.resolve(dummyResponseProps);
fake.parseResponse = dummyParse;

const call = () => (APIPromise.prototype.catch as any).call(fake, () => {});
expect(call).toThrow(Error);
expect(call).not.toThrow(/private member/i);
});

test('APIPromise.finally with mismatched this does not throw private-field TypeError', async () => {
const fake: any = Object.create(APIPromise.prototype);
fake.responsePromise = Promise.resolve(dummyResponseProps);
fake.parseResponse = dummyParse;

const call = () => (APIPromise.prototype.finally as any).call(fake, () => {});
expect(call).toThrow(Error);
expect(call).not.toThrow(/private member/i);
});

test('AbstractPage.getNextPage with mismatched this does not throw private-field TypeError', async () => {
class TestPage<Item> extends Page<Item> {
override nextPageRequestOptions() {
return { path: '/v1/anything', method: 'get' } as any;
}
}

const fake: any = Object.create(TestPage.prototype);
fake.options = { path: '/v1/anything', method: 'get' };
fake.getPaginatedItems = () => [1];
fake.response = new Response();
fake.body = {};

const call = () => (AbstractPage.prototype.getNextPage as any).call(fake);
await expect(call()).rejects.toBeInstanceOf(Error);
await expect(call()).rejects.not.toThrow(/private member/i);
});

test('Stream.tee with mismatched this does not throw private-field TypeError', () => {
const fake: any = Object.create(Stream.prototype);
fake.controller = new AbortController();
fake.iterator = async function* () {};

const call = () => (Stream.prototype.tee as any).call(fake);
expect(call).not.toThrow(/private member/i);
});
});