@@ -3,14 +3,14 @@ import {
33 ThrottlingException ,
44 ThrottlingExceptionReason ,
55} from '@amzn/codewhisperer-streaming'
6- import { CredentialsProvider , Position } from '@aws/language-server-runtimes/server-interface'
6+ import { CredentialsProvider , Position , InitializeParams } from '@aws/language-server-runtimes/server-interface'
77import * as assert from 'assert'
88import { AWSError } from 'aws-sdk'
99import { expect } from 'chai'
1010import * as sinon from 'sinon'
1111import * as os from 'os'
1212import * as path from 'path'
13- import { BUILDER_ID_START_URL } from './constants'
13+ import { BUILDER_ID_START_URL , SAGEMAKER_UNIFIED_STUDIO_SERVICE } from './constants'
1414import {
1515 getBearerTokenFromProvider ,
1616 getEndPositionForAcceptedSuggestion ,
@@ -24,6 +24,7 @@ import {
2424 getFileExtensionName ,
2525 listFilesWithGitignore ,
2626 getOriginFromClientInfo ,
27+ getClientName ,
2728 sanitizeInput ,
2829 sanitizeRequestInput ,
2930} from './utils'
@@ -73,12 +74,86 @@ describe('getBearerTokenFromProvider', () => {
7374 } )
7475} )
7576
77+ describe ( 'getClientName' , ( ) => {
78+ let originalEnv : string | undefined
79+
80+ beforeEach ( ( ) => {
81+ originalEnv = process . env . SERVICE_NAME
82+ } )
83+
84+ afterEach ( ( ) => {
85+ if ( originalEnv !== undefined ) {
86+ process . env . SERVICE_NAME = originalEnv
87+ } else {
88+ delete process . env . SERVICE_NAME
89+ }
90+ } )
91+
92+ it ( 'returns client name from initializationOptions path when SERVICE_NAME is SageMakerUnifiedStudio' , ( ) => {
93+ process . env . SERVICE_NAME = SAGEMAKER_UNIFIED_STUDIO_SERVICE
94+ const lspParams = {
95+ initializationOptions : {
96+ aws : {
97+ clientInfo : {
98+ name : 'AmazonQ-For-SMUS-CE-1.0.0' ,
99+ } ,
100+ } ,
101+ } ,
102+ clientInfo : {
103+ name : 'VSCode-Extension' ,
104+ } ,
105+ } as InitializeParams
106+
107+ const result = getClientName ( lspParams )
108+ assert . strictEqual ( result , 'AmazonQ-For-SMUS-CE-1.0.0' )
109+ } )
110+
111+ it ( 'returns client name from clientInfo path when SERVICE_NAME is not SageMakerUnifiedStudio' , ( ) => {
112+ process . env . SERVICE_NAME = 'SomeOtherService'
113+ const lspParams = {
114+ initializationOptions : {
115+ aws : {
116+ clientInfo : {
117+ name : 'AmazonQ-For-SMUS-CE-1.0.0' ,
118+ } ,
119+ } ,
120+ } ,
121+ clientInfo : {
122+ name : 'VSCode-Extension' ,
123+ } ,
124+ } as InitializeParams
125+
126+ const result = getClientName ( lspParams )
127+ assert . strictEqual ( result , 'VSCode-Extension' )
128+ } )
129+
130+ it ( 'returns undefined when lspParams is undefined' , ( ) => {
131+ const result = getClientName ( undefined )
132+ assert . strictEqual ( result , undefined )
133+ } )
134+ } )
135+
76136describe ( 'getOriginFromClientInfo' , ( ) => {
77- it ( 'returns MD_IDE for SMUS client name' , ( ) => {
137+ it ( 'returns MD_IDE for SMUS-IDE client name' , ( ) => {
78138 const result = getOriginFromClientInfo ( 'AmazonQ-For-SMUS-IDE-1.0.0' )
79139 assert . strictEqual ( result , 'MD_IDE' )
80140 } )
81141
142+ it ( 'returns MD_IDE for SMUS-CE client name' , ( ) => {
143+ const result = getOriginFromClientInfo ( 'AmazonQ-For-SMUS-CE-1.0.0' )
144+ assert . strictEqual ( result , 'MD_IDE' )
145+ } )
146+
147+ it ( 'returns MD_IDE for client names starting with SMUS-IDE prefix' , ( ) => {
148+ const result = getOriginFromClientInfo ( 'AmazonQ-For-SMUS-IDE' )
149+ assert . strictEqual ( result , 'MD_IDE' )
150+ } )
151+
152+ it ( 'returns MD_IDE for client names starting with SMUS-CE prefix' , ( ) => {
153+ const result = getOriginFromClientInfo ( 'AmazonQ-For-SMUS-CE' )
154+ assert . strictEqual ( result , 'MD_IDE' )
155+ } )
156+
82157 it ( 'returns IDE for non-SMUS client name' , ( ) => {
83158 const result = getOriginFromClientInfo ( 'VSCode-Extension' )
84159 assert . strictEqual ( result , 'IDE' )
@@ -93,6 +168,11 @@ describe('getOriginFromClientInfo', () => {
93168 const result = getOriginFromClientInfo ( '' )
94169 assert . strictEqual ( result , 'IDE' )
95170 } )
171+
172+ it ( 'returns IDE for client names that do not match SMUS patterns' , ( ) => {
173+ const result = getOriginFromClientInfo ( 'AmazonQ-For-Other-IDE' )
174+ assert . strictEqual ( result , 'IDE' )
175+ } )
96176} )
97177
98178describe ( 'getSsoConnectionType' , ( ) => {
0 commit comments