diff --git a/packages/core/resources/sagemaker_connect b/packages/core/resources/sagemaker_connect index a1b7c9c0db0..19d0e1984cc 100755 --- a/packages/core/resources/sagemaker_connect +++ b/packages/core/resources/sagemaker_connect @@ -9,13 +9,16 @@ _get_ssm_session_info() { local url_to_get_session_info="http://localhost:${local_endpoint_port}/get_session?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}" + # Generate unique temporary file name to avoid conflicts + local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json" + # Use curl with --write-out to capture HTTP status - response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info") + response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info") http_status="${response: -3}" - session_json=$(cat /tmp/ssm_session_response.json) + session_json=$(cat "$temp_file") # Clean up temporary file - rm -f /tmp/ssm_session_response.json + rm -f "$temp_file" if [[ "$http_status" -ne 200 ]]; then echo "Error: Failed to get SSM session info. HTTP status: $http_status" @@ -40,16 +43,21 @@ _get_ssm_session_info_async() { local url_base="http://localhost:${local_endpoint_port}/get_session_async" local url_to_get_session_info="${url_base}?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}&request_id=${request_id}" + # Generate unique temporary file name to avoid conflicts + local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json" + local max_retries=60 local retry_interval=5 local attempt=1 while (( attempt <= max_retries )); do - response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info") + response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info") http_status="${response: -3}" - session_json=$(cat /tmp/ssm_session_response.json) + session_json=$(cat "$temp_file") if [[ "$http_status" -eq 200 ]]; then + # Clean up temporary file on success + rm -f "$temp_file" export SSM_SESSION_JSON="$session_json" return 0 elif [[ "$http_status" -eq 202 || "$http_status" -eq 204 ]]; then @@ -59,10 +67,14 @@ _get_ssm_session_info_async() { else echo "Error: Failed to get SSM session info. HTTP status: $http_status" echo "Response: $session_json" + # Clean up temporary file on error + rm -f "$temp_file" exit 1 fi done + # Clean up temporary file on timeout + rm -f "$temp_file" echo "Error: Timed out after $max_retries attempts waiting for session to be ready." exit 1 } diff --git a/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts b/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts index e59b1b9dd10..f8dad504067 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts @@ -50,7 +50,7 @@ export async function handleGetSessionAsync(req: IncomingMessage, res: ServerRes const refreshUrl = await store.getRefreshUrl(connectionIdentifier) const { spaceName } = parseArn(connectionIdentifier) - const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?reconnect_identifier=${encodeURIComponent( + const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?remote_access_token_refresh=true&reconnect_identifier=${encodeURIComponent( connectionIdentifier )}&reconnect_request_id=${encodeURIComponent(requestId)}&reconnect_callback_url=${encodeURIComponent( `http://localhost:${serverInfo.port}/refresh_token` diff --git a/packages/core/src/awsService/sagemaker/detached-server/utils.ts b/packages/core/src/awsService/sagemaker/detached-server/utils.ts index 9ac6b6b0303..de01041d4ad 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/utils.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/utils.ts @@ -18,6 +18,10 @@ export { open } export const mappingFilePath = join(os.homedir(), '.aws', '.sagemaker-space-profiles') const tempFilePath = `${mappingFilePath}.tmp` +// Simple file lock to prevent concurrent writes +let isWriting = false +const writeQueue: Array<() => Promise> = [] + /** * Reads the local endpoint info file (default or via env) and returns pid & port. * @throws Error if the file is missing, invalid JSON, or missing fields @@ -100,14 +104,48 @@ export async function readMapping() { } /** - * Writes the mapping to a temp file and atomically renames it to the target path. + * Processes the write queue to ensure only one write operation happens at a time. */ -export async function writeMapping(mapping: SpaceMappings) { +async function processWriteQueue() { + if (isWriting || writeQueue.length === 0) { + return + } + + isWriting = true try { - const json = JSON.stringify(mapping, undefined, 2) - await fs.writeFile(tempFilePath, json) - await fs.rename(tempFilePath, mappingFilePath) - } catch (err) { - throw new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`) + while (writeQueue.length > 0) { + const writeOperation = writeQueue.shift()! + await writeOperation() + } + } finally { + isWriting = false } } + +/** + * Writes the mapping to a temp file and atomically renames it to the target path. + * Uses a queue to prevent race conditions when multiple requests try to write simultaneously. + */ +export async function writeMapping(mapping: SpaceMappings) { + return new Promise((resolve, reject) => { + const writeOperation = async () => { + try { + // Generate unique temp file name to avoid conflicts + const uniqueTempPath = `${tempFilePath}.${process.pid}.${Date.now()}` + + const json = JSON.stringify(mapping, undefined, 2) + await fs.writeFile(uniqueTempPath, json) + await fs.rename(uniqueTempPath, mappingFilePath) + resolve() + } catch (err) { + reject(new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`)) + } + } + + writeQueue.push(writeOperation) + + // ProcessWriteQueue handles its own errors via individual operation callbacks + // eslint-disable-next-line @typescript-eslint/no-floating-promises + processWriteQueue() + }) +} diff --git a/packages/core/src/test/awsService/sagemaker/detached-server/utils.test.ts b/packages/core/src/test/awsService/sagemaker/detached-server/utils.test.ts index 1eeb5708d11..bc8d0a8867b 100644 --- a/packages/core/src/test/awsService/sagemaker/detached-server/utils.test.ts +++ b/packages/core/src/test/awsService/sagemaker/detached-server/utils.test.ts @@ -3,8 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ +/* eslint-disable no-restricted-imports */ import * as assert from 'assert' -import { parseArn } from '../../../../awsService/sagemaker/detached-server/utils' +import { parseArn, writeMapping, readMapping } from '../../../../awsService/sagemaker/detached-server/utils' +import { promises as fs } from 'fs' +import * as path from 'path' +import * as os from 'os' +import { SpaceMappings } from '../../../../awsService/sagemaker/types' describe('parseArn', () => { it('parses a standard SageMaker ARN with forward slash', () => { @@ -37,3 +42,71 @@ describe('parseArn', () => { assert.throws(() => parseArn(invalidArn), /Invalid SageMaker ARN format/) }) }) + +describe('writeMapping', () => { + let testDir: string + + beforeEach(async () => { + testDir = await fs.mkdtemp(path.join(os.tmpdir(), 'sagemaker-test-')) + }) + + afterEach(async () => { + await fs.rmdir(testDir, { recursive: true }) + }) + + it('handles concurrent writes without race conditions', async () => { + const mapping1: SpaceMappings = { + localCredential: { + 'space-1': { type: 'iam', profileName: 'profile1' }, + }, + } + const mapping2: SpaceMappings = { + localCredential: { + 'space-2': { type: 'iam', profileName: 'profile2' }, + }, + } + const mapping3: SpaceMappings = { + deepLink: { + 'space-3': { + requests: { + req1: { + sessionId: 'session-456', + url: 'wss://example3.com', + token: 'token-456', + }, + }, + refreshUrl: 'https://example3.com/refresh', + }, + }, + } + + const writePromises = [writeMapping(mapping1), writeMapping(mapping2), writeMapping(mapping3)] + + await Promise.all(writePromises) + + const finalContent = await readMapping() + const possibleResults = [mapping1, mapping2, mapping3] + const isValidResult = possibleResults.some( + (expected) => JSON.stringify(finalContent) === JSON.stringify(expected) + ) + assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings') + }) + + it('queues multiple writes and processes them sequentially', async () => { + const mappings = Array.from({ length: 5 }, (_, i) => ({ + localCredential: { + [`space-${i}`]: { type: 'iam' as const, profileName: `profile-${i}` }, + }, + })) + + const writePromises = mappings.map((mapping) => writeMapping(mapping)) + + await Promise.all(writePromises) + + const finalContent = await readMapping() + assert.strictEqual(typeof finalContent, 'object', 'Final content should be a valid object') + + const isValidResult = mappings.some((mapping) => JSON.stringify(finalContent) === JSON.stringify(mapping)) + assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings') + }) +}) diff --git a/packages/toolkit/.changes/next-release/Bug Fix-c6e841d7-e0bb-474c-9540-f896746d26d4.json b/packages/toolkit/.changes/next-release/Bug Fix-c6e841d7-e0bb-474c-9540-f896746d26d4.json new file mode 100644 index 00000000000..18aea78cb6a --- /dev/null +++ b/packages/toolkit/.changes/next-release/Bug Fix-c6e841d7-e0bb-474c-9540-f896746d26d4.json @@ -0,0 +1,4 @@ +{ + "type": "Bug Fix", + "description": "SageMaker: Resolve race condition when reconnecting from multiple remote windows." +}