Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 4 additions & 51 deletions packages/core/src/auth/sso/clients.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,25 @@ import {
SSOServiceException,
} from '@aws-sdk/client-sso'
import {
AuthorizationPendingException,
CreateTokenRequest,
RegisterClientRequest,
SSOOIDC,
SSOOIDCClient,
StartDeviceAuthorizationRequest,
} from '@aws-sdk/client-sso-oidc'
import { AsyncCollection } from '../../shared/utilities/asyncCollection'
import { pageableToCollection, partialClone } from '../../shared/utilities/collectionUtils'
import { pageableToCollection } from '../../shared/utilities/collectionUtils'
import { assertHasProps, isNonNullable, RequiredProps, selectFrom } from '../../shared/utilities/tsUtils'
import { getLogger } from '../../shared/logger/logger'
import { SsoAccessTokenProvider } from './ssoAccessTokenProvider'
import { AwsClientResponseError, isClientFault } from '../../shared/errors'
import { DevSettings } from '../../shared/settings'
import { SdkError } from '@aws-sdk/types'
import { HttpRequest, HttpResponse } from '@smithy/protocol-http'
import { StandardRetryStrategy, defaultRetryDecider } from '@smithy/middleware-retry'
import { AuthenticationFlow } from './model'
import { toSnakeCase } from '../../shared/utilities/textUtilities'
import { getUserAgent, withTelemetryContext } from '../../shared/telemetry/util'
import { defaultDeserializeMiddleware, finalizeLoggingMiddleware } from '../../shared/awsClientBuilderV3'

