Skip to content

Commit 2eb0891

Browse files
authored
fix(sagemaker): Fix race-condition with multiple remote spaces trying to reconnect (#7684)
## Problem - When reconnecting to multiple SageMaker Spaces (either via deeplink or from within the VS Code extension), a **race condition** occurs when writing to shared temporary files. This can cause the local SageMaker server to crash due to concurrent access. - Need clearer error messaging when reconnection to a deeplinked space is attempted without an active Studio login. ## Solution - For connections initiated from the VS Code extension, we generate **unique temporary files** to read the response json. - For deeplink-based reconnections, we introduce a **queue** to process session requests sequentially. - Add `remote_access_token_refresh` flag to the refresh URL to enable the Studio server to return more specific error messages. --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license.
1 parent b39ab4e commit 2eb0891

File tree

5 files changed

+141
-14
lines changed

5 files changed

+141
-14
lines changed

packages/core/resources/sagemaker_connect

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ _get_ssm_session_info() {
99

1010
local url_to_get_session_info="http://localhost:${local_endpoint_port}/get_session?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}"
1111

12+
# Generate unique temporary file name to avoid conflicts
13+
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"
14+
1215
# Use curl with --write-out to capture HTTP status
13-
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
16+
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
1417
http_status="${response: -3}"
15-
session_json=$(cat /tmp/ssm_session_response.json)
18+
session_json=$(cat "$temp_file")
1619

1720
# Clean up temporary file
18-
rm -f /tmp/ssm_session_response.json
21+
rm -f "$temp_file"
1922

2023
if [[ "$http_status" -ne 200 ]]; then
2124
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
@@ -40,16 +43,21 @@ _get_ssm_session_info_async() {
4043
local url_base="http://localhost:${local_endpoint_port}/get_session_async"
4144
local url_to_get_session_info="${url_base}?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}&request_id=${request_id}"
4245

46+
# Generate unique temporary file name to avoid conflicts
47+
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"
48+
4349
local max_retries=60
4450
local retry_interval=5
4551
local attempt=1
4652

4753
while (( attempt <= max_retries )); do
48-
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
54+
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
4955
http_status="${response: -3}"
50-
session_json=$(cat /tmp/ssm_session_response.json)
56+
session_json=$(cat "$temp_file")
5157

5258
if [[ "$http_status" -eq 200 ]]; then
59+
# Clean up temporary file on success
60+
rm -f "$temp_file"
5361
export SSM_SESSION_JSON="$session_json"
5462
return 0
5563
elif [[ "$http_status" -eq 202 || "$http_status" -eq 204 ]]; then
@@ -59,10 +67,14 @@ _get_ssm_session_info_async() {
5967
else
6068
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
6169
echo "Response: $session_json"
70+
# Clean up temporary file on error
71+
rm -f "$temp_file"
6272
exit 1
6373
fi
6474
done
6575

76+
# Clean up temporary file on timeout
77+
rm -f "$temp_file"
6678
echo "Error: Timed out after $max_retries attempts waiting for session to be ready."
6779
exit 1
6880
}

packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export async function handleGetSessionAsync(req: IncomingMessage, res: ServerRes
5050
const refreshUrl = await store.getRefreshUrl(connectionIdentifier)
5151
const { spaceName } = parseArn(connectionIdentifier)
5252

53-
const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?reconnect_identifier=${encodeURIComponent(
53+
const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?remote_access_token_refresh=true&reconnect_identifier=${encodeURIComponent(
5454
connectionIdentifier
5555
)}&reconnect_request_id=${encodeURIComponent(requestId)}&reconnect_callback_url=${encodeURIComponent(
5656
`http://localhost:${serverInfo.port}/refresh_token`

packages/core/src/awsService/sagemaker/detached-server/utils.ts

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ export { open }
1818
export const mappingFilePath = join(os.homedir(), '.aws', '.sagemaker-space-profiles')
1919
const tempFilePath = `${mappingFilePath}.tmp`
2020

21+
// Simple file lock to prevent concurrent writes
22+
let isWriting = false
23+
const writeQueue: Array<() => Promise<void>> = []
24+
2125
/**
2226
* Reads the local endpoint info file (default or via env) and returns pid & port.
2327
* @throws Error if the file is missing, invalid JSON, or missing fields
@@ -100,14 +104,48 @@ export async function readMapping() {
100104
}
101105

102106
/**
103-
* Writes the mapping to a temp file and atomically renames it to the target path.
107+
* Processes the write queue to ensure only one write operation happens at a time.
104108
*/
105-
export async function writeMapping(mapping: SpaceMappings) {
109+
async function processWriteQueue() {
110+
if (isWriting || writeQueue.length === 0) {
111+
return
112+
}
113+
114+
isWriting = true
106115
try {
107-
const json = JSON.stringify(mapping, undefined, 2)
108-
await fs.writeFile(tempFilePath, json)
109-
await fs.rename(tempFilePath, mappingFilePath)
110-
} catch (err) {
111-
throw new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`)
116+
while (writeQueue.length > 0) {
117+
const writeOperation = writeQueue.shift()!
118+
await writeOperation()
119+
}
120+
} finally {
121+
isWriting = false
112122
}
113123
}
124+
125+
/**
126+
* Writes the mapping to a temp file and atomically renames it to the target path.
127+
* Uses a queue to prevent race conditions when multiple requests try to write simultaneously.
128+
*/
129+
export async function writeMapping(mapping: SpaceMappings) {
130+
return new Promise<void>((resolve, reject) => {
131+
const writeOperation = async () => {
132+
try {
133+
// Generate unique temp file name to avoid conflicts
134+
const uniqueTempPath = `${tempFilePath}.${process.pid}.${Date.now()}`
135+
136+
const json = JSON.stringify(mapping, undefined, 2)
137+
await fs.writeFile(uniqueTempPath, json)
138+
await fs.rename(uniqueTempPath, mappingFilePath)
139+
resolve()
140+
} catch (err) {
141+
reject(new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`))
142+
}
143+
}
144+
145+
writeQueue.push(writeOperation)
146+
147+
// ProcessWriteQueue handles its own errors via individual operation callbacks
148+
// eslint-disable-next-line @typescript-eslint/no-floating-promises
149+
processWriteQueue()
150+
})
151+
}

packages/core/src/test/awsService/sagemaker/detached-server/utils.test.ts

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6+
/* eslint-disable no-restricted-imports */
67
import * as assert from 'assert'
7-
import { parseArn } from '../../../../awsService/sagemaker/detached-server/utils'
8+
import { parseArn, writeMapping, readMapping } from '../../../../awsService/sagemaker/detached-server/utils'
9+
import { promises as fs } from 'fs'
10+
import * as path from 'path'
11+
import * as os from 'os'
12+
import { SpaceMappings } from '../../../../awsService/sagemaker/types'
813

914
describe('parseArn', () => {
1015
it('parses a standard SageMaker ARN with forward slash', () => {
@@ -37,3 +42,71 @@ describe('parseArn', () => {
3742
assert.throws(() => parseArn(invalidArn), /Invalid SageMaker ARN format/)
3843
})
3944
})
45+
46+
describe('writeMapping', () => {
47+
let testDir: string
48+
49+
beforeEach(async () => {
50+
testDir = await fs.mkdtemp(path.join(os.tmpdir(), 'sagemaker-test-'))
51+
})
52+
53+
afterEach(async () => {
54+
await fs.rmdir(testDir, { recursive: true })
55+
})
56+
57+
it('handles concurrent writes without race conditions', async () => {
58+
const mapping1: SpaceMappings = {
59+
localCredential: {
60+
'space-1': { type: 'iam', profileName: 'profile1' },
61+
},
62+
}
63+
const mapping2: SpaceMappings = {
64+
localCredential: {
65+
'space-2': { type: 'iam', profileName: 'profile2' },
66+
},
67+
}
68+
const mapping3: SpaceMappings = {
69+
deepLink: {
70+
'space-3': {
71+
requests: {
72+
req1: {
73+
sessionId: 'session-456',
74+
url: 'wss://example3.com',
75+
token: 'token-456',
76+
},
77+
},
78+
refreshUrl: 'https://example3.com/refresh',
79+
},
80+
},
81+
}
82+
83+
const writePromises = [writeMapping(mapping1), writeMapping(mapping2), writeMapping(mapping3)]
84+
85+
await Promise.all(writePromises)
86+
87+
const finalContent = await readMapping()
88+
const possibleResults = [mapping1, mapping2, mapping3]
89+
const isValidResult = possibleResults.some(
90+
(expected) => JSON.stringify(finalContent) === JSON.stringify(expected)
91+
)
92+
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
93+
})
94+
95+
it('queues multiple writes and processes them sequentially', async () => {
96+
const mappings = Array.from({ length: 5 }, (_, i) => ({
97+
localCredential: {
98+
[`space-${i}`]: { type: 'iam' as const, profileName: `profile-${i}` },
99+
},
100+
}))
101+
102+
const writePromises = mappings.map((mapping) => writeMapping(mapping))
103+
104+
await Promise.all(writePromises)
105+
106+
const finalContent = await readMapping()
107+
assert.strictEqual(typeof finalContent, 'object', 'Final content should be a valid object')
108+
109+
const isValidResult = mappings.some((mapping) => JSON.stringify(finalContent) === JSON.stringify(mapping))
110+
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
111+
})
112+
})
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "Bug Fix",
3+
"description": "SageMaker: Resolve race condition when reconnecting from multiple remote windows."
4+
}

0 commit comments

Comments
 (0)