Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ export class AutoEncrypter {
context.ns = ns;
context.document = cmd;

const stateMachine = new StateMachine({
const stateMachine = new StateMachine(this, {
promoteValues: false,
promoteLongs: false,
proxyOptions: this._proxyOptions,
Expand All @@ -436,7 +436,7 @@ export class AutoEncrypter {

context.id = this._contextCounter++;

const stateMachine = new StateMachine({
const stateMachine = new StateMachine(this, {
...options,
proxyOptions: this._proxyOptions,
tlsOptions: this._tlsOptions,
Expand Down
8 changes: 4 additions & 4 deletions src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ export class ClientEncryption {
keyMaterial
});

const stateMachine = new StateMachine({
const stateMachine = new StateMachine(this, {
proxyOptions: this._proxyOptions,
tlsOptions: this._tlsOptions,
socketOptions: autoSelectSocketOptions(this._client.s.options)
Expand Down Expand Up @@ -295,7 +295,7 @@ export class ClientEncryption {
}
const filterBson = serialize(filter);
const context = this._mongoCrypt.makeRewrapManyDataKeyContext(filterBson, keyEncryptionKeyBson);
const stateMachine = new StateMachine({
const stateMachine = new StateMachine(this, {
proxyOptions: this._proxyOptions,
tlsOptions: this._tlsOptions,
socketOptions: autoSelectSocketOptions(this._client.s.options)
Expand Down Expand Up @@ -699,7 +699,7 @@ export class ClientEncryption {
const valueBuffer = serialize({ v: value });
const context = this._mongoCrypt.makeExplicitDecryptionContext(valueBuffer);

const stateMachine = new StateMachine({
const stateMachine = new StateMachine(this, {
proxyOptions: this._proxyOptions,
tlsOptions: this._tlsOptions,
socketOptions: autoSelectSocketOptions(this._client.s.options)
Expand Down Expand Up @@ -783,7 +783,7 @@ export class ClientEncryption {
}

const valueBuffer = serialize({ v: value });
const stateMachine = new StateMachine({
const stateMachine = new StateMachine(this, {
proxyOptions: this._proxyOptions,
tlsOptions: this._tlsOptions,
socketOptions: autoSelectSocketOptions(this._client.s.options)
Expand Down
14 changes: 9 additions & 5 deletions src/client-side-encryption/state_machine.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import * as fs from 'fs/promises';
import { type MongoCryptContext, type MongoCryptKMSRequest } from 'mongodb-client-encryption';
import * as net from 'net';
import * as tls from 'tls';
Expand All @@ -14,7 +13,7 @@ import { type ProxyOptions } from '../cmap/connection';
import { CursorTimeoutContext } from '../cursor/abstract_cursor';
import { getSocks, type SocksLib } from '../deps';
import { MongoOperationTimeoutError } from '../error';
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
import { type IO, type MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { type CollectionInfo } from '../operations/list_collections';
import { Timeout, type TimeoutContext, TimeoutError } from '../timeout';
Expand Down Expand Up @@ -186,10 +185,15 @@ export type StateMachineOptions = {
*/
// TODO(DRIVERS-2671): clarify CSOT behavior for FLE APIs
export class StateMachine {
private parent: { _client: { io: IO } };

constructor(
parent: { _client: { io: IO } },
private options: StateMachineOptions,
private bsonOptions = pluckBSONSerializeOptions(options)
) {}
) {
this.parent = parent;
}

/**
* Executes the state machine according to the specification
Expand Down Expand Up @@ -524,11 +528,11 @@ export class StateMachine {
options: tls.ConnectionOptions
): Promise<void> {
if (tlsOptions.tlsCertificateKeyFile) {
const cert = await fs.readFile(tlsOptions.tlsCertificateKeyFile);
const cert = await this.parent._client.io.fs.readFile(tlsOptions.tlsCertificateKeyFile);
options.cert = options.key = cert;
}
if (tlsOptions.tlsCAFile) {
options.ca = await fs.readFile(tlsOptions.tlsCAFile);
options.ca = await this.parent._client.io.fs.readFile(tlsOptions.tlsCAFile);
}
if (tlsOptions.tlsCertificateKeyFilePassword) {
options.passphrase = tlsOptions.tlsCertificateKeyFilePassword;
Expand Down
11 changes: 6 additions & 5 deletions src/cmap/auth/mongodb_oidc.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { Document } from '../../bson';
import { MongoInvalidArgumentError, MongoMissingCredentialsError } from '../../error';
import { type MongoClient } from '../../mongo_client';
import type { HandshakeDocument } from '../connect';
import type { Connection } from '../connection';
import { type AuthContext, AuthProvider } from './auth_provider';
Expand Down Expand Up @@ -115,11 +116,11 @@ export interface Workflow {
}

/** @internal */
export const OIDC_WORKFLOWS: Map<EnvironmentName, () => Workflow> = new Map();
OIDC_WORKFLOWS.set('test', () => new TokenMachineWorkflow(new TokenCache()));
OIDC_WORKFLOWS.set('azure', () => new AzureMachineWorkflow(new TokenCache()));
OIDC_WORKFLOWS.set('gcp', () => new GCPMachineWorkflow(new TokenCache()));
OIDC_WORKFLOWS.set('k8s', () => new K8SMachineWorkflow(new TokenCache()));
export const OIDC_WORKFLOWS: Map<EnvironmentName, (client: MongoClient) => Workflow> = new Map();
OIDC_WORKFLOWS.set('test', client => new TokenMachineWorkflow(client, new TokenCache()));
OIDC_WORKFLOWS.set('azure', client => new AzureMachineWorkflow(client, new TokenCache()));
OIDC_WORKFLOWS.set('gcp', client => new GCPMachineWorkflow(client, new TokenCache()));
OIDC_WORKFLOWS.set('k8s', client => new K8SMachineWorkflow(client, new TokenCache()));

/**
* OIDC auth provider.
Expand Down
8 changes: 0 additions & 8 deletions src/cmap/auth/mongodb_oidc/azure_machine_workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { MongoAzureError } from '../../../error';
import { get } from '../../../utils';
import type { MongoCredentials } from '../mongo_credentials';
import { type AccessToken, MachineWorkflow } from './machine_workflow';
import { type TokenCache } from './token_cache';

/** Azure request headers. */
const AZURE_HEADERS = Object.freeze({ Metadata: 'true', Accept: 'application/json' });
Expand All @@ -22,13 +21,6 @@ const TOKEN_RESOURCE_MISSING_ERROR =
* @internal
*/
export class AzureMachineWorkflow extends MachineWorkflow {
/**
* Instantiate the machine workflow.
*/
constructor(cache: TokenCache) {
super(cache);
}

/**
* Get the token from the environment.
*/
Expand Down
8 changes: 0 additions & 8 deletions src/cmap/auth/mongodb_oidc/gcp_machine_workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { MongoGCPError } from '../../../error';
import { get } from '../../../utils';
import { type MongoCredentials } from '../mongo_credentials';
import { type AccessToken, MachineWorkflow } from './machine_workflow';
import { type TokenCache } from './token_cache';

/** GCP base URL. */
const GCP_BASE_URL =
Expand All @@ -16,13 +15,6 @@ const TOKEN_RESOURCE_MISSING_ERROR =
'TOKEN_RESOURCE must be set in the auth mechanism properties when ENVIRONMENT is gcp.';

export class GCPMachineWorkflow extends MachineWorkflow {
/**
* Instantiate the machine workflow.
*/
constructor(cache: TokenCache) {
super(cache);
}

/**
* Get the token from the environment.
*/
Expand Down
16 changes: 4 additions & 12 deletions src/cmap/auth/mongodb_oidc/k8s_machine_workflow.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { readFile } from 'fs/promises';

import { type MongoClient } from '../../../mongo_client';
import { type MongoCredentials } from '../mongo_credentials';
import { type AccessToken, MachineWorkflow } from './machine_workflow';
import { type TokenCache } from './token_cache';

/** The fallback file name */
const FALLBACK_FILENAME = '/var/run/secrets/kubernetes.io/serviceaccount/token';
Expand All @@ -13,17 +12,10 @@ const AZURE_FILENAME = 'AZURE_FEDERATED_TOKEN_FILE';
const AWS_FILENAME = 'AWS_WEB_IDENTITY_TOKEN_FILE';

export class K8SMachineWorkflow extends MachineWorkflow {
/**
* Instantiate the machine workflow.
*/
constructor(cache: TokenCache) {
super(cache);
}

/**
* Get the token from the environment.
*/
async getToken(): Promise<AccessToken> {
async getToken(_credentials: MongoCredentials, client: MongoClient): Promise<AccessToken> {
let filename: string;
if (process.env[AZURE_FILENAME]) {
filename = process.env[AZURE_FILENAME];
Expand All @@ -32,7 +24,7 @@ export class K8SMachineWorkflow extends MachineWorkflow {
} else {
filename = FALLBACK_FILENAME;
}
const token = await readFile(filename, 'utf8');
const token = await client.io.fs.readFile(filename, { encoding: 'utf8' });
return { access_token: token };
}
}
16 changes: 11 additions & 5 deletions src/cmap/auth/mongodb_oidc/machine_workflow.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { setTimeout } from 'timers/promises';

import { type Document } from '../../../bson';
import { type IO } from '../../../mongo_client';
import { ns } from '../../../utils';
import type { Connection } from '../../connection';
import type { MongoCredentials } from '../mongo_credentials';
Expand All @@ -21,7 +22,10 @@ export interface AccessToken {
}

/** @internal */
export type OIDCTokenFunction = (credentials: MongoCredentials) => Promise<AccessToken>;
export type OIDCTokenFunction = (
credentials: MongoCredentials,
client: { io: IO }
) => Promise<AccessToken>;

/**
* Common behaviour for OIDC machine workflows.
Expand All @@ -31,11 +35,13 @@ export abstract class MachineWorkflow implements Workflow {
cache: TokenCache;
callback: OIDCTokenFunction;
lastExecutionTime: number;
client: { io: IO };

/**
* Instantiate the machine workflow.
*/
constructor(cache: TokenCache) {
constructor(client: { io: IO }, cache: TokenCache) {
this.client = client;
this.cache = cache;
this.callback = this.withLock(this.getToken.bind(this));
this.lastExecutionTime = Date.now() - THROTTLE_MS;
Expand Down Expand Up @@ -101,7 +107,7 @@ export abstract class MachineWorkflow implements Workflow {
}
return token;
} else {
const token = await this.callback(credentials);
const token = await this.callback(credentials, connection.client);
this.cache.put({ accessToken: token.access_token, expiresInSeconds: token.expires_in });
// Put the access token on the connection as well.
connection.accessToken = token.access_token;
Expand Down Expand Up @@ -129,7 +135,7 @@ export abstract class MachineWorkflow implements Workflow {
await setTimeout(THROTTLE_MS - difference);
}
this.lastExecutionTime = Date.now();
return await callback(credentials);
return await callback(credentials, this.client);
});
return await lock;
};
Expand All @@ -138,5 +144,5 @@ export abstract class MachineWorkflow implements Workflow {
/**
* Get the token from the environment or endpoint.
*/
abstract getToken(credentials: MongoCredentials): Promise<AccessToken>;
abstract getToken(credentials: MongoCredentials, client: { io: IO }): Promise<AccessToken>;
}
16 changes: 4 additions & 12 deletions src/cmap/auth/mongodb_oidc/token_machine_workflow.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import * as fs from 'fs';

import { MongoAWSError } from '../../../error';
import { type MongoClient } from '../../../mongo_client';
import { type MongoCredentials } from '../mongo_credentials';
import { type AccessToken, MachineWorkflow } from './machine_workflow';
import { type TokenCache } from './token_cache';

/** Error for when the token is missing in the environment. */
const TOKEN_MISSING_ERROR = 'OIDC_TOKEN_FILE must be set in the environment.';
Expand All @@ -13,22 +12,15 @@ const TOKEN_MISSING_ERROR = 'OIDC_TOKEN_FILE must be set in the environment.';
* @internal
*/
export class TokenMachineWorkflow extends MachineWorkflow {
/**
* Instantiate the machine workflow.
*/
constructor(cache: TokenCache) {
super(cache);
}

/**
* Get the token from the environment.
*/
async getToken(): Promise<AccessToken> {
async getToken(_: MongoCredentials, client: MongoClient): Promise<AccessToken> {
const tokenFile = process.env.OIDC_TOKEN_FILE;
if (!tokenFile) {
throw new MongoAWSError(TOKEN_MISSING_ERROR);
}
const token = await fs.promises.readFile(tokenFile, 'utf8');
const token = await client.io.fs.readFile(tokenFile, { encoding: 'utf8' });
return { access_token: token };
}
}
16 changes: 12 additions & 4 deletions src/cmap/connect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
MongoRuntimeError,
needsRetryableWriteLabel
} from '../error';
import { type IO } from '../mongo_client';
import { HostAddress, ns, promiseWithResolvers } from '../utils';
import { AuthContext } from './auth/auth_provider';
import { AuthMechanism } from './auth/providers';
Expand All @@ -35,11 +36,14 @@ import {
/** @public */
export type Stream = Socket | TLSSocket;

export async function connect(options: ConnectionOptions): Promise<Connection> {
export async function connect(
parent: { client: { io: IO } },
options: ConnectionOptions
): Promise<Connection> {
let connection: Connection | null = null;
try {
const socket = await makeSocket(options);
connection = makeConnection(options, socket);
connection = makeConnection(parent, options, socket);
await performInitialHandshake(connection, options);
return connection;
} catch (error) {
Expand All @@ -48,13 +52,17 @@ export async function connect(options: ConnectionOptions): Promise<Connection> {
}
}

export function makeConnection(options: ConnectionOptions, socket: Stream): Connection {
export function makeConnection(
parent: { client: { io: IO } },
options: ConnectionOptions,
socket: Stream
): Connection {
let ConnectionType = options.connectionType ?? Connection;
if (options.autoEncrypter) {
ConnectionType = CryptoConnection;
}

return new ConnectionType(socket, options);
return new ConnectionType(parent, socket, options);
}

function checkSupportedServer(hello: Document, options: ConnectionOptions) {
Expand Down
Loading