Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions packages/core/resources/sagemaker_connect
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ _get_ssm_session_info() {

local url_to_get_session_info="http://localhost:${local_endpoint_port}/get_session?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}"

# Generate unique temporary file name to avoid conflicts
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"

# Use curl with --write-out to capture HTTP status
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
http_status="${response: -3}"
session_json=$(cat /tmp/ssm_session_response.json)
session_json=$(cat "$temp_file")

# Clean up temporary file
rm -f /tmp/ssm_session_response.json
rm -f "$temp_file"

if [[ "$http_status" -ne 200 ]]; then
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
Expand All @@ -40,16 +43,21 @@ _get_ssm_session_info_async() {
local url_base="http://localhost:${local_endpoint_port}/get_session_async"
local url_to_get_session_info="${url_base}?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}&request_id=${request_id}"

# Generate unique temporary file name to avoid conflicts
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"

local max_retries=60
local retry_interval=5
local attempt=1

while (( attempt <= max_retries )); do
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
http_status="${response: -3}"
session_json=$(cat /tmp/ssm_session_response.json)
session_json=$(cat "$temp_file")

if [[ "$http_status" -eq 200 ]]; then
# Clean up temporary file on success
rm -f "$temp_file"
export SSM_SESSION_JSON="$session_json"
return 0
elif [[ "$http_status" -eq 202 || "$http_status" -eq 204 ]]; then
Expand All @@ -59,10 +67,14 @@ _get_ssm_session_info_async() {
else
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
echo "Response: $session_json"
# Clean up temporary file on error
rm -f "$temp_file"
exit 1
fi
done

# Clean up temporary file on timeout
rm -f "$temp_file"
echo "Error: Timed out after $max_retries attempts waiting for session to be ready."
exit 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export async function handleGetSessionAsync(req: IncomingMessage, res: ServerRes
const refreshUrl = await store.getRefreshUrl(connectionIdentifier)
const { spaceName } = parseArn(connectionIdentifier)

const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?reconnect_identifier=${encodeURIComponent(
const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?remote_access_token_refresh=true&reconnect_identifier=${encodeURIComponent(
connectionIdentifier
)}&reconnect_request_id=${encodeURIComponent(requestId)}&reconnect_callback_url=${encodeURIComponent(
`http://localhost:${serverInfo.port}/refresh_token`
Expand Down
52 changes: 45 additions & 7 deletions packages/core/src/awsService/sagemaker/detached-server/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export { open }
export const mappingFilePath = join(os.homedir(), '.aws', '.sagemaker-space-profiles')
const tempFilePath = `${mappingFilePath}.tmp`

// Simple file lock to prevent concurrent writes
let isWriting = false
const writeQueue: Array<() => Promise<void>> = []

/**
* Reads the local endpoint info file (default or via env) and returns pid & port.
* @throws Error if the file is missing, invalid JSON, or missing fields
Expand Down Expand Up @@ -100,14 +104,48 @@ export async function readMapping() {
}

/**
* Writes the mapping to a temp file and atomically renames it to the target path.
* Processes the write queue to ensure only one write operation happens at a time.
*/
export async function writeMapping(mapping: SpaceMappings) {
async function processWriteQueue() {
if (isWriting || writeQueue.length === 0) {
return
}

isWriting = true
try {
const json = JSON.stringify(mapping, undefined, 2)
await fs.writeFile(tempFilePath, json)
await fs.rename(tempFilePath, mappingFilePath)
} catch (err) {
throw new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`)
while (writeQueue.length > 0) {
const writeOperation = writeQueue.shift()!
await writeOperation()
}
} finally {
isWriting = false
}
}

/**
* Writes the mapping to a temp file and atomically renames it to the target path.
* Uses a queue to prevent race conditions when multiple requests try to write simultaneously.
*/
export async function writeMapping(mapping: SpaceMappings) {
return new Promise<void>((resolve, reject) => {
const writeOperation = async () => {
try {
// Generate unique temp file name to avoid conflicts
const uniqueTempPath = `${tempFilePath}.${process.pid}.${Date.now()}`

const json = JSON.stringify(mapping, undefined, 2)
await fs.writeFile(uniqueTempPath, json)
await fs.rename(uniqueTempPath, mappingFilePath)
resolve()
} catch (err) {
reject(new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`))
}
}

writeQueue.push(writeOperation)

// ProcessWriteQueue handles its own errors via individual operation callbacks
// eslint-disable-next-line @typescript-eslint/no-floating-promises
processWriteQueue()
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
* SPDX-License-Identifier: Apache-2.0
*/

/* eslint-disable no-restricted-imports */
import * as assert from 'assert'
import { parseArn } from '../../../../awsService/sagemaker/detached-server/utils'
import { parseArn, writeMapping, readMapping } from '../../../../awsService/sagemaker/detached-server/utils'
import { promises as fs } from 'fs'
import * as path from 'path'
import * as os from 'os'
import { SpaceMappings } from '../../../../awsService/sagemaker/types'

describe('parseArn', () => {
it('parses a standard SageMaker ARN with forward slash', () => {
Expand Down Expand Up @@ -37,3 +42,71 @@ describe('parseArn', () => {
assert.throws(() => parseArn(invalidArn), /Invalid SageMaker ARN format/)
})
})

describe('writeMapping', () => {
let testDir: string

beforeEach(async () => {
testDir = await fs.mkdtemp(path.join(os.tmpdir(), 'sagemaker-test-'))
})

afterEach(async () => {
await fs.rmdir(testDir, { recursive: true })
})

it('handles concurrent writes without race conditions', async () => {
const mapping1: SpaceMappings = {
localCredential: {
'space-1': { type: 'iam', profileName: 'profile1' },
},
}
const mapping2: SpaceMappings = {
localCredential: {
'space-2': { type: 'iam', profileName: 'profile2' },
},
}
const mapping3: SpaceMappings = {
deepLink: {
'space-3': {
requests: {
req1: {
sessionId: 'session-456',
url: 'wss://example3.com',
token: 'token-456',
},
},
refreshUrl: 'https://example3.com/refresh',
},
},
}

const writePromises = [writeMapping(mapping1), writeMapping(mapping2), writeMapping(mapping3)]

await Promise.all(writePromises)

const finalContent = await readMapping()
const possibleResults = [mapping1, mapping2, mapping3]
const isValidResult = possibleResults.some(
(expected) => JSON.stringify(finalContent) === JSON.stringify(expected)
)
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
})

it('queues multiple writes and processes them sequentially', async () => {
const mappings = Array.from({ length: 5 }, (_, i) => ({
localCredential: {
[`space-${i}`]: { type: 'iam' as const, profileName: `profile-${i}` },
},
}))

const writePromises = mappings.map((mapping) => writeMapping(mapping))

await Promise.all(writePromises)

const finalContent = await readMapping()
assert.strictEqual(typeof finalContent, 'object', 'Final content should be a valid object')

const isValidResult = mappings.some((mapping) => JSON.stringify(finalContent) === JSON.stringify(mapping))
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
})
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "Bug Fix",
"description": "SageMaker: Resolve race condition when reconnecting from multiple remote windows."
}
Loading