export class OidcClient {
public constructor(
Expand Down Expand Up @@ -249,52 +248,6 @@ export class SsoClient {
}

function addLoggingMiddleware(client: SSOOIDCClient) {
client.middlewareStack.add(
(next, context) => (args) => {
if (HttpRequest.isInstance(args.request)) {
const { hostname, path } = args.request
const input = partialClone(
// TODO: Fix
args.input as unknown as Record<string, unknown>,
3,
['clientSecret', 'accessToken', 'refreshToken'],
'[omitted]'
)
getLogger().debug('API request (%s %s): %O', hostname, path, input)
}
return next(args)
},
{ step: 'finalizeRequest' }
)

client.middlewareStack.add(
(next, context) => async (args) => {
if (!HttpRequest.isInstance(args.request)) {
return next(args)
}

const { hostname, path } = args.request
const result = await next(args).catch((e) => {
if (e instanceof Error && !(e instanceof AuthorizationPendingException)) {
const err = { ...e }
delete err['stack']
getLogger().error('API response (%s %s): %O', hostname, path, err)
}
throw e
})
if (HttpResponse.isInstance(result.response)) {
const output = partialClone(
// TODO: Fix
result.output as unknown as Record<string, unknown>,
3,
['clientSecret', 'accessToken', 'refreshToken'],
'[omitted]'
)
getLogger().debug('API response (%s %s): %O', hostname, path, output)
}

return result
},
{ step: 'deserialize' }
)
client.middlewareStack.add(finalizeLoggingMiddleware, { step: 'finalizeRequest' })
client.middlewareStack.add(defaultDeserializeMiddleware, { step: 'deserialize' })
}
114 changes: 77 additions & 37 deletions packages/core/src/shared/awsClientBuilderV3.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import { CredentialsShim } from '../auth/deprecated/loginManager'
import { AwsContext } from './awsContext'
import {
AwsCredentialIdentityProvider,
BuildHandlerArguments,
DeserializeHandlerArguments,
DeserializeHandlerOutput,
FinalizeHandlerArguments,
Logger,
RetryStrategyV2,
TokenIdentity,
Expand All @@ -18,10 +22,8 @@ import {
BuildHandler,
BuildMiddleware,
DeserializeHandler,
DeserializeMiddleware,
Handler,
FinalizeHandler,
FinalizeRequestMiddleware,
HandlerExecutionContext,
MetadataBearer,
MiddlewareStack,
Expand All @@ -42,6 +44,7 @@ import { partialClone } from './utilities/collectionUtils'
import { selectFrom } from './utilities/tsUtils'
import { once } from './utilities/functionUtils'
import { isWeb } from './extensionGlobals'
import { AuthorizationPendingException } from '@aws-sdk/client-sso-oidc'

export type AwsClientConstructor<C> = new (o: AwsClientOptions) => C
export type AwsCommandConstructor<CommandInput extends object, Command extends AwsCommand<CommandInput, object>> = new (
Expand Down Expand Up @@ -175,8 +178,8 @@ export class AWSClientBuilderV3 {
}

const service = new serviceOptions.serviceClient(opt)
service.middlewareStack.add(telemetryMiddleware, { step: 'deserialize' })
service.middlewareStack.add(loggingMiddleware, { step: 'finalizeRequest' })
service.middlewareStack.add(defaultDeserializeMiddleware, { step: 'deserialize' })
service.middlewareStack.add(finalizeLoggingMiddleware, { step: 'finalizeRequest' })
service.middlewareStack.add(getEndpointMiddleware(serviceOptions.settings), { step: 'build' })

if (keepAlive) {
Expand Down Expand Up @@ -211,65 +214,85 @@ export function recordErrorTelemetry(err: Error, serviceName?: string) {
})
}

function logAndThrow(e: any, serviceId: string, errorMessageAppend: string): never {
if (e instanceof Error) {
recordErrorTelemetry(e, serviceId)
getLogger().error('API Response %s: %O', errorMessageAppend, e)
}
throw e
export function defaultDeserializeMiddleware<Input extends object, Output extends object>(
next: DeserializeHandler<Input, Output>,
context: HandlerExecutionContext
) {
return async (args: DeserializeHandlerArguments<Input>) => onDeserialize(next, context, args)
}

const telemetryMiddleware: DeserializeMiddleware<any, any> =
(next: DeserializeHandler<any, any>, context: HandlerExecutionContext) => async (args: any) =>
emitOnRequest(next, context, args)

const loggingMiddleware: FinalizeRequestMiddleware<any, any> = (next: FinalizeHandler<any, any>) => async (args: any) =>
logOnRequest(next, args)
export function finalizeLoggingMiddleware<Input extends object, Output extends object>(
next: FinalizeHandler<Input, Output>
) {
return async (args: FinalizeHandlerArguments<Input>) => logOnFinalize(next, args)
}

function getEndpointMiddleware(settings: DevSettings = DevSettings.instance): BuildMiddleware<any, any> {
return (next: BuildHandler<any, any>, context: HandlerExecutionContext) => async (args: any) =>
overwriteEndpoint(next, context, settings, args)
function getEndpointMiddleware<Input extends object, Output extends object>(
settings: DevSettings = DevSettings.instance
): BuildMiddleware<Input, Output> {
return (next: BuildHandler<Input, Output>, context: HandlerExecutionContext) =>
async (args: BuildHandlerArguments<Input>) =>
overwriteEndpoint(next, context, settings, args)
}

const keepAliveMiddleware: BuildMiddleware<any, any> = (next: BuildHandler<any, any>) => async (args: any) =>
addKeepAliveHeader(next, args)
function keepAliveMiddleware<Input extends object, Output extends object>(next: BuildHandler<Input, Output>) {
return async (args: BuildHandlerArguments<Input>) => addKeepAliveHeader(next, args)
}

export async function emitOnRequest(next: DeserializeHandler<any, any>, context: HandlerExecutionContext, args: any) {
if (!HttpResponse.isInstance(args.request)) {
export async function onDeserialize<Input extends object, Output extends object>(
next: DeserializeHandler<Input, Output>,
context: HandlerExecutionContext,
args: DeserializeHandlerArguments<Input>
): Promise<DeserializeHandlerOutput<Output>> {
const request = args.request
if (!HttpRequest.isInstance(request)) {
return next(args)
}
const { hostname, path } = request
const serviceId = getServiceId(context as object)
const { hostname, path } = args.request
const logTail = `(${hostname} ${path})`
try {
const result = await next(args)
if (HttpResponse.isInstance(result.response)) {
// TODO: omit credentials / sensitive info from the telemetry.
const output = partialClone(result.output, 3)
const output = partialClone(result.output, 3, ['clientSecret', 'accessToken', 'refreshToken'], '[omitted]')
getLogger().debug(`API Response %s: %O`, logTail, output)
}
return result
} catch (e: any) {
logAndThrow(e, serviceId, logTail)
} catch (e: unknown) {
if (e instanceof Error && !(e instanceof AuthorizationPendingException)) {
const err = { ...e, name: e.name, mesage: e.message }
delete err['stack']
recordErrorTelemetry(err, serviceId)
getLogger().warn(`API Request %s resulted in error: %O`, logTail, err)
}
throw e
}
}

export async function logOnRequest(next: FinalizeHandler<any, any>, args: any) {
export function logOnFinalize<Input extends object, Output extends object>(
next: FinalizeHandler<Input, Output>,
args: FinalizeHandlerArguments<Input>
) {
const request = args.request
if (HttpRequest.isInstance(args.request)) {
if (HttpRequest.isInstance(request)) {
const { hostname, path } = request
// TODO: omit credentials / sensitive info from the logs.
const input = partialClone(args.input, 3)
getLogger().debug(`API Request (%s %s): %O`, hostname, path, input)
const input = partialClone(args.input, 3, ['clientSecret', 'accessToken', 'refreshToken'], '[omitted]')
getLogger().debug(
`API Request (%s %s):\n headers: %O\n input: %O`,
hostname,
path,
filterRequestHeaders(request),
input
)
}
return next(args)
}

export function overwriteEndpoint(
next: BuildHandler<any, any>,
export function overwriteEndpoint<Input extends object, Output extends object>(
next: BuildHandler<Input, Output>,
context: HandlerExecutionContext,
settings: DevSettings,
args: any
args: BuildHandlerArguments<Input>
) {
const request = args.request
if (HttpRequest.isInstance(request)) {
Expand All @@ -291,10 +314,27 @@ export function overwriteEndpoint(
* @param args
* @returns
*/
export function addKeepAliveHeader(next: BuildHandler<any, any>, args: any) {
export function addKeepAliveHeader<Input extends object, Output extends object>(
next: BuildHandler<Input, Output>,
args: BuildHandlerArguments<Input>
) {
const request = args.request
if (HttpRequest.isInstance(request)) {
request.headers['Connection'] = 'keep-alive'
}
return next(args)
}

function filterRequestHeaders(request: HttpRequest) {
const logHeaderNames = [
'x-amzn-requestid',
'x-amzn-trace-id',
'x-amzn-served-from',
'x-cache',
'x-amz-cf-id',
'x-amz-cf-pop',
'Connection',
'host',
]
return Object.fromEntries(Object.entries(request.headers).filter(([k, _]) => logHeaderNames.includes(k)))
}
58 changes: 55 additions & 3 deletions packages/core/src/shared/clients/clientWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,28 @@ import globals from '../extensionGlobals'
import { AwsClient, AwsClientConstructor, AwsCommand, AwsCommandConstructor } from '../awsClientBuilderV3'
import { PaginationConfiguration, Paginator } from '@aws-sdk/types'
import { AsyncCollection, toCollection } from '../utilities/asyncCollection'
import { isDefined } from '../utilities/tsUtils'
import { hasKey, isDefined } from '../utilities/tsUtils'
import { PerfLog } from '../logger/perfLogger'
import { getLogger } from '../logger/logger'
import { truncateProps } from '../utilities/textUtilities'
import { ToolkitError } from '../errors'

type SDKPaginator<C, CommandInput extends object, CommandOutput extends object> = (
config: Omit<PaginationConfiguration, 'client'> & { client: C },
input: CommandInput,
...rest: any[]
) => Paginator<CommandOutput>

interface RequestOptions<Output extends object> {
/**
* Resolve this value if the request fails. If not present, will re-throw error.
*/
fallbackValue?: Output
/**
* Do not used cached client for the request.
*/
ignoreCache?: boolean
}
export abstract class ClientWrapper<C extends AwsClient> implements vscode.Disposable {
protected client?: C

Expand All @@ -34,10 +49,39 @@ export abstract class ClientWrapper<C extends AwsClient> implements vscode.Dispo
CommandOutput extends object,
CommandOptions extends CommandInput,
Command extends AwsCommand<CommandInput, CommandOutput>,
>(command: AwsCommandConstructor<CommandInput, Command>, commandOptions: CommandOptions): Promise<CommandOutput> {
return await this.getClient().send(new command(commandOptions))
>(
command: AwsCommandConstructor<CommandInput, Command>,
commandOptions: CommandOptions,
requestOptions?: RequestOptions<CommandOutput>
): Promise<CommandOutput> {
const action = 'API Request'
const perflog = new PerfLog(action)
return await this.getClient(requestOptions?.ignoreCache)
.send(new command(commandOptions))
.catch(async (e) => {
await this.onError(e)
const errWithoutStack = { ...e, name: e.name, message: e.message }
delete errWithoutStack['stack']
const timecost = perflog.elapsed().toFixed(1)
if (requestOptions?.fallbackValue) {
return requestOptions.fallbackValue
}
// Error is already logged in middleware before this, so we omit it here.
getLogger().error(
`${action} failed without fallback (time: %dms) \nparams: %O`,
timecost,
truncateProps(commandOptions, 20, ['nextToken'])
)
throw new ToolkitError(`${action}: ${errWithoutStack.message}`, {
code: extractCode(errWithoutStack),
cause: errWithoutStack,
})
})
}

// Intended to be overwritten by subclasses to implement custom error handling behavior.
protected onError(_: Error): void | Promise<void> {}

protected makePaginatedRequest<CommandInput extends object, CommandOutput extends object, Output extends object>(
paginator: SDKPaginator<C, CommandInput, CommandOutput>,
input: CommandInput,
Expand Down Expand Up @@ -65,3 +109,11 @@ export abstract class ClientWrapper<C extends AwsClient> implements vscode.Dispo
this.client?.destroy()
}
}

function extractCode(e: Error): string {
return hasKey(e, 'code') && typeof e['code'] === 'string'
? e.code
: hasKey(e, 'Code') && typeof e['Code'] === 'string'
? e.Code
: e.name
}
Loading
Loading