Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137220.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137220
summary: Skip dataframes when disabled
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ public List<RestHandler> getRestHandlers(
restHandlers.add(new RestDeleteTrainedModelAliasAction());
restHandlers.add(new RestPutTrainedModelDefinitionPartAction());
restHandlers.add(new RestInferTrainedModelAction());
restHandlers.add(new RestCatTrainedModelsAction());
restHandlers.add(new RestCatTrainedModelsAction(machineLearningExtension.get().isDataFrameAnalyticsEnabled()));
if (machineLearningExtension.get().isDataFrameAnalyticsEnabled()) {
restHandlers.add(new RestGetDataFrameAnalyticsAction());
restHandlers.add(new RestGetDataFrameAnalyticsStatsAction());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
@ServerlessScope(Scope.PUBLIC)
public class RestCatTrainedModelsAction extends AbstractCatAction {

private final boolean areDataFrameAnalyticsEnabled;

public RestCatTrainedModelsAction(boolean areDataFrameAnalyticsEnabled) {
this.areDataFrameAnalyticsEnabled = areDataFrameAnalyticsEnabled;
}

@Override
public List<Route> routes() {
return List.of(
Expand Down Expand Up @@ -122,14 +128,18 @@ private void getDerivedData(
listeners.acquire(response -> trainedModelsStats = response.getResources().results())
);

final var dataFrameAnalyticsRequest = new GetDataFrameAnalyticsAction.Request(requestIdPattern);
dataFrameAnalyticsRequest.setAllowNoResources(true);
dataFrameAnalyticsRequest.setPageParams(new PageParams(0, potentialAnalyticsIds.size()));
client.execute(
GetDataFrameAnalyticsAction.INSTANCE,
dataFrameAnalyticsRequest,
listeners.acquire(response -> dataFrameAnalytics = response.getResources().results())
);
if (areDataFrameAnalyticsEnabled) {
final var dataFrameAnalyticsRequest = new GetDataFrameAnalyticsAction.Request(requestIdPattern);
dataFrameAnalyticsRequest.setAllowNoResources(true);
dataFrameAnalyticsRequest.setPageParams(new PageParams(0, potentialAnalyticsIds.size()));
client.execute(
GetDataFrameAnalyticsAction.INSTANCE,
dataFrameAnalyticsRequest,
listeners.acquire(response -> dataFrameAnalytics = response.getResources().results())
);
} else {
dataFrameAnalytics = List.of();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,25 @@

package org.elasticsearch.xpack.ml.rest.cat;

import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsActionResponseTests;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.junit.Before;

import java.util.List;

import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;

public class RestCatTrainedModelsActionTests extends ESTestCase {

private RestCatTrainedModelsAction action;

@Before
public void setUpAction() {
action = new RestCatTrainedModelsAction();
}

public void testBuildTableAccumulatedStats() {
var action = new RestCatTrainedModelsAction(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a unit test for the case where areDataFrameAnalyticsEnabled is false and we don't expect to see the dataframe information in the table?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but the existing functionality isn't tested so adding the tests for both true/false would take a few weeks of calendar time, which I assume is okay since this bug had been open for months. Alternatively, the linked integration test covers the false case whereas the existing integration tests cover the true case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not just basically copy this test case where areDataFrameAnalyticsEnabled is true, switch it to false, then assert that we see __none__ as the data_frame.id value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That boolean is used within getDerivedData which is called before the method that this code tests, which is buildTable. The test only calls buildTable, so we'd only be verifying that an absent list sets __none__ which is a test we should have but isn't really relevant to this change (happy to add it, though)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah, I see, thanks for clarifying. Yeah, no need to add that test in this PR, since it's basically unrelated.


// GetTrainedModelsStatsActionResponseTests
var deployment1 = new GetTrainedModelsStatsAction.Response.TrainedModelStats(
"id1",
Expand All @@ -48,10 +45,13 @@ public void testBuildTableAccumulatedStats() {
null
);

var configs = List.of(TrainedModelConfigTests.createTestInstance("id1").build());
var dataframeConfig = DataFrameAnalyticsConfigTests.createRandom("dataframe1");
var configs = List.of(
TrainedModelConfigTests.createTestInstance(deployment1.getModelId()).setTags(List.of(dataframeConfig.getId())).build()
);

var table = action.buildTable(new FakeRestRequest(), List.of(deployment1, deployment2), configs, List.of());
assertThat(table.getRows().get(0).get(0).value, is("id1"));
var table = action.buildTable(new FakeRestRequest(), List.of(deployment1, deployment2), configs, List.of(dataframeConfig));
assertThat(table.getRows().get(0).get(0).value, is(deployment1.getModelId()));
// pipeline count
assertThat(table.getRows().get(0).get(9).value, is(4));
// ingest count
Expand Down Expand Up @@ -82,5 +82,12 @@ public void testBuildTableAccumulatedStats() {
.ingestFailedCount()
)
);
assertThat(table.getRows().get(0).get(14).value, is(dataframeConfig.getId()));
assertThat(table.getRows().get(0).get(15).value, is(dataframeConfig.getCreateTime()));
assertThat(table.getRows().get(0).get(16).value, is(Strings.arrayToCommaDelimitedString(dataframeConfig.getSource().getIndex())));
assertThat(
table.getRows().get(0).get(17).value,
dataframeConfig.getAnalysis() == null ? nullValue() : is(dataframeConfig.getAnalysis().getWriteableName())
);
}
}