66package org .opensearch .ml .engine .algorithms .remote ;
77
88import static org .junit .Assert .assertEquals ;
9+ import static org .junit .Assert .assertNull ;
910import static org .mockito .ArgumentMatchers .any ;
1011import static org .mockito .Mockito .doThrow ;
1112import static org .mockito .Mockito .mock ;
12- import static org .mockito .Mockito .spy ;
1313import static org .mockito .Mockito .verify ;
1414import static org .mockito .Mockito .when ;
1515
3939import org .opensearch .ml .engine .MLEngineClassLoader ;
4040import org .opensearch .ml .engine .MLStaticMockBase ;
4141import org .opensearch .ml .engine .encryptor .Encryptor ;
42- import org .opensearch .ml .engine .encryptor .EncryptorImpl ;
4342
4443import com .google .common .collect .ImmutableMap ;
4544
@@ -64,7 +63,9 @@ public class RemoteModelTest extends MLStaticMockBase {
6463 public void setUp () {
6564 MockitoAnnotations .openMocks (this );
6665 remoteModel = new RemoteModel ();
67- encryptor = spy (new EncryptorImpl (null , "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=" ));
66+
67+ encryptor = mock (Encryptor .class );
68+ when (encryptor .decrypt (any (), any ())).thenReturn ("test_api_key" );
6869 }
6970
7071 @ Test
@@ -189,7 +190,7 @@ public void initModel_NullHeader() {
189190 when (mlModel .getConnector ()).thenReturn (connector );
190191 remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
191192 Map <String , String > decryptedHeaders = connector .getDecryptedHeaders ();
192- Assert . assertNull (decryptedHeaders );
193+ assertNull (decryptedHeaders );
193194 }
194195
195196 @ Test
@@ -200,12 +201,59 @@ public void initModel_WithHeader() {
200201 Map <String , String > decryptedHeaders = connector .getDecryptedHeaders ();
201202 RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
202203 Assert .assertNotNull (executor );
203- Assert . assertNull (decryptedHeaders );
204+ assertNull (decryptedHeaders );
204205 Assert .assertNotNull (executor .getConnector ().getDecryptedHeaders ());
205206 assertEquals (1 , executor .getConnector ().getDecryptedHeaders ().size ());
206207 assertEquals ("Bearer test_api_key" , executor .getConnector ().getDecryptedHeaders ().get ("Authorization" ));
207208 remoteModel .close ();
208- Assert .assertNull (remoteModel .getConnectorExecutor ());
209+ assertNull (remoteModel .getConnectorExecutor ());
210+ }
211+
212+ @ Test
213+ public void initModel_setsTenantIdOnClonedConnector_whenMissing () {
214+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
215+ when (mlModel .getConnector ()).thenReturn (connector );
216+ when (mlModel .getTenantId ()).thenReturn ("tenantId" );
217+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
218+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
219+ remoteModel .close ();
220+ assertNull (connector .getTenantId ());
221+ assertEquals ("tenantId" , executor .getConnector ().getTenantId ());
222+ }
223+
224+ @ Test
225+ public void initModel_bothTenantIdsNull () {
226+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
227+ when (mlModel .getConnector ()).thenReturn (connector );
228+ when (mlModel .getTenantId ()).thenReturn (null );
229+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
230+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
231+ assertNull (connector .getTenantId ());
232+ assertNull (executor .getConnector ().getTenantId ());
233+ }
234+
235+ @ Test
236+ public void initModel_connectorHasTenantId () {
237+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
238+ connector .setTenantId ("connectorTenantId" );
239+ when (mlModel .getConnector ()).thenReturn (connector );
240+ when (mlModel .getTenantId ()).thenReturn (null );
241+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
242+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
243+ assertEquals ("connectorTenantId" , connector .getTenantId ());
244+ assertEquals ("connectorTenantId" , executor .getConnector ().getTenantId ());
245+ }
246+
247+ @ Test
248+ public void initModel_bothHaveTenantIds () {
249+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
250+ connector .setTenantId ("connectorTenantId" );
251+ when (mlModel .getConnector ()).thenReturn (connector );
252+ when (mlModel .getTenantId ()).thenReturn ("modelTenantId" );
253+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
254+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
255+ assertEquals ("connectorTenantId" , connector .getTenantId ());
256+ assertEquals ("connectorTenantId" , executor .getConnector ().getTenantId ());
209257 }
210258
211259 private Connector createConnector (Map <String , String > headers ) {
@@ -222,7 +270,7 @@ private Connector createConnector(Map<String, String> headers) {
222270 .name ("test connector" )
223271 .protocol (ConnectorProtocols .HTTP )
224272 .version ("1" )
225- .credential (ImmutableMap .of ("key" , encryptor . encrypt ( "test_api_key" , null ) ))
273+ .credential (ImmutableMap .of ("key" , "dummy-encrypted-value" ))
226274 .actions (Arrays .asList (predictAction ))
227275 .build ();
228276 return connector ;
0 commit comments