Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,572 changes: 1,572 additions & 0 deletions package-lock.json

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,9 @@
"@aws-sdk/client-iot": "~3.693.0",
"@aws-sdk/client-iotsecuretunneling": "~3.693.0",
"@aws-sdk/client-lambda": "<3.731.0",
"@aws-sdk/client-redshift": "~3.693.0",
"@aws-sdk/client-redshift-data": "~3.693.0",
"@aws-sdk/client-redshift-serverless": "~3.693.0",
"@aws-sdk/client-s3": "<3.731.0",
"@aws-sdk/client-s3-control": "^3.830.0",
"@aws-sdk/client-sagemaker": "<3.696.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import * as vscode from 'vscode'
import { DefaultRedshiftClient } from '../../../shared/clients/redshiftClient'
import { ConnectionParams } from '../models/models'
import { RedshiftData } from 'aws-sdk'
import { ColumnMetadata, Field } from '@aws-sdk/client-redshift-data'
import { telemetry } from '../../../shared/telemetry/telemetry'

export class RedshiftNotebookController {
Expand Down Expand Up @@ -79,8 +79,8 @@ export class RedshiftNotebookController {
}

let executionId: string | undefined
let columnMetadata: RedshiftData.ColumnMetadataList | undefined
const records: RedshiftData.SqlRecords = []
let columnMetadata: ColumnMetadata[] | undefined
const records: Field[][] = []
let nextToken: string | undefined
// get all the pages of the result
do {
Expand All @@ -90,7 +90,7 @@ export class RedshiftNotebookController {
nextToken,
executionId
)
if (result) {
if (result && result.statementResultResponse.Records) {
nextToken = result.statementResultResponse.NextToken
executionId = result.executionId
columnMetadata = result.statementResultResponse.ColumnMetadata
Expand All @@ -116,7 +116,7 @@ export class RedshiftNotebookController {
})
}

public getAsTable(connectionParams: ConnectionParams, columns: string[], records: RedshiftData.SqlRecords) {
public getAsTable(connectionParams: ConnectionParams, columns: string[], records: Field[][]) {
if (!records || records.length === 0) {
return '<p>No records to display<p>'
}
Expand Down
122 changes: 76 additions & 46 deletions packages/core/src/shared/clients/redshiftClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,43 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { Redshift, RedshiftServerless, RedshiftData } from 'aws-sdk'
import globals from '../extensionGlobals'
import { ClusterCredentials, ClustersMessage, GetClusterCredentialsMessage } from 'aws-sdk/clients/redshift'
import {
GetCredentialsRequest,
GetCredentialsResponse,
ListWorkgroupsResponse,
} from 'aws-sdk/clients/redshiftserverless'
ClusterCredentials,
ClustersMessage,
DescribeClustersCommand,
DescribeClustersMessage,
GetClusterCredentialsCommand,
GetClusterCredentialsMessage,
RedshiftClient,
} from '@aws-sdk/client-redshift'
import {
DescribeStatementCommand,
DescribeStatementRequest,
ExecuteStatementCommand,
GetStatementResultCommand,
GetStatementResultRequest,
GetStatementResultResponse,
ListDatabasesCommand,
ListDatabasesRequest,
ListDatabasesResponse,
ListSchemasCommand,
ListSchemasRequest,
ListSchemasResponse,
ListTablesCommand,
ListTablesRequest,
ListTablesResponse,
} from 'aws-sdk/clients/redshiftdata'
RedshiftDataClient,
} from '@aws-sdk/client-redshift-data'
import {
GetCredentialsCommand,
GetCredentialsRequest,
GetCredentialsResponse,
ListWorkgroupsCommand,
ListWorkgroupsRequest,
ListWorkgroupsResponse,
RedshiftServerlessClient,
} from '@aws-sdk/client-redshift-serverless'
import globals from '../extensionGlobals'
import { ConnectionParams, ConnectionType, RedshiftWarehouseType } from '../../awsService/redshift/models/models'
import { sleep } from '../utilities/timeoutUtils'
import { SecretsManagerClient } from './secretsManagerClient'
Expand All @@ -37,21 +58,21 @@ export class DefaultRedshiftClient {
public readonly regionCode: string,
private readonly redshiftDataClientProvider: (
regionCode: string
) => Promise<RedshiftData> = createRedshiftDataClient,
private readonly redshiftClientProvider: (regionCode: string) => Promise<Redshift> = createRedshiftSdkClient,
) => RedshiftDataClient = createRedshiftDataClient,
private readonly redshiftClientProvider: (regionCode: string) => RedshiftClient = createRedshiftSdkClient,
private readonly redshiftServerlessClientProvider: (
regionCode: string
) => Promise<RedshiftServerless> = createRedshiftServerlessSdkClient
) => RedshiftServerlessClient = createRedshiftServerlessSdkClient
) {}

