Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;

/**
Expand Down Expand Up @@ -60,13 +61,6 @@ public class InferenceServiceNodeLocalRateLimitCalculator implements InferenceSe
* - Which task types support request re-routing and "node-local" rate limit calculation
* - How many nodes should handle requests for each task type, based on cluster size (dynamically calculated or statically provided)
**/
static final Map<String, Collection<NodeLocalRateLimitConfig>> SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS = Map.of(
ElasticInferenceService.NAME,
// TODO: should probably be a map/set
List.of(new NodeLocalRateLimitConfig(TaskType.SPARSE_EMBEDDING, (numNodesInCluster) -> DEFAULT_MAX_NODES_PER_GROUPING))
);

record NodeLocalRateLimitConfig(TaskType taskType, MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy) {}

@FunctionalInterface
private interface MaxNodesPerGroupingStrategy {
Expand All @@ -81,11 +75,31 @@ private interface MaxNodesPerGroupingStrategy {

private final ConcurrentHashMap<String, Map<TaskType, RateLimitAssignment>> serviceAssignments;

private final SortedMap<String, SortedMap<TaskType, MaxNodesPerGroupingStrategy>> serviceNodeLocalRateLimitConfigs;

@Inject
public InferenceServiceNodeLocalRateLimitCalculator(ClusterService clusterService, InferenceServiceRegistry serviceRegistry) {
clusterService.addListener(this);
this.serviceRegistry = serviceRegistry;
this.serviceAssignments = new ConcurrentHashMap<>();
this.serviceNodeLocalRateLimitConfigs = createServiceNodeLocalRateLimitConfigs();
}

private SortedMap<String, SortedMap<TaskType, MaxNodesPerGroupingStrategy>> createServiceNodeLocalRateLimitConfigs() {
TreeMap<String, TreeMap<TaskType, MaxNodesPerGroupingStrategy>> serviceNodeLocalRateLimitConfigs = new TreeMap<>();

MaxNodesPerGroupingStrategy defaultStrategy = (numNodesInCluster) -> DEFAULT_MAX_NODES_PER_GROUPING;

for (var service : serviceRegistry.getServices().values()) {
TreeMap<TaskType, MaxNodesPerGroupingStrategy> serviceConfigs = new TreeMap<>();
var taskTypes = service.supportedTaskTypes();
for (TaskType taskType : taskTypes) {
serviceConfigs.put(taskType, defaultStrategy);
}
serviceNodeLocalRateLimitConfigs.put(service.name(), serviceConfigs);
}

return Collections.unmodifiableSortedMap(serviceNodeLocalRateLimitConfigs);
}

@Override
Expand All @@ -100,9 +114,8 @@ public void clusterChanged(ClusterChangedEvent event) {
}

public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) {
return SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.getOrDefault(serviceName, Collections.emptyList())
.stream()
.anyMatch(rateLimitConfig -> taskType.equals(rateLimitConfig.taskType));
Map<TaskType, MaxNodesPerGroupingStrategy> serviceConfigs = serviceNodeLocalRateLimitConfigs.get(serviceName);
return serviceConfigs != null && serviceConfigs.containsKey(taskType);
}

public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) {
Expand All @@ -126,27 +139,28 @@ private void updateAssignments(ClusterChangedEvent event) {

// Sort nodes by id (every node lands on the same result)
var sortedNodes = nodes.stream().sorted(Comparator.comparing(DiscoveryNode::getId)).toList();
PriorityQueue<Map.Entry<Integer, DiscoveryNode>> nodeAssignmentCounts = createNodeAssignmentPriorityQueue(sortedNodes);

// Sort inference services by name (every node lands on the same result)
var sortedServices = new ArrayList<>(serviceRegistry.getServices().values());
sortedServices.sort(Comparator.comparing(InferenceService::name));

for (String serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
for (String serviceName : serviceNodeLocalRateLimitConfigs.keySet()) {
Optional<InferenceService> service = serviceRegistry.getService(serviceName);

if (service.isPresent()) {
var inferenceService = service.get();

for (NodeLocalRateLimitConfig rateLimitConfig : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName)) {
for (TaskType taskType : serviceNodeLocalRateLimitConfigs.get(serviceName).keySet()) {
Map<TaskType, RateLimitAssignment> perTaskTypeAssignments = new HashMap<>();
TaskType taskType = rateLimitConfig.taskType();
var maxNodesPerGroupingStrategy = serviceNodeLocalRateLimitConfigs.get(serviceName).get(taskType);

// Calculate node assignments needed for re-routing
var assignedNodes = calculateServiceAssignment(rateLimitConfig.maxNodesPerGroupingStrategy(), sortedNodes);
var assignedNodes = calculateServiceTaskTypeAssignment(maxNodesPerGroupingStrategy, nodeAssignmentCounts);

// Update rate limits to be "node-local"
var numAssignedNodes = assignedNodes.size();
updateRateLimits(inferenceService, numAssignedNodes);
updateRateLimits(inferenceService, taskType, numAssignedNodes);

perTaskTypeAssignments.put(taskType, new RateLimitAssignment(assignedNodes));
serviceAssignments.put(serviceName, perTaskTypeAssignments);
Expand All @@ -160,38 +174,61 @@ private void updateAssignments(ClusterChangedEvent event) {
}
}

private List<DiscoveryNode> calculateServiceAssignment(
private PriorityQueue<Map.Entry<Integer, DiscoveryNode>> createNodeAssignmentPriorityQueue(List<DiscoveryNode> sortedNodes) {
PriorityQueue<Map.Entry<Integer, DiscoveryNode>> nodeAssignmentCounts = new PriorityQueue<>(
Comparator.comparingInt((Map.Entry<Integer, DiscoveryNode> o) -> o.getKey()).thenComparing(o -> o.getValue().getId())
);

for (DiscoveryNode node : sortedNodes) {
nodeAssignmentCounts.add(Map.entry(0, node));
}
return nodeAssignmentCounts;
}

private List<DiscoveryNode> calculateServiceTaskTypeAssignment(
MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy,
List<DiscoveryNode> sortedNodes
PriorityQueue<Map.Entry<Integer, DiscoveryNode>> nodeAssignmentCounts
) {
int numberOfNodes = sortedNodes.size();
// Use a priority queue to prioritize nodes with the fewest assignments
int numberOfNodes = nodeAssignmentCounts.size();
int nodesPerGrouping = Math.min(numberOfNodes, maxNodesPerGroupingStrategy.calculate(numberOfNodes));

List<DiscoveryNode> assignedNodes = new ArrayList<>();

// TODO: here we can probably be smarter: if |num nodes in cluster| > |num nodes per task types|
// -> make sure a service provider is not assigned the same nodes for all task types; only relevant as soon as we support more task
// types
// Assign nodes by repeatedly picking the one with the fewest assignments
for (int j = 0; j < nodesPerGrouping; j++) {
var assignedNode = sortedNodes.get(j % numberOfNodes);
assignedNodes.add(assignedNode);
Map.Entry<Integer, DiscoveryNode> fewestAssignments = nodeAssignmentCounts.poll();

if (fewestAssignments == null) {
logger.warn("Node assignment queue is empty. Stopping node local rate limiting assignment.");
break;
}

DiscoveryNode nodeToAssign = fewestAssignments.getValue();
int currentAssignmentCount = fewestAssignments.getKey();

assignedNodes.add(nodeToAssign);
nodeAssignmentCounts.add(Map.entry(currentAssignmentCount + 1, nodeToAssign));
}

return assignedNodes;
}

private void updateRateLimits(InferenceService service, int responsibleNodes) {
private void updateRateLimits(InferenceService service, TaskType taskType, int responsibleNodes) {
if ((service instanceof SenderService) == false) {
return;
}

SenderService senderService = (SenderService) service;
Sender sender = senderService.getSender();
// TODO: this needs to take in service and task type as soon as multiple services/task types are supported
sender.updateRateLimitDivisor(responsibleNodes);
sender.updateRateLimitDivisor(service.name(), taskType, responsibleNodes);
}

InferenceServiceRegistry serviceRegistry() {
return serviceRegistry;
}

SortedMap<String, SortedMap<TaskType, MaxNodesPerGroupingStrategy>> serviceNodeLocalRateLimitConfigs() {
return serviceNodeLocalRateLimitConfigs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;

Expand All @@ -21,7 +22,7 @@ public interface RequestExecutor {

void shutdown();

void updateRateLimitDivisor(int newDivisor);
void updateRateLimitDivisor(String serviceName, TaskType taskType, int newDivisor);

boolean isShutdown();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
abstract class AlibabaCloudSearchRequestManager extends BaseRequestManager {

protected AlibabaCloudSearchRequestManager(ThreadPool threadPool, AlibabaCloudSearchModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
AlibabaCloudSearchRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int apiKeyHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@

import java.util.Objects;

public abstract class AmazonBedrockRequestManager implements RequestManager {
public abstract class AmazonBedrockRequestManager extends BaseRequestManager {

protected final ThreadPool threadPool;
protected final TimeValue timeout;
private final AmazonBedrockModel baseModel;

protected AmazonBedrockRequestManager(AmazonBedrockModel baseModel, ThreadPool threadPool, @Nullable TimeValue timeout) {
super(
threadPool,
baseModel.getInferenceEntityId(),
AmazonBedrockRequestManager.RateLimitGrouping.of(baseModel),
baseModel.rateLimitSettings(),
baseModel.getConfigurations().getService(),
baseModel.getConfigurations().getTaskType()
);
this.baseModel = Objects.requireNonNull(baseModel);
this.threadPool = Objects.requireNonNull(threadPool);
this.timeout = timeout;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
abstract class AnthropicRequestManager extends BaseRequestManager {

protected AnthropicRequestManager(ThreadPool threadPool, AnthropicModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
AnthropicRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int accountHash, int modelIdHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
public abstract class AzureAiStudioRequestManager extends BaseRequestManager {

protected AzureAiStudioRequestManager(ThreadPool threadPool, AzureAiStudioModel model) {
super(threadPool, model.getInferenceEntityId(), AzureAiStudioRequestManager.RateLimitGrouping.of(model), model.rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
AzureAiStudioRequestManager.RateLimitGrouping.of(model),
model.rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int targetHashcode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@

public abstract class AzureOpenAiRequestManager extends BaseRequestManager {
protected AzureOpenAiRequestManager(ThreadPool threadPool, AzureOpenAiModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int resourceNameHash, int deploymentIdHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
Expand All @@ -23,17 +24,28 @@ abstract class BaseRequestManager implements RequestManager {
// the rate and the other inference endpoint's rate will be ignored
private final EndpointGrouping endpointGrouping;
private final RateLimitSettings rateLimitSettings;
private final String service;
private final TaskType taskType;

BaseRequestManager(ThreadPool threadPool, String inferenceEntityId, Object rateLimitGroup, RateLimitSettings rateLimitSettings) {
BaseRequestManager(
ThreadPool threadPool,
String inferenceEntityId,
Object rateLimitGroup,
RateLimitSettings rateLimitSettings,
String service,
TaskType taskType
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);

Objects.requireNonNull(rateLimitSettings);
this.endpointGrouping = new EndpointGrouping(Objects.requireNonNull(rateLimitGroup).hashCode(), rateLimitSettings);
this.rateLimitSettings = rateLimitSettings;
this.service = service;
this.taskType = taskType;
}

BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel) {
BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel, String service, TaskType taskType) {
this.threadPool = Objects.requireNonNull(threadPool);
Objects.requireNonNull(rateLimitGroupingModel);

Expand All @@ -43,6 +55,8 @@ abstract class BaseRequestManager implements RequestManager {
rateLimitGroupingModel.rateLimitSettings()
);
this.rateLimitSettings = rateLimitGroupingModel.rateLimitSettings();
this.service = service;
this.taskType = taskType;
}

protected void execute(Runnable runnable) {
Expand All @@ -64,5 +78,15 @@ public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}

@Override
public String service() {
return this.service;
}

@Override
public TaskType taskType() {
return this.taskType;
}

private record EndpointGrouping(int group, RateLimitSettings settings) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
abstract class CohereRequestManager extends BaseRequestManager {

protected CohereRequestManager(ThreadPool threadPool, CohereModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
CohereRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int apiKeyHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,25 @@ public class DeepSeekRequestManager extends BaseRequestManager {
private final DeepSeekChatCompletionModel model;

public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
DeepSeekRequestManager.RateLimitGrouping.of(model),
model.rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
this.model = Objects.requireNonNull(model);
}

record RateLimitGrouping(int apiKeyHash) {
public static DeepSeekRequestManager.RateLimitGrouping of(DeepSeekChatCompletionModel model) {
Objects.requireNonNull(model);

return new DeepSeekRequestManager.RateLimitGrouping(model.apiKey().hashCode());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it was intentional to limit the rate limit to the max allowed:

So I think we should revert the changes around the api key here.

cc: @prwhelan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct - there is effectively no rate limiting for DeepSeek

}
}

@Override
public void execute(
InferenceInputs inferenceInputs,
Expand Down
Loading