Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
Expand Down Expand Up @@ -80,9 +78,6 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene

private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentClusterService.class);

private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0;
public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0;

private final ClusterService clusterService;
private final ThreadPool threadPool;
private final NodeLoadDetector nodeLoadDetector;
Expand Down Expand Up @@ -170,14 +165,6 @@ public void clusterChanged(ClusterChangedEvent event) {
return;
}

if (eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(event)) {
// we should not try to rebalance assignments while there may be nodes running on a version
// prior to introducing distributed model allocation.
// But we should remove routing to removed or shutting down nodes.
removeRoutingToRemovedOrShuttingDownNodes(event);
return;
}

if (event.nodesAdded()) {
logMlNodeHeterogeneity();
}
Expand All @@ -204,10 +191,6 @@ public void clusterChanged(ClusterChangedEvent event) {
}
}

boolean eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(ClusterChangedEvent event) {
return event.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION);
}

boolean eventStateHasGlobalBlockStateNotRecoveredBlock(ClusterChangedEvent event) {
return event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK);
}
Expand Down Expand Up @@ -401,18 +384,6 @@ public void createNewModelAssignment(
CreateTrainedModelAssignmentAction.Request request,
ActionListener<TrainedModelAssignment> listener
) {
if (clusterService.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) {
listener.onFailure(
new ElasticsearchStatusException(
"cannot create new assignment [{}] for model [{}] while cluster upgrade is in progress",
RestStatus.CONFLICT,
request.getTaskParams().getDeploymentId(),
request.getTaskParams().getModelId()
)
);
return;
}

if (MlMetadata.getMlMetadata(clusterService.state()).isResetMode()) {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down Expand Up @@ -524,12 +495,8 @@ private static ClusterState update(ClusterState currentState, TrainedModelAssign
private static ClusterState forceUpdate(ClusterState currentState, TrainedModelAssignmentMetadata.Builder modelAssignments) {
logger.debug(() -> format("updated assignments: %s", modelAssignments.build()));
Metadata.Builder metadata = Metadata.builder(currentState.metadata());
if (currentState.getMinTransportVersion().onOrAfter(RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION)) {
metadata.putCustom(TrainedModelAssignmentMetadata.NAME, modelAssignments.build())
.removeCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME);
} else {
metadata.putCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME, modelAssignments.buildOld());
}
metadata.putCustom(TrainedModelAssignmentMetadata.NAME, modelAssignments.build())
.removeCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME);
return ClusterState.builder(currentState).metadata(metadata).build();
}

Expand Down Expand Up @@ -847,7 +814,7 @@ private void updateDeployment(
}
boolean hasUpdates = hasUpdates(numberOfAllocations, adaptiveAllocationsSettingsUpdates, existingAssignment);
if (hasUpdates == false) {
logger.info("no updates");
logger.debug("no updates to be made for deployment [{}]", deploymentId);
listener.onResponse(existingAssignment);
return;
}
Expand All @@ -861,27 +828,17 @@ private void updateDeployment(
);
return;
}
if (clusterState.getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) {
listener.onFailure(
new ElasticsearchStatusException(
"cannot update deployment with model id [{}] while cluster upgrade is in progress.",
RestStatus.CONFLICT,
deploymentId
)
);
return;
}

