Skip to content

Commit a8725b4

Browse files
authored
feat: add support for SMUS Q CodeEditor client to send MD IDE origin (aws#2032)
* feat: add support for SMUS Q CodeEditor client to send MD IDE origin * refactor: move smus service name to constants file and typecase argument getClientName util
1 parent 36e161e commit a8725b4

File tree

5 files changed

+97
-7
lines changed

5 files changed

+97
-7
lines changed

server/aws-lsp-codewhisperer/src/language-server/agenticChat/agenticChatController.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ import {
124124
isUsageLimitError,
125125
isNullish,
126126
getOriginFromClientInfo,
127+
getClientName,
127128
sanitizeInput,
128129
sanitizeRequestInput,
129130
} from '../../shared/utils'
@@ -366,7 +367,7 @@ export class AgenticChatController implements ChatHandlers {
366367
this.#features.lsp
367368
)
368369
this.#mcpEventHandler = new McpEventHandler(features, telemetryService)
369-
this.#origin = getOriginFromClientInfo(this.#features.lsp.getClientInitializeParams()?.clientInfo?.name)
370+
this.#origin = getOriginFromClientInfo(getClientName(this.#features.lsp.getClientInitializeParams()))
370371
this.#activeUserTracker = ActiveUserTracker.getInstance(this.#features)
371372
}
372373

server/aws-lsp-codewhisperer/src/language-server/chat/chatSessionService.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import { AmazonQBaseServiceManager } from '../../shared/amazonQServiceManager/Ba
2727
import { loggingUtils } from '@aws/lsp-core'
2828
import { Logging } from '@aws/language-server-runtimes/server-interface'
2929
import { Features } from '../types'
30-
import { getOriginFromClientInfo, getRequestID, isUsageLimitError } from '../../shared/utils'
30+
import { getOriginFromClientInfo, getClientName, getRequestID, isUsageLimitError } from '../../shared/utils'
3131
import { enabledModelSelection } from '../../shared/utils'
3232

3333
export type ChatSessionServiceConfig = CodeWhispererStreamingClientConfig
@@ -138,7 +138,7 @@ export class ChatSessionService {
138138
this.#serviceManager = serviceManager
139139
this.#lsp = lsp
140140
this.#logging = logging
141-
this.#origin = getOriginFromClientInfo(this.#lsp?.getClientInitializeParams()?.clientInfo?.name)
141+
this.#origin = getOriginFromClientInfo(getClientName(this.#lsp?.getClientInitializeParams()))
142142
}
143143

