Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions packages/core/resources/sagemaker_connect
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ _get_ssm_session_info() {

local url_to_get_session_info="http://localhost:${local_endpoint_port}/get_session?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}"

# Generate unique temporary file name to avoid conflicts
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"

# Use curl with --write-out to capture HTTP status
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
http_status="${response: -3}"
session_json=$(cat /tmp/ssm_session_response.json)
session_json=$(cat "$temp_file")

# Clean up temporary file
rm -f /tmp/ssm_session_response.json
rm -f "$temp_file"

if [[ "$http_status" -ne 200 ]]; then
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
Expand All @@ -40,16 +43,21 @@ _get_ssm_session_info_async() {
local url_base="http://localhost:${local_endpoint_port}/get_session_async"
local url_to_get_session_info="${url_base}?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}&request_id=${request_id}"

# Generate unique temporary file name to avoid conflicts
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"

local max_retries=60
local retry_interval=5
local attempt=1

while (( attempt <= max_retries )); do
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
http_status="${response: -3}"
session_json=$(cat /tmp/ssm_session_response.json)
session_json=$(cat "$temp_file")

if [[ "$http_status" -eq 200 ]]; then
# Clean up temporary file on success
rm -f "$temp_file"
export SSM_SESSION_JSON="$session_json"
return 0
elif [[ "$http_status" -eq 202 || "$http_status" -eq 204 ]]; then
Expand All @@ -59,10 +67,14 @@ _get_ssm_session_info_async() {
else
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
echo "Response: $session_json"
# Clean up temporary file on error
rm -f "$temp_file"
exit 1
fi
done

# Clean up temporary file on timeout
rm -f "$temp_file"
echo "Error: Timed out after $max_retries attempts waiting for session to be ready."
exit 1
}
Expand Down
49 changes: 42 additions & 7 deletions packages/core/src/awsService/sagemaker/detached-server/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export { open }
export const mappingFilePath = join(os.homedir(), '.aws', '.sagemaker-space-profiles')
const tempFilePath = `${mappingFilePath}.tmp`

// Simple file lock to prevent concurrent writes
let isWriting = false
const writeQueue: Array<() => Promise<void>> = []

/**
* Reads the local endpoint info file (default or via env) and returns pid & port.
* @throws Error if the file is missing, invalid JSON, or missing fields
Expand Down Expand Up @@ -100,14 +104,45 @@ export async function readMapping() {
}

/**
* Writes the mapping to a temp file and atomically renames it to the target path.
* Processes the write queue to ensure only one write operation happens at a time.
*/
export async function writeMapping(mapping: SpaceMappings) {
async function processWriteQueue() {
if (isWriting || writeQueue.length === 0) {
return
}

isWriting = true
try {
const json = JSON.stringify(mapping, undefined, 2)
await fs.writeFile(tempFilePath, json)
await fs.rename(tempFilePath, mappingFilePath)
} catch (err) {
throw new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`)
while (writeQueue.length > 0) {
const writeOperation = writeQueue.shift()!
await writeOperation()
}
} finally {
isWriting = false
}
}

/**
* Writes the mapping to a temp file and atomically renames it to the target path.
* Uses a queue to prevent race conditions when multiple requests try to write simultaneously.
*/
export async function writeMapping(mapping: SpaceMappings) {
return new Promise<void>((resolve, reject) => {
const writeOperation = async () => {
try {
// Generate unique temp file name to avoid conflicts
const uniqueTempPath = `${tempFilePath}.${process.pid}.${Date.now()}`

const json = JSON.stringify(mapping, undefined, 2)
await fs.writeFile(uniqueTempPath, json)
await fs.rename(uniqueTempPath, mappingFilePath)
resolve()
} catch (err) {
reject(new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`))
}
}

writeQueue.push(writeOperation)
processWriteQueue().catch(reject)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

processWriteQueue().catch this catch cannot catch anything because inside processWriteQueue it does not throw.

Copy link
Contributor Author

@aws-asolidu aws-asolidu Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I’ll remove this. Since processWriteQueue uses a try/finally block there's nothing to catch, and writeOperation already handles errors and calls reject.

})
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
*/

import * as assert from 'assert'
import { parseArn } from '../../../../awsService/sagemaker/detached-server/utils'
import { parseArn, writeMapping, readMapping } from '../../../../awsService/sagemaker/detached-server/utils'
import { promises as fs } from 'fs'

Check failure on line 8 in packages/core/src/test/awsService/sagemaker/detached-server/utils.test.ts

View workflow job for this annotation

GitHub Actions / lint (18.x, stable)

'fs' import is restricted from being used. Avoid node:fs and use shared/fs/fs.ts when possible
import * as path from 'path'
import * as os from 'os'
import { SpaceMappings } from '../../../../awsService/sagemaker/types'

describe('parseArn', () => {
it('parses a standard SageMaker ARN with forward slash', () => {
Expand Down Expand Up @@ -37,3 +41,71 @@
assert.throws(() => parseArn(invalidArn), /Invalid SageMaker ARN format/)
})
})

describe('writeMapping', () => {
let testDir: string

beforeEach(async () => {
testDir = await fs.mkdtemp(path.join(os.tmpdir(), 'sagemaker-test-'))
})

afterEach(async () => {
await fs.rmdir(testDir, { recursive: true })
})

it('handles concurrent writes without race conditions', async () => {
const mapping1: SpaceMappings = {
localCredential: {
'space-1': { type: 'iam', profileName: 'profile1' },
},
}
const mapping2: SpaceMappings = {
localCredential: {
'space-2': { type: 'iam', profileName: 'profile2' },
},
}
const mapping3: SpaceMappings = {
deepLink: {
'space-3': {
requests: {
req1: {
sessionId: 'session-456',
url: 'wss://example3.com',
token: 'token-456',
},
},
refreshUrl: 'https://example3.com/refresh',
},
},
}

const writePromises = [writeMapping(mapping1), writeMapping(mapping2), writeMapping(mapping3)]

await Promise.all(writePromises)

const finalContent = await readMapping()
const possibleResults = [mapping1, mapping2, mapping3]
const isValidResult = possibleResults.some(
(expected) => JSON.stringify(finalContent) === JSON.stringify(expected)
)
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
})

it('queues multiple writes and processes them sequentially', async () => {
const mappings = Array.from({ length: 5 }, (_, i) => ({
localCredential: {
[`space-${i}`]: { type: 'iam' as const, profileName: `profile-${i}` },
},
}))

const writePromises = mappings.map((mapping) => writeMapping(mapping))

await Promise.all(writePromises)

const finalContent = await readMapping()
assert.strictEqual(typeof finalContent, 'object', 'Final content should be a valid object')

const isValidResult = mappings.some((mapping) => JSON.stringify(finalContent) === JSON.stringify(mapping))
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
})
})
Loading