ActionListener<ClusterState> updatedStateListener = ActionListener.wrap(
updatedState -> submitUnbatchedTask("update model deployment", new ClusterStateUpdateTask() {
ActionListener<TrainedModelAssignmentMetadata.Builder> updatedAssignmentListener = ActionListener.wrap(
updatedAssignment -> submitUnbatchedTask("update model deployment", new ClusterStateUpdateTask() {

private volatile boolean isUpdated;

@Override
public ClusterState execute(ClusterState currentState) {
if (areClusterStatesCompatibleForRebalance(clusterState, currentState)) {
isUpdated = true;
return updatedState;
return update(currentState, updatedAssignment);
}
logger.debug(() -> format("[%s] Retrying update as cluster state has been modified", deploymentId));
updateDeployment(currentState, deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, listener);
Expand Down Expand Up @@ -913,7 +870,7 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState)
listener::onFailure
);

updateAssignment(clusterState, existingAssignment, numberOfAllocations, adaptiveAllocationsSettings, updatedStateListener);
updateAssignment(clusterState, existingAssignment, numberOfAllocations, adaptiveAllocationsSettings, updatedAssignmentListener);
}

static boolean hasUpdates(
Expand Down Expand Up @@ -947,7 +904,7 @@ private void updateAssignment(
TrainedModelAssignment assignment,
Integer numberOfAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
if (numberOfAllocations == null || numberOfAllocations == assignment.getTaskParams().getNumberOfAllocations()) {
Expand All @@ -964,21 +921,21 @@ private void updateAndKeepNumberOfAllocations(
ClusterState clusterState,
TrainedModelAssignment assignment,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment)
.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);
TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment);
listener.onResponse(update(clusterState, builder));
listener.onResponse(builder);
}

private void increaseNumberOfAllocations(
ClusterState clusterState,
TrainedModelAssignment assignment,
int numberOfAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
try {
TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment)
Expand All @@ -998,7 +955,7 @@ private void increaseNumberOfAllocations(
)
);
} else {
listener.onResponse(update(clusterState, rebalancedMetadata));
listener.onResponse(rebalancedMetadata);
}
} catch (Exception e) {
listener.onFailure(e);
Expand All @@ -1010,7 +967,7 @@ private void decreaseNumberOfAllocations(
TrainedModelAssignment assignment,
int numberOfAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
ActionListener<ClusterState> listener
ActionListener<TrainedModelAssignmentMetadata.Builder> listener
) {
TrainedModelAssignment.Builder updatedAssignment = numberOfAllocations < assignment.totalTargetAllocations()
? new AllocationReducer(assignment, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(clusterState)).reduceTo(
Expand All @@ -1025,7 +982,7 @@ private void decreaseNumberOfAllocations(
}
TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment);
listener.onResponse(update(clusterState, builder));
listener.onResponse(builder);
}

static ClusterState setToStopping(ClusterState clusterState, String deploymentId, String reason) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,17 @@ public void clusterChanged(ClusterChangedEvent event) {
final boolean isResetMode = MlMetadata.getMlMetadata(event.state()).isResetMode();
TrainedModelAssignmentMetadata modelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(event.state());
final String currentNode = event.state().nodes().getLocalNodeId();
final boolean isNewAllocationSupported = event.state()
.getMinTransportVersion()
.onOrAfter(TrainedModelAssignmentClusterService.DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION);
final Set<String> shuttingDownNodes = Collections.unmodifiableSet(event.state().metadata().nodeShutdowns().getAllNodeIds());

if (isResetMode == false && isNewAllocationSupported) {
if (isResetMode == false) {
updateNumberOfAllocations(modelAssignmentMetadata);
}

for (TrainedModelAssignment trainedModelAssignment : modelAssignmentMetadata.allAssignments().values()) {
RoutingInfo routingInfo = trainedModelAssignment.getNodeRoutingTable().get(currentNode);
if (routingInfo != null) {
// Add new models to start loading if the assignment is not stopping
if (isNewAllocationSupported && trainedModelAssignment.getAssignmentState() != AssignmentState.STOPPING) {
if (trainedModelAssignment.getAssignmentState() != AssignmentState.STOPPING) {
if (shouldAssignmentBeRestarted(routingInfo, trainedModelAssignment.getDeploymentId())) {
prepareAssignmentForRestart(trainedModelAssignment);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ public void testClusterChanged_GivenNodesAdded_ThenLogMlNodeHeterogeneityCalled(
TrainedModelAssignmentClusterService serviceSpy = spy(createClusterService(randomInt(5)));
doNothing().when(serviceSpy).logMlNodeHeterogeneity();
doReturn(false).when(serviceSpy).eventStateHasGlobalBlockStateNotRecoveredBlock(any());
doReturn(false).when(serviceSpy).eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(any());

ClusterChangedEvent mockNodesAddedEvent = mock(ClusterChangedEvent.class);
ClusterState mockState = mock(ClusterState.class);
Expand Down