// eslint-disable-next-line require-yield
public async describeProvisionedClusters(nextToken?: string): Promise<ClustersMessage> {
const redshiftClient = await this.redshiftClientProvider(this.regionCode)
const request: Redshift.DescribeClustersMessage = {
const redshiftClient = this.redshiftClientProvider(this.regionCode)
const request: DescribeClustersMessage = {
Marker: nextToken,
MaxRecords: 20,
}
const response = await redshiftClient.describeClusters(request).promise()
const response = await redshiftClient.send(new DescribeClustersCommand(request))
if (response.Clusters) {
response.Clusters = response.Clusters.filter(
(cluster) => cluster.ClusterAvailabilityStatus?.toLowerCase() === 'available'
Expand All @@ -61,12 +82,12 @@ export class DefaultRedshiftClient {
}

public async listServerlessWorkgroups(nextToken?: string): Promise<ListWorkgroupsResponse> {
const redshiftServerlessClient = await this.redshiftServerlessClientProvider(this.regionCode)
const request: RedshiftServerless.ListWorkgroupsRequest = {
const redshiftServerlessClient = this.redshiftServerlessClientProvider(this.regionCode)
const request: ListWorkgroupsRequest = {
nextToken: nextToken,
maxResults: 20,
}
const response = await redshiftServerlessClient.listWorkgroups(request).promise()
const response = await redshiftServerlessClient.send(new ListWorkgroupsCommand(request))
if (response.workgroups) {
response.workgroups = response.workgroups.filter(
(workgroup) => workgroup.status?.toLowerCase() === 'available'
Expand All @@ -76,10 +97,10 @@ export class DefaultRedshiftClient {
}

public async listDatabases(connectionParams: ConnectionParams, nextToken?: string): Promise<ListDatabasesResponse> {
const redshiftDataClient = await this.redshiftDataClientProvider(this.regionCode)
const redshiftDataClient = this.redshiftDataClientProvider(this.regionCode)
const warehouseType = connectionParams.warehouseType
const warehouseIdentifier = connectionParams.warehouseIdentifier
const input: RedshiftData.ListDatabasesRequest = {
const input: ListDatabasesRequest = {
ClusterIdentifier: warehouseType === RedshiftWarehouseType.PROVISIONED ? warehouseIdentifier : undefined,
Database: connectionParams.database,
DbUser:
Expand All @@ -94,13 +115,13 @@ export class DefaultRedshiftClient {
? connectionParams.secret
: undefined,
}
return redshiftDataClient.listDatabases(input).promise()
return redshiftDataClient.send(new ListDatabasesCommand(input))
}
public async listSchemas(connectionParams: ConnectionParams, nextToken?: string): Promise<ListSchemasResponse> {
const redshiftDataClient = await this.redshiftDataClientProvider(this.regionCode)
const redshiftDataClient = this.redshiftDataClientProvider(this.regionCode)
const warehouseType = connectionParams.warehouseType
const warehouseIdentifier = connectionParams.warehouseIdentifier
const input: RedshiftData.ListSchemasRequest = {
const input: ListSchemasRequest = {
ClusterIdentifier: warehouseType === RedshiftWarehouseType.PROVISIONED ? warehouseIdentifier : undefined,
Database: connectionParams.database,
DbUser:
Expand All @@ -114,18 +135,18 @@ export class DefaultRedshiftClient {
? connectionParams.secret
: undefined,
}
return redshiftDataClient.listSchemas(input).promise()
return redshiftDataClient.send(new ListSchemasCommand(input))
}

public async listTables(
connectionParams: ConnectionParams,
schemaName: string,
nextToken?: string
): Promise<ListTablesResponse> {
const redshiftDataClient = await this.redshiftDataClientProvider(this.regionCode)
const redshiftDataClient = this.redshiftDataClientProvider(this.regionCode)
const warehouseType = connectionParams.warehouseType
const warehouseIdentifier = connectionParams.warehouseIdentifier
const input: RedshiftData.ListTablesRequest = {
const input: ListTablesRequest = {
ClusterIdentifier: warehouseType === RedshiftWarehouseType.PROVISIONED ? warehouseIdentifier : undefined,
DbUser:
connectionParams.username && connectionParams.connectionType !== ConnectionType.DatabaseUser
Expand All @@ -140,7 +161,7 @@ export class DefaultRedshiftClient {
? connectionParams.secret
: undefined,
}
const ListTablesResponse = redshiftDataClient.listTables(input).promise()
const ListTablesResponse = redshiftDataClient.send(new ListTablesCommand(input))
return ListTablesResponse
}

Expand All @@ -150,11 +171,11 @@ export class DefaultRedshiftClient {
nextToken?: string,
executionId?: string
): Promise<ExecuteQueryResponse | undefined> {
const redshiftData = await this.redshiftDataClientProvider(this.regionCode)
const redshiftData = this.redshiftDataClientProvider(this.regionCode)
// if executionId is not passed in, that means that we're executing and retrieving the results of the query for the first time.
if (!executionId) {
const execution = await redshiftData
.executeStatement({
const execution = await redshiftData.send(
new ExecuteStatementCommand({
ClusterIdentifier:
connectionParams.warehouseType === RedshiftWarehouseType.PROVISIONED
? connectionParams.warehouseIdentifier
Expand All @@ -174,15 +195,15 @@ export class DefaultRedshiftClient {
? connectionParams.secret
: undefined,
})
.promise()
)

executionId = execution.Id
type Status = 'RUNNING' | 'FAILED' | 'FINISHED'
let status: Status = 'RUNNING'
while (status === 'RUNNING') {
const describeStatementResponse = await redshiftData
.describeStatement({ Id: executionId } as DescribeStatementRequest)
.promise()
const describeStatementResponse = await redshiftData.send(
new DescribeStatementCommand({ Id: executionId } as DescribeStatementRequest)
)
if (describeStatementResponse.Status === 'FAILED' || describeStatementResponse.Status === 'FINISHED') {
status = describeStatementResponse.Status
if (status === 'FAILED') {
Expand All @@ -198,9 +219,9 @@ export class DefaultRedshiftClient {
}
}
}
const result = await redshiftData
.getStatementResult({ Id: executionId, NextToken: nextToken } as GetStatementResultRequest)
.promise()
const result = await redshiftData.send(
new GetStatementResultCommand({ Id: executionId, NextToken: nextToken } as GetStatementResultRequest)
)

return { statementResultResponse: result, executionId: executionId } as ExecuteQueryResponse
}
Expand All @@ -210,20 +231,20 @@ export class DefaultRedshiftClient {
connectionParams: ConnectionParams
): Promise<ClusterCredentials | GetCredentialsResponse> {
if (warehouseType === RedshiftWarehouseType.PROVISIONED) {
const redshiftClient = await this.redshiftClientProvider(this.regionCode)
const redshiftClient = this.redshiftClientProvider(this.regionCode)
const getClusterCredentialsRequest: GetClusterCredentialsMessage = {
DbUser: connectionParams.username!,
DbName: connectionParams.database,
ClusterIdentifier: connectionParams.warehouseIdentifier,
}
return redshiftClient.getClusterCredentials(getClusterCredentialsRequest).promise()
return redshiftClient.send(new GetClusterCredentialsCommand(getClusterCredentialsRequest))
} else {
const redshiftServerless = await this.redshiftServerlessClientProvider(this.regionCode)
const redshiftServerless = this.redshiftServerlessClientProvider(this.regionCode)
const getCredentialsRequest: GetCredentialsRequest = {
dbName: connectionParams.database,
workgroupName: connectionParams.warehouseIdentifier,
}
return redshiftServerless.getCredentials(getCredentialsRequest).promise()
return redshiftServerless.send(new GetCredentialsCommand(getCredentialsRequest))
}
}
public genUniqueId(connectionParams: ConnectionParams): string {
Expand Down Expand Up @@ -258,13 +279,22 @@ export class DefaultRedshiftClient {
}
}

async function createRedshiftSdkClient(regionCode: string): Promise<Redshift> {
return await globals.sdkClientBuilder.createAwsService(Redshift, { computeChecksums: true }, regionCode)
function createRedshiftSdkClient(regionCode: string): RedshiftClient {
return globals.sdkClientBuilderV3.createAwsService({
serviceClient: RedshiftClient,
clientOptions: { region: regionCode },
})
}

async function createRedshiftServerlessSdkClient(regionCode: string): Promise<RedshiftServerless> {
return await globals.sdkClientBuilder.createAwsService(RedshiftServerless, { computeChecksums: true }, regionCode)
function createRedshiftServerlessSdkClient(regionCode: string): RedshiftServerlessClient {
return globals.sdkClientBuilderV3.createAwsService({
serviceClient: RedshiftServerlessClient,
clientOptions: { region: regionCode },
})
}
async function createRedshiftDataClient(regionCode: string): Promise<RedshiftData> {
return await globals.sdkClientBuilder.createAwsService(RedshiftData, { computeChecksums: true }, regionCode)
function createRedshiftDataClient(regionCode: string): RedshiftDataClient {
return globals.sdkClientBuilderV3.createAwsService({
serviceClient: RedshiftDataClient,
clientOptions: { region: regionCode },
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
* SPDX-License-Identifier: Apache-2.0
*/

import sinon = require('sinon')
import { mockClient } from 'aws-sdk-client-mock'
import { RedshiftDatabaseNode } from '../../../../awsService/redshift/explorer/redshiftDatabaseNode'
import { RedshiftData } from 'aws-sdk'
import { RedshiftDataClient, ListSchemasCommand } from '@aws-sdk/client-redshift-data'
import { DefaultRedshiftClient } from '../../../../shared/clients/redshiftClient'
import { ConnectionParams, ConnectionType, RedshiftWarehouseType } from '../../../../awsService/redshift/models/models'
import assert = require('assert')
Expand All @@ -14,44 +14,37 @@ import { AWSTreeNodeBase } from '../../../../shared/treeview/nodes/awsTreeNodeBa
import { MoreResultsNode } from '../../../../awsexplorer/moreResultsNode'

describe('RedshiftDatabaseNode', function () {
const sandbox = sinon.createSandbox()
const mockRedshiftData = <RedshiftData>{}
const redshiftClient = new DefaultRedshiftClient('us-east-1', async () => mockRedshiftData, undefined, undefined)
const mockRedshiftData = mockClient(RedshiftDataClient)
const redshiftClient = new DefaultRedshiftClient('us-east-1', () => mockRedshiftData as any, undefined, undefined)
const connectionParams = new ConnectionParams(
ConnectionType.TempCreds,
'testDb1',
'warehouseId',
RedshiftWarehouseType.PROVISIONED
)
let listSchemasStub: sinon.SinonStub

describe('getChildren', function () {
beforeEach(function () {
listSchemasStub = sandbox.stub()
mockRedshiftData.listSchemas = listSchemasStub
})

afterEach(function () {
sandbox.reset()
mockRedshiftData.reset()
})

it('loads schemas successfully', async () => {
const node = new RedshiftDatabaseNode('testDB1', redshiftClient, connectionParams)
listSchemasStub.returns({ promise: () => Promise.resolve({ Schemas: ['schema1'] }) })
mockRedshiftData.on(ListSchemasCommand).resolves({ Schemas: ['schema1'] })
const childNodes = await node.getChildren()
verifyChildNodes(childNodes, false)
})

it('loads schemas and shows load more node when there are more schemas', async () => {
const node = new RedshiftDatabaseNode('testDB1', redshiftClient, connectionParams)
listSchemasStub.returns({ promise: () => Promise.resolve({ Schemas: ['schema1'], NextToken: 'next' }) })
mockRedshiftData.on(ListSchemasCommand).resolves({ Schemas: ['schema1'], NextToken: 'next' })
const childNodes = await node.getChildren()
verifyChildNodes(childNodes, true)
})

it('shows error node when listSchema fails', async () => {
const node = new RedshiftDatabaseNode('testDB1', redshiftClient, connectionParams)
listSchemasStub.returns({ promise: () => Promise.reject('Failed') })
mockRedshiftData.on(ListSchemasCommand).rejects('Failed')
const childNodes = await node.getChildren()
assert.strictEqual(childNodes.length, 1)
assert.strictEqual(childNodes[0].contextValue, 'awsErrorNode')
Expand Down
Loading
Loading