Skip to content

Commit 19b2f27

Browse files
committed
fix race-condition
1 parent b39ab4e commit 19b2f27

File tree

3 files changed

+132
-13
lines changed

3 files changed

+132
-13
lines changed

packages/core/resources/sagemaker_connect

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ _get_ssm_session_info() {
99

1010
local url_to_get_session_info="http://localhost:${local_endpoint_port}/get_session?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}"
1111

12+
# Generate unique temporary file name to avoid conflicts
13+
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"
14+
1215
# Use curl with --write-out to capture HTTP status
13-
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
16+
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
1417
http_status="${response: -3}"
15-
session_json=$(cat /tmp/ssm_session_response.json)
18+
session_json=$(cat "$temp_file")
1619

1720
# Clean up temporary file
18-
rm -f /tmp/ssm_session_response.json
21+
rm -f "$temp_file"
1922

2023
if [[ "$http_status" -ne 200 ]]; then
2124
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
@@ -40,16 +43,21 @@ _get_ssm_session_info_async() {
4043
local url_base="http://localhost:${local_endpoint_port}/get_session_async"
4144
local url_to_get_session_info="${url_base}?connection_identifier=${aws_resource_arn}&credentials_type=${credentials_type}&request_id=${request_id}"
4245

46+
# Generate unique temporary file name to avoid conflicts
47+
local temp_file="/tmp/ssm_session_response_$$_$(date +%s%N).json"
48+
4349
local max_retries=60
4450
local retry_interval=5
4551
local attempt=1
4652

4753
while (( attempt <= max_retries )); do
48-
response=$(curl -s -w "%{http_code}" -o /tmp/ssm_session_response.json "$url_to_get_session_info")
54+
response=$(curl -s -w "%{http_code}" -o "$temp_file" "$url_to_get_session_info")
4955
http_status="${response: -3}"
50-
session_json=$(cat /tmp/ssm_session_response.json)
56+
session_json=$(cat "$temp_file")
5157

5258
if [[ "$http_status" -eq 200 ]]; then
59+
# Clean up temporary file on success
60+
rm -f "$temp_file"
5361
export SSM_SESSION_JSON="$session_json"
5462
return 0
5563
elif [[ "$http_status" -eq 202 || "$http_status" -eq 204 ]]; then
@@ -59,10 +67,14 @@ _get_ssm_session_info_async() {
5967
else
6068
echo "Error: Failed to get SSM session info. HTTP status: $http_status"
6169
echo "Response: $session_json"
70+
# Clean up temporary file on error
71+
rm -f "$temp_file"
6272
exit 1
6373
fi
6474
done
6575

76+
# Clean up temporary file on timeout
77+
rm -f "$temp_file"
6678
echo "Error: Timed out after $max_retries attempts waiting for session to be ready."
6779
exit 1
6880
}

packages/core/src/awsService/sagemaker/detached-server/utils.ts

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ export { open }
1818
export const mappingFilePath = join(os.homedir(), '.aws', '.sagemaker-space-profiles')
1919
const tempFilePath = `${mappingFilePath}.tmp`
2020

21+
// Simple file lock to prevent concurrent writes
22+
let isWriting = false
23+
const writeQueue: Array<() => Promise<void>> = []
24+
2125
/**
2226
* Reads the local endpoint info file (default or via env) and returns pid & port.
2327
* @throws Error if the file is missing, invalid JSON, or missing fields
@@ -100,14 +104,45 @@ export async function readMapping() {
100104
}
101105

102106
/**
103-
* Writes the mapping to a temp file and atomically renames it to the target path.
107+
* Processes the write queue to ensure only one write operation happens at a time.
104108
*/
105-
export async function writeMapping(mapping: SpaceMappings) {
109+
async function processWriteQueue() {
110+
if (isWriting || writeQueue.length === 0) {
111+
return
112+
}
113+
114+
isWriting = true
106115
try {
107-
const json = JSON.stringify(mapping, undefined, 2)
108-
await fs.writeFile(tempFilePath, json)
109-
await fs.rename(tempFilePath, mappingFilePath)
110-
} catch (err) {
111-
throw new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`)
116+
while (writeQueue.length > 0) {
117+
const writeOperation = writeQueue.shift()!
118+
await writeOperation()
119+
}
120+
} finally {
121+
isWriting = false
112122
}
113123
}
124+
125+
/**
126+
* Writes the mapping to a temp file and atomically renames it to the target path.
127+
* Uses a queue to prevent race conditions when multiple requests try to write simultaneously.
128+
*/
129+
export async function writeMapping(mapping: SpaceMappings) {
130+
return new Promise<void>((resolve, reject) => {
131+
const writeOperation = async () => {
132+
try {
133+
// Generate unique temp file name to avoid conflicts
134+
const uniqueTempPath = `${tempFilePath}.${process.pid}.${Date.now()}`
135+
136+
const json = JSON.stringify(mapping, undefined, 2)
137+
await fs.writeFile(uniqueTempPath, json)
138+
await fs.rename(uniqueTempPath, mappingFilePath)
139+
resolve()
140+
} catch (err) {
141+
reject(new Error(`Failed to write mapping file: ${err instanceof Error ? err.message : String(err)}`))
142+
}
143+
}
144+
145+
writeQueue.push(writeOperation)
146+
processWriteQueue().catch(reject)
147+
})
148+
}

packages/core/src/test/awsService/sagemaker/detached-server/utils.test.ts

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
*/
55

66
import * as assert from 'assert'
7-
import { parseArn } from '../../../../awsService/sagemaker/detached-server/utils'
7+
import { parseArn, writeMapping, readMapping } from '../../../../awsService/sagemaker/detached-server/utils'
8+
import { promises as fs } from 'fs'
9+
import * as path from 'path'
10+
import * as os from 'os'
11+
import { SpaceMappings } from '../../../../awsService/sagemaker/types'
812

913
describe('parseArn', () => {
1014
it('parses a standard SageMaker ARN with forward slash', () => {
@@ -37,3 +41,71 @@ describe('parseArn', () => {
3741
assert.throws(() => parseArn(invalidArn), /Invalid SageMaker ARN format/)
3842
})
3943
})
44+
45+
describe('writeMapping', () => {
46+
let testDir: string
47+
48+
beforeEach(async () => {
49+
testDir = await fs.mkdtemp(path.join(os.tmpdir(), 'sagemaker-test-'))
50+
})
51+
52+
afterEach(async () => {
53+
await fs.rmdir(testDir, { recursive: true })
54+
})
55+
56+
it('handles concurrent writes without race conditions', async () => {
57+
const mapping1: SpaceMappings = {
58+
localCredential: {
59+
'space-1': { type: 'iam', profileName: 'profile1' },
60+
},
61+
}
62+
const mapping2: SpaceMappings = {
63+
localCredential: {
64+
'space-2': { type: 'iam', profileName: 'profile2' },
65+
},
66+
}
67+
const mapping3: SpaceMappings = {
68+
deepLink: {
69+
'space-3': {
70+
requests: {
71+
req1: {
72+
sessionId: 'session-456',
73+
url: 'wss://example3.com',
74+
token: 'token-456',
75+
},
76+
},
77+
refreshUrl: 'https://example3.com/refresh',
78+
},
79+
},
80+
}
81+
82+
const writePromises = [writeMapping(mapping1), writeMapping(mapping2), writeMapping(mapping3)]
83+
84+
await Promise.all(writePromises)
85+
86+
const finalContent = await readMapping()
87+
const possibleResults = [mapping1, mapping2, mapping3]
88+
const isValidResult = possibleResults.some(
89+
(expected) => JSON.stringify(finalContent) === JSON.stringify(expected)
90+
)
91+
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
92+
})
93+
94+
it('queues multiple writes and processes them sequentially', async () => {
95+
const mappings = Array.from({ length: 5 }, (_, i) => ({
96+
localCredential: {
97+
[`space-${i}`]: { type: 'iam' as const, profileName: `profile-${i}` },
98+
},
99+
}))
100+
101+
const writePromises = mappings.map((mapping) => writeMapping(mapping))
102+
103+
await Promise.all(writePromises)
104+
105+
const finalContent = await readMapping()
106+
assert.strictEqual(typeof finalContent, 'object', 'Final content should be a valid object')
107+
108+
const isValidResult = mappings.some((mapping) => JSON.stringify(finalContent) === JSON.stringify(mapping))
109+
assert.strictEqual(isValidResult, true, 'Final content should match one of the written mappings')
110+
})
111+
})

0 commit comments

Comments
 (0)