Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ tests:
- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
method: test {p0=nodes.stats/11_indices_metrics/indices mappings exact count test for indices level}
issue: https://github.com/elastic/elasticsearch/issues/120950
- class: org.elasticsearch.xpack.ml.integration.PyTorchModelIT
issue: https://github.com/elastic/elasticsearch/issues/121165
- class: org.elasticsearch.test.rest.yaml.CcsCommonYamlTestSuiteIT
issue: https://github.com/elastic/elasticsearch/issues/121407
- class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
Expand Down Expand Up @@ -79,9 +77,6 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene

private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentClusterService.class);

private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These version checks are redundant in 9.0 and 9.1. The 8.x backports will need to keep them however.

public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0;

private final ClusterService clusterService;
private final ThreadPool threadPool;
private final NodeLoadDetector nodeLoadDetector;
Expand Down Expand Up @@ -169,14 +164,6 @@ public void clusterChanged(ClusterChangedEvent event) {
return;
}

if (eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(event)) {
// we should not try to rebalance assignments while there may be nodes running on a version
// prior to introducing distributed model allocation.
// But we should remove routing to removed or shutting down nodes.
removeRoutingToRemovedOrShuttingDownNodes(event);
return;
}

if (event.nodesAdded()) {
logMlNodeHeterogeneity();
}
Expand All @@ -203,10 +190,6 @@ public void clusterChanged(ClusterChangedEvent event) {
}
}

boolean eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(ClusterChangedEvent event) {
return event.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION);
}

boolean eventStateHasGlobalBlockStateNotRecoveredBlock(ClusterChangedEvent event) {
return event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK);
}
Expand Down Expand Up @@ -400,18 +383,6 @@ public void createNewModelAssignment(
CreateTrainedModelAssignmentAction.Request request,
ActionListener<TrainedModelAssignment> listener
) {
if (clusterService.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) {
listener.onFailure(
new ElasticsearchStatusException(
"cannot create new assignment [{}] for model [{}] while cluster upgrade is in progress",
RestStatus.CONFLICT,
request.getTaskParams().getDeploymentId(),
request.getTaskParams().getModelId()
)
);
return;
}

if (MlMetadata.getMlMetadata(clusterService.state()).isResetMode()) {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down Expand Up @@ -522,13 +493,11 @@ private static ClusterState update(ClusterState currentState, TrainedModelAssign

private static ClusterState forceUpdate(ClusterState currentState, TrainedModelAssignmentMetadata.Builder modelAssignments) {
logger.debug(() -> format("updated assignments: %s", modelAssignments.build()));

ProjectMetadata.Builder builder = ProjectMetadata.builder(currentState.metadata().getProject());
if (currentState.getMinTransportVersion().onOrAfter(RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION)) {
builder.putCustom(TrainedModelAssignmentMetadata.NAME, modelAssignments.build())
.removeCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME);
} else {
builder.putCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME, modelAssignments.buildOld());
}
builder.putCustom(TrainedModelAssignmentMetadata.NAME, modelAssignments.build())
.removeCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME);

return ClusterState.builder(currentState).putProjectMetadata(builder).build();
}

Expand Down Expand Up @@ -844,7 +813,7 @@ private void updateDeployment(
}
boolean hasUpdates = hasUpdates(numberOfAllocations, adaptiveAllocationsSettingsUpdates, existingAssignment);
if (hasUpdates == false) {
logger.info("no updates");
logger.debug("no updates to be made for deployment [{}]", deploymentId);
listener.onResponse(existingAssignment);
return;
}
Expand All @@ -858,27 +827,17 @@ private void updateDeployment(
);
return;
}
if (clusterState.getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) {
listener.onFailure(
new ElasticsearchStatusException(
"cannot update deployment with model id [{}] while cluster upgrade is in progress.",
RestStatus.CONFLICT,
deploymentId
)
);
return;
}

