77
88package org .elasticsearch .xpack .ml .rest .cat ;
99
10+ import org .elasticsearch .common .Strings ;
1011import org .elasticsearch .test .ESTestCase ;
1112import org .elasticsearch .test .rest .FakeRestRequest ;
1213import org .elasticsearch .xpack .core .ml .action .GetTrainedModelsStatsAction ;
1314import org .elasticsearch .xpack .core .ml .action .GetTrainedModelsStatsActionResponseTests ;
15+ import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsConfigTests ;
1416import org .elasticsearch .xpack .core .ml .inference .TrainedModelConfigTests ;
1517import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TrainedModelSizeStats ;
16- import org .junit .Before ;
1718
1819import java .util .List ;
1920
2021import static org .hamcrest .Matchers .is ;
22+ import static org .hamcrest .Matchers .nullValue ;
2123
2224public class RestCatTrainedModelsActionTests extends ESTestCase {
2325
24- private RestCatTrainedModelsAction action ;
25-
26- @ Before
27- public void setUpAction () {
28- action = new RestCatTrainedModelsAction ();
29- }
30-
3126 public void testBuildTableAccumulatedStats () {
27+ var action = new RestCatTrainedModelsAction (true );
28+
3229 // GetTrainedModelsStatsActionResponseTests
3330 var deployment1 = new GetTrainedModelsStatsAction .Response .TrainedModelStats (
3431 "id1" ,
@@ -48,10 +45,13 @@ public void testBuildTableAccumulatedStats() {
4845 null
4946 );
5047
51- var configs = List .of (TrainedModelConfigTests .createTestInstance ("id1" ).build ());
48+ var dataframeConfig = DataFrameAnalyticsConfigTests .createRandom ("dataframe1" );
49+ var configs = List .of (
50+ TrainedModelConfigTests .createTestInstance (deployment1 .getModelId ()).setTags (List .of (dataframeConfig .getId ())).build ()
51+ );
5252
53- var table = action .buildTable (new FakeRestRequest (), List .of (deployment1 , deployment2 ), configs , List .of ());
54- assertThat (table .getRows ().get (0 ).get (0 ).value , is ("id1" ));
53+ var table = action .buildTable (new FakeRestRequest (), List .of (deployment1 , deployment2 ), configs , List .of (dataframeConfig ));
54+ assertThat (table .getRows ().get (0 ).get (0 ).value , is (deployment1 . getModelId () ));
5555 // pipeline count
5656 assertThat (table .getRows ().get (0 ).get (9 ).value , is (4 ));
5757 // ingest count
@@ -82,5 +82,12 @@ public void testBuildTableAccumulatedStats() {
8282 .ingestFailedCount ()
8383 )
8484 );
85+ assertThat (table .getRows ().get (0 ).get (14 ).value , is (dataframeConfig .getId ()));
86+ assertThat (table .getRows ().get (0 ).get (15 ).value , is (dataframeConfig .getCreateTime ()));
87+ assertThat (table .getRows ().get (0 ).get (16 ).value , is (Strings .arrayToCommaDelimitedString (dataframeConfig .getSource ().getIndex ())));
88+ assertThat (
89+ table .getRows ().get (0 ).get (17 ).value ,
90+ dataframeConfig .getAnalysis () == null ? nullValue () : is (dataframeConfig .getAnalysis ().getWriteableName ())
91+ );
8592 }
8693}
0 commit comments