1010import static org .mockito .ArgumentMatchers .eq ;
1111import static org .mockito .ArgumentMatchers .isA ;
1212import static org .mockito .Mockito .doAnswer ;
13+ import static org .mockito .Mockito .doReturn ;
1314import static org .mockito .Mockito .doThrow ;
1415import static org .mockito .Mockito .mock ;
16+ import static org .mockito .Mockito .spy ;
1517import static org .mockito .Mockito .verify ;
1618import static org .mockito .Mockito .when ;
1719import static org .opensearch .ml .common .settings .MLCommonsSettings .ML_COMMONS_ALLOW_MODEL_URL ;
@@ -182,25 +184,27 @@ public void setup() throws IOException {
182184 );
183185 when (clusterService .getClusterSettings ()).thenReturn (clusterSettings );
184186 when (clusterService .getSettings ()).thenReturn (settings );
185- transportRegisterModelAction = new TransportRegisterModelAction (
186- transportService ,
187- actionFilters ,
188- modelHelper ,
189- mlIndicesHandler ,
190- mlModelManager ,
191- mlTaskManager ,
192- clusterService ,
193- settings ,
194- threadPool ,
195- client ,
196- sdkClient ,
197- nodeFilter ,
198- mlTaskDispatcher ,
199- mlStats ,
200- modelAccessControlHelper ,
201- connectorAccessControlHelper ,
202- mlModelGroupManager ,
203- mlFeatureEnabledSetting
187+ transportRegisterModelAction = spy (
188+ new TransportRegisterModelAction (
189+ transportService ,
190+ actionFilters ,
191+ modelHelper ,
192+ mlIndicesHandler ,
193+ mlModelManager ,
194+ mlTaskManager ,
195+ clusterService ,
196+ settings ,
197+ threadPool ,
198+ client ,
199+ sdkClient ,
200+ nodeFilter ,
201+ mlTaskDispatcher ,
202+ mlStats ,
203+ modelAccessControlHelper ,
204+ connectorAccessControlHelper ,
205+ mlModelGroupManager ,
206+ mlFeatureEnabledSetting
207+ )
204208 );
205209 assertNotNull (transportRegisterModelAction );
206210
@@ -594,6 +598,79 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi
594598 );
595599 }
596600
601+ @ Test
602+ public void test_execute_registerRemoteModel_withUntrustedEndpoint () {
603+ // Create request and input mocks
604+ MLRegisterModelRequest request = mock (MLRegisterModelRequest .class );
605+ MLRegisterModelInput input = MLRegisterModelInput
606+ .builder ()
607+ .functionName (FunctionName .REMOTE )
608+ .isHidden (false )
609+ .modelName ("test-model" )
610+ .build ();
611+
612+ // Create a proper Connector instance instead of mocking it
613+ Connector connector = mock (Connector .class );
614+ when (connector .getActionEndpoint (anyString (), any (Map .class ))).thenReturn ("https://untrusted-endpoint.com" );
615+ // Set the connector on the input
616+ input .setConnector (connector );
617+
618+ when (request .getRegisterModelInput ()).thenReturn (input );
619+
620+ // Mock super admin check
621+ doReturn (false ).when (transportRegisterModelAction ).isSuperAdminUserWrapper (any (), any ());
622+
623+ // Mock model group validation
624+ SearchResponse searchResponse = mock (SearchResponse .class );
625+ SearchHits searchHits = new SearchHits (new SearchHit [0 ], new TotalHits (0L , TotalHits .Relation .EQUAL_TO ), 0.0f );
626+ when (searchResponse .getHits ()).thenReturn (searchHits );
627+
628+ doAnswer (invocation -> {
629+ ActionListener <SearchResponse > listener = invocation .getArgument (2 );
630+ listener .onResponse (searchResponse );
631+ return null ;
632+ }).when (mlModelGroupManager ).validateUniqueModelGroupName (any (), any (), any ());
633+
634+ // Mock connector validation to throw exception
635+ doThrow (new IllegalArgumentException ("The connector endpoint provided is not trusted" )).when (connector ).validateConnectorURL (any ());
636+
637+ // Execute
638+ transportRegisterModelAction .doExecute (task , request , actionListener );
639+
640+ // Verify
641+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
642+ verify (actionListener ).onFailure (argumentCaptor .capture ());
643+ assertTrue (argumentCaptor .getValue ().getMessage ().contains ("not trusted" ));
644+ }
645+
646+ @ Test
647+ public void test_execute_registerRemoteModel_withUntrustedEndpoint_hidden_model () {
648+ MLRegisterModelRequest request = mock (MLRegisterModelRequest .class );
649+ MLRegisterModelInput input = mock (MLRegisterModelInput .class );
650+ when (request .getRegisterModelInput ()).thenReturn (input );
651+ when (input .getModelName ()).thenReturn ("Test Model" );
652+ when (input .getVersion ()).thenReturn ("1" );
653+ when (input .getModelGroupId ()).thenReturn ("modelGroupID" );
654+ when (input .getFunctionName ()).thenReturn (FunctionName .REMOTE );
655+ when (input .getIsHidden ()).thenReturn (true );
656+
657+ // Create a proper Connector instance instead of mocking it
658+ Connector connector = mock (Connector .class );
659+ when (connector .getActionEndpoint (anyString (), any (Map .class ))).thenReturn ("https://untrusted-endpoint.com" );
660+ // Set the connector on the input
661+ when (input .getConnector ()).thenReturn (connector );
662+ MLCreateConnectorResponse mlCreateConnectorResponse = mock (MLCreateConnectorResponse .class );
663+ doAnswer (invocation -> {
664+ ActionListener <MLCreateConnectorResponse > listener = invocation .getArgument (2 );
665+ listener .onResponse (mlCreateConnectorResponse );
666+ return null ;
667+ }).when (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), any (), isA (ActionListener .class ));
668+ MLRegisterModelResponse response = mock (MLRegisterModelResponse .class );
669+ transportRegisterModelAction .doExecute (task , request , actionListener );
670+ ArgumentCaptor <MLRegisterModelResponse > argumentCaptor = ArgumentCaptor .forClass (MLRegisterModelResponse .class );
671+ verify (mlModelManager ).registerMLRemoteModel (eq (sdkClient ), eq (input ), isA (MLTask .class ), eq (actionListener ));
672+ }
673+
597674 @ Test
598675 public void test_ModelNameAlreadyExists () throws IOException {
599676 when (node1 .getId ()).thenReturn ("NodeId1" );
@@ -647,6 +724,7 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException
647724 );
648725 }
649726
727+ @ Test
650728 public void test_FailureWhenSearchingModelGroupName () throws IOException {
651729 doAnswer (invocation -> {
652730 ActionListener <SearchResponse > listener = invocation .getArgument (2 );
@@ -661,6 +739,7 @@ public void test_FailureWhenSearchingModelGroupName() throws IOException {
661739 assertEquals ("Runtime exception" , argumentCaptor .getValue ().getMessage ());
662740 }
663741
742+ @ Test
664743 public void test_NoAccessWhenModelNameAlreadyExists () throws IOException {
665744
666745 SearchResponse searchResponse = createModelGroupSearchResponse (1 );
0 commit comments