ActionListener<ClusterState> updatedStateListener = ActionListener.wrap(
updatedState -> submitUnbatchedTask("update model deployment", new ClusterStateUpdateTask() {
ActionListener<TrainedModelAssignmentMetadata.Builder> updatedAssignmentListener = ActionListener.wrap(
updatedAssignment -> submitUnbatchedTask("update model deployment", new ClusterStateUpdateTask() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix, here the new assignment state is passed rather than the updated cluster state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the TrainedModelAssignmentMetadata changed in the meantime?

That's a similar bug to the existing one, right? But much smaller, because the TrainedModelAssignmentMetadata changes less often than the ClusterState.

Should we protect against that? For example, only replace the TrainedModelAssignmentMetadata in the ClusterState if it's identical to the one we started the update computation with? And if it has changed, try this process again. More or less this paradigm:
https://github.com/elastic/elasticsearch/blob/main/test/framework/src/main/java/org/elasticsearch/common/util/MockBigArrays.java#L728-L742.

Maybe that's overkill and too complicated though. WDYT?

Copy link
Contributor

@jan-elastic jan-elastic Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're deciding not to fix this, let's leave a comment about this small issue with this implementation and call it a day

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the areClusterStatesCompatibleForRebalance function it also checks that the TrainedModelAssignmentMetadata has not changed.

&& TrainedModelAssignmentMetadata.fromState(source).equals(TrainedModelAssignmentMetadata.fromState(target));

By comparing the starting state with the latest state before applying the updated TrainedModelAssignmentMetadata (which is a lightweight operation and can be done in the ClusterStateUpdateTask) the code is effectively performing a "compare and swap" paradigm as linked above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, cool! I missed that part of the code. Guess this all works then as is


private volatile boolean isUpdated;

@Override
public ClusterState execute(ClusterState currentState) {
if (areClusterStatesCompatibleForRebalance(clusterState, currentState)) {
isUpdated = true;
return updatedState;
return update(currentState, updatedAssignment);
}
logger.debug(() -> format("[%s] Retrying update as cluster state has been modified", deploymentId));
updateDeployment(currentState, deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, listener);
Expand Down Expand Up @@ -910,7 +869,7 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState)
listener::onFailure
);

updateAssignment(clusterState, existingAssignment, numberOfAllocations, adaptiveAllocationsSettings, updatedStateListener);
updateAssignment(clusterState, existingAssignment, numberOfAllocations, adaptiveAllocationsSettings, updatedAssignmentListener);
}

static boolean hasUpdates(
Expand Down Expand Up @@ -944,7 +903,7 @@ private void updateAssignment(
TrainedModelAssignment assignment,
Integer numberOfAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
if (numberOfAllocations == null || numberOfAllocations == assignment.getTaskParams().getNumberOfAllocations()) {
Expand All @@ -961,21 +920,21 @@ private void updateAndKeepNumberOfAllocations(
ClusterState clusterState,
TrainedModelAssignment assignment,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment)
.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);
TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment);
listener.onResponse(update(clusterState, builder));
listener.onResponse(builder);
}

private void increaseNumberOfAllocations(
ClusterState clusterState,
TrainedModelAssignment assignment,
int numberOfAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
try {
TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment)
Expand All @@ -995,7 +954,7 @@ private void increaseNumberOfAllocations(
)
);
} else {
listener.onResponse(update(clusterState, rebalancedMetadata));
listener.onResponse(rebalancedMetadata);
}
} catch (Exception e) {
listener.onFailure(e);
Expand All @@ -1007,7 +966,7 @@ private void decreaseNumberOfAllocations(
TrainedModelAssignment assignment,
int numberOfAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
TrainedModelAssignment.Builder updatedAssignment = numberOfAllocations < assignment.totalTargetAllocations()
? new AllocationReducer(assignment, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(clusterState)).reduceTo(
Expand All @@ -1022,7 +981,7 @@ private void decreaseNumberOfAllocations(
}
TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment);
listener.onResponse(update(clusterState, builder));
listener.onResponse(builder);
}

static ClusterState setToStopping(ClusterState clusterState, String deploymentId, String reason) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,17 @@ public void clusterChanged(ClusterChangedEvent event) {
final boolean isResetMode = MlMetadata.getMlMetadata(event.state()).isResetMode();
TrainedModelAssignmentMetadata modelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(event.state());
final String currentNode = event.state().nodes().getLocalNodeId();
final boolean isNewAllocationSupported = event.state()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another version change that is irrelevant for 9

.getMinTransportVersion()
.onOrAfter(TrainedModelAssignmentClusterService.DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION);
final Set<String> shuttingDownNodes = Collections.unmodifiableSet(event.state().metadata().nodeShutdowns().getAllNodeIds());

if (isResetMode == false && isNewAllocationSupported) {
if (isResetMode == false) {
updateNumberOfAllocations(modelAssignmentMetadata);
}

for (TrainedModelAssignment trainedModelAssignment : modelAssignmentMetadata.allAssignments().values()) {
RoutingInfo routingInfo = trainedModelAssignment.getNodeRoutingTable().get(currentNode);
if (routingInfo != null) {
// Add new models to start loading if the assignment is not stopping
if (isNewAllocationSupported && trainedModelAssignment.getAssignmentState() != AssignmentState.STOPPING) {
if (trainedModelAssignment.getAssignmentState() != AssignmentState.STOPPING) {
if (shouldAssignmentBeRestarted(routingInfo, trainedModelAssignment.getDeploymentId())) {
prepareAssignmentForRestart(trainedModelAssignment);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ public void testClusterChanged_GivenNodesAdded_ThenLogMlNodeHeterogeneityCalled(
TrainedModelAssignmentClusterService serviceSpy = spy(createClusterService(randomInt(5)));
doNothing().when(serviceSpy).logMlNodeHeterogeneity();
doReturn(false).when(serviceSpy).eventStateHasGlobalBlockStateNotRecoveredBlock(any());
doReturn(false).when(serviceSpy).eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(any());

ClusterChangedEvent mockNodesAddedEvent = mock(ClusterChangedEvent.class);
ClusterState mockState = mock(ClusterState.class);
Expand Down