144144
public async sendMessage(request: SendMessageCommandInput): Promise<SendMessageCommandOutput> {

server/aws-lsp-codewhisperer/src/shared/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ export const AWS_Q_ENDPOINT_URL_ENV_VAR = 'AWS_Q_ENDPOINT_URL'
1818
export const Q_CONFIGURATION_SECTION = 'aws.q'
1919
export const CODE_WHISPERER_CONFIGURATION_SECTION = 'aws.codeWhisperer'
2020

21+
export const SAGEMAKER_UNIFIED_STUDIO_SERVICE = 'SageMakerUnifiedStudio'
22+
2123
/**
2224
* Names of directories relevant to the crash reporting functionality.
2325
*

server/aws-lsp-codewhisperer/src/shared/utils.test.ts

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@ import {
33
ThrottlingException,
44
ThrottlingExceptionReason,
55
} from '@amzn/codewhisperer-streaming'
6-
import { CredentialsProvider, Position } from '@aws/language-server-runtimes/server-interface'
6+
import { CredentialsProvider, Position, InitializeParams } from '@aws/language-server-runtimes/server-interface'
77
import * as assert from 'assert'
88
import { AWSError } from 'aws-sdk'
99
import { expect } from 'chai'
1010
import * as sinon from 'sinon'
1111
import * as os from 'os'
1212
import * as path from 'path'
13-
import { BUILDER_ID_START_URL } from './constants'
13+
import { BUILDER_ID_START_URL, SAGEMAKER_UNIFIED_STUDIO_SERVICE } from './constants'
1414
import {
1515
getBearerTokenFromProvider,
1616
getEndPositionForAcceptedSuggestion,
@@ -24,6 +24,7 @@ import {
2424
getFileExtensionName,
2525
listFilesWithGitignore,
2626
getOriginFromClientInfo,
27+
getClientName,
2728
sanitizeInput,
2829
sanitizeRequestInput,
2930
} from './utils'
@@ -73,12 +74,86 @@ describe('getBearerTokenFromProvider', () => {
7374
})
7475
})
7576

77+
describe('getClientName', () => {
78+
let originalEnv: string | undefined
79+
80+
beforeEach(() => {
81+
originalEnv = process.env.SERVICE_NAME
82+
})
83+
84+
afterEach(() => {
85+
if (originalEnv !== undefined) {
86+
process.env.SERVICE_NAME = originalEnv
87+
} else {
88+
delete process.env.SERVICE_NAME
89+
}
90+
})
91+
92+
it('returns client name from initializationOptions path when SERVICE_NAME is SageMakerUnifiedStudio', () => {
93+
process.env.SERVICE_NAME = SAGEMAKER_UNIFIED_STUDIO_SERVICE
94+
const lspParams = {
95+
initializationOptions: {
96+
aws: {
97+
clientInfo: {
98+
name: 'AmazonQ-For-SMUS-CE-1.0.0',
99+
},
100+
},
101+
},
102+
clientInfo: {
103+
name: 'VSCode-Extension',
104+
},
105+
} as InitializeParams
106+
107+
const result = getClientName(lspParams)
108+
assert.strictEqual(result, 'AmazonQ-For-SMUS-CE-1.0.0')
109+
})
110+
111+
it('returns client name from clientInfo path when SERVICE_NAME is not SageMakerUnifiedStudio', () => {
112+
process.env.SERVICE_NAME = 'SomeOtherService'
113+
const lspParams = {
114+
initializationOptions: {
115+
aws: {
116+
clientInfo: {
117+
name: 'AmazonQ-For-SMUS-CE-1.0.0',
118+
},
119+
},
120+
},
121+
clientInfo: {
122+
name: 'VSCode-Extension',
123+
},
124+
} as InitializeParams
125+
126+
const result = getClientName(lspParams)
127+
assert.strictEqual(result, 'VSCode-Extension')
128+
})
129+
130+
it('returns undefined when lspParams is undefined', () => {
131+
const result = getClientName(undefined)
132+
assert.strictEqual(result, undefined)
133+
})
134+
})
135+
76136
describe('getOriginFromClientInfo', () => {
77-
it('returns MD_IDE for SMUS client name', () => {
137+
it('returns MD_IDE for SMUS-IDE client name', () => {
78138
const result = getOriginFromClientInfo('AmazonQ-For-SMUS-IDE-1.0.0')
79139
assert.strictEqual(result, 'MD_IDE')
80140
})
81141

142+
it('returns MD_IDE for SMUS-CE client name', () => {
143+
const result = getOriginFromClientInfo('AmazonQ-For-SMUS-CE-1.0.0')
144+
assert.strictEqual(result, 'MD_IDE')
145+
})
146+
147+
it('returns MD_IDE for client names starting with SMUS-IDE prefix', () => {
148+
const result = getOriginFromClientInfo('AmazonQ-For-SMUS-IDE')
149+
assert.strictEqual(result, 'MD_IDE')
150+
})
151+
152+
it('returns MD_IDE for client names starting with SMUS-CE prefix', () => {
153+
const result = getOriginFromClientInfo('AmazonQ-For-SMUS-CE')
154+
assert.strictEqual(result, 'MD_IDE')
155+
})
156+
82157
it('returns IDE for non-SMUS client name', () => {
83158
const result = getOriginFromClientInfo('VSCode-Extension')
84159
assert.strictEqual(result, 'IDE')
@@ -93,6 +168,11 @@ describe('getOriginFromClientInfo', () => {
93168
const result = getOriginFromClientInfo('')
94169
assert.strictEqual(result, 'IDE')
95170
})
171+
172+
it('returns IDE for client names that do not match SMUS patterns', () => {
173+
const result = getOriginFromClientInfo('AmazonQ-For-Other-IDE')
174+
assert.strictEqual(result, 'IDE')
175+
})
96176
})
97177

98178
describe('getSsoConnectionType', () => {

server/aws-lsp-codewhisperer/src/shared/utils.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import {
1414
crashMonitoringDirName,
1515
driveLetterRegex,
1616
MISSING_BEARER_TOKEN_ERROR,
17+
SAGEMAKER_UNIFIED_STUDIO_SERVICE,
1718
} from './constants'
1819
import {
1920
CodeWhispererStreamingServiceException,
@@ -373,8 +374,14 @@ export function getBearerTokenFromProvider(credentialsProvider: CredentialsProvi
373374
return credentials.token
374375
}
375376

377+
export function getClientName(lspParams: InitializeParams | undefined): string | undefined {
378+
return process.env.SERVICE_NAME === SAGEMAKER_UNIFIED_STUDIO_SERVICE
379+
? lspParams?.initializationOptions?.aws?.clientInfo?.name
380+
: lspParams?.clientInfo?.name
381+
}
382+
376383
export function getOriginFromClientInfo(clientName: string | undefined): Origin {
377-
if (clientName?.startsWith('AmazonQ-For-SMUS-IDE')) {
384+
if (clientName?.startsWith('AmazonQ-For-SMUS-IDE') || clientName?.startsWith('AmazonQ-For-SMUS-CE')) {
378385
return 'MD_IDE'
379386
}
380387
return 'IDE'

0 commit comments

Comments
 (0)