|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License |
| 4 | + * 2.0; you may not use this file except in compliance with the Elastic License |
| 5 | + * 2.0. |
| 6 | + */ |
| 7 | + |
| 8 | +package org.elasticsearch.xpack.inference; |
| 9 | + |
| 10 | +import org.elasticsearch.inference.TaskType; |
| 11 | +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; |
| 12 | + |
| 13 | +import java.io.IOException; |
| 14 | +import java.util.List; |
| 15 | +import java.util.Map; |
| 16 | +import java.util.Optional; |
| 17 | + |
| 18 | +public class InferenceUpdateElasticsearchInternalServiceModelIT extends CustomElandModelIT { |
| 19 | + private final List<AdaptiveAllocationsSettings> ADAPTIVE_ALLOCATIONS_SETTINGS = List.of( |
| 20 | + new AdaptiveAllocationsSettings(randomBoolean(), null, null), |
| 21 | + new AdaptiveAllocationsSettings(null, randomIntBetween(1, 10), null), |
| 22 | + new AdaptiveAllocationsSettings(null, null, randomIntBetween(1, 10)), |
| 23 | + new AdaptiveAllocationsSettings(randomBoolean(), randomIntBetween(1, 10), randomIntBetween(11, 20)) |
| 24 | + ); |
| 25 | + |
| 26 | + public void testUpdateNumThreads() throws IOException { |
| 27 | + testUpdateElasticsearchInternalServiceEndpoint( |
| 28 | + Optional.of(randomIntBetween(2, 10)), |
| 29 | + Optional.empty(), |
| 30 | + Optional.empty(), |
| 31 | + Optional.empty() |
| 32 | + ); |
| 33 | + } |
| 34 | + |
| 35 | + public void testUpdateAdaptiveAllocationsSettings() throws IOException { |
| 36 | + for (AdaptiveAllocationsSettings settings : ADAPTIVE_ALLOCATIONS_SETTINGS) { |
| 37 | + testUpdateElasticsearchInternalServiceEndpoint( |
| 38 | + Optional.empty(), |
| 39 | + Optional.ofNullable(settings.getEnabled()), |
| 40 | + Optional.ofNullable(settings.getMinNumberOfAllocations()), |
| 41 | + Optional.ofNullable(settings.getMaxNumberOfAllocations()) |
| 42 | + ); |
| 43 | + } |
| 44 | + } |
| 45 | + |
| 46 | + public void testUpdateNumAllocationsAndAdaptiveAllocationsSettings() throws IOException { |
| 47 | + testUpdateElasticsearchInternalServiceEndpoint( |
| 48 | + Optional.of(randomIntBetween(2, 10)), |
| 49 | + Optional.of(randomBoolean()), |
| 50 | + Optional.of(randomIntBetween(1, 10)), |
| 51 | + Optional.of(randomIntBetween(11, 20)) |
| 52 | + ); |
| 53 | + } |
| 54 | + |
| 55 | + private void testUpdateElasticsearchInternalServiceEndpoint( |
| 56 | + Optional<Integer> updatedNumAllocations, |
| 57 | + Optional<Boolean> updatedAdaptiveAllocationsEnabled, |
| 58 | + Optional<Integer> updatedMinNumberOfAllocations, |
| 59 | + Optional<Integer> updatedMaxNumberOfAllocations |
| 60 | + ) throws IOException { |
| 61 | + var inferenceId = "update-adaptive-allocations-inference"; |
| 62 | + var originalEndpoint = setupInferenceEndpoint(inferenceId); |
| 63 | + verifyEndpointConfig(originalEndpoint, 1, Optional.empty(), Optional.empty(), Optional.empty()); |
| 64 | + |
| 65 | + var updateConfig = generateUpdateConfig( |
| 66 | + updatedNumAllocations, |
| 67 | + updatedAdaptiveAllocationsEnabled, |
| 68 | + updatedMinNumberOfAllocations, |
| 69 | + updatedMaxNumberOfAllocations |
| 70 | + ); |
| 71 | + var updatedEndpoint = updateEndpoint(inferenceId, updateConfig, TaskType.SPARSE_EMBEDDING); |
| 72 | + verifyEndpointConfig( |
| 73 | + updatedEndpoint, |
| 74 | + updatedNumAllocations.orElse(1), |
| 75 | + updatedAdaptiveAllocationsEnabled, |
| 76 | + updatedMinNumberOfAllocations, |
| 77 | + updatedMaxNumberOfAllocations |
| 78 | + ); |
| 79 | + } |
| 80 | + |
| 81 | + private Map<String, Object> setupInferenceEndpoint(String inferenceId) throws IOException { |
| 82 | + String modelId = "custom-text-expansion-model"; |
| 83 | + createMlNodeTextExpansionModel(modelId, client()); |
| 84 | + |
| 85 | + var inferenceConfig = """ |
| 86 | + { |
| 87 | + "service": "elasticsearch", |
| 88 | + "service_settings": { |
| 89 | + "model_id": "custom-text-expansion-model", |
| 90 | + "num_allocations": 1, |
| 91 | + "num_threads": 1 |
| 92 | + } |
| 93 | + } |
| 94 | + """; |
| 95 | + |
| 96 | + return putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING); |
| 97 | + } |
| 98 | + |
| 99 | + public static String generateUpdateConfig( |
| 100 | + Optional<Integer> numAllocations, |
| 101 | + Optional<Boolean> adaptiveAllocationsEnabled, |
| 102 | + Optional<Integer> minNumberOfAllocations, |
| 103 | + Optional<Integer> maxNumberOfAllocations |
| 104 | + ) { |
| 105 | + StringBuilder requestBodyBuilder = new StringBuilder(); |
| 106 | + requestBodyBuilder.append("{ \"service_settings\": {"); |
| 107 | + |
| 108 | + numAllocations.ifPresent(value -> requestBodyBuilder.append("\"num_allocations\": ").append(value).append(",")); |
| 109 | + |
| 110 | + if (adaptiveAllocationsEnabled.isPresent() || minNumberOfAllocations.isPresent() || maxNumberOfAllocations.isPresent()) { |
| 111 | + requestBodyBuilder.append("\"adaptive_allocations\": {"); |
| 112 | + adaptiveAllocationsEnabled.ifPresent(value -> requestBodyBuilder.append("\"enabled\": ").append(value).append(",")); |
| 113 | + minNumberOfAllocations.ifPresent( |
| 114 | + value -> requestBodyBuilder.append("\"min_number_of_allocations\": ").append(value).append(",") |
| 115 | + ); |
| 116 | + maxNumberOfAllocations.ifPresent( |
| 117 | + value -> requestBodyBuilder.append("\"max_number_of_allocations\": ").append(value).append(",") |
| 118 | + ); |
| 119 | + |
| 120 | + if (requestBodyBuilder.charAt(requestBodyBuilder.length() - 1) == ',') { |
| 121 | + requestBodyBuilder.deleteCharAt(requestBodyBuilder.length() - 1); |
| 122 | + } |
| 123 | + requestBodyBuilder.append("},"); |
| 124 | + } |
| 125 | + |
| 126 | + if (requestBodyBuilder.charAt(requestBodyBuilder.length() - 1) == ',') { |
| 127 | + requestBodyBuilder.deleteCharAt(requestBodyBuilder.length() - 1); |
| 128 | + } |
| 129 | + |
| 130 | + requestBodyBuilder.append("} }"); |
| 131 | + return requestBodyBuilder.toString(); |
| 132 | + } |
| 133 | + |
| 134 | + @SuppressWarnings("unchecked") |
| 135 | + private void verifyEndpointConfig( |
| 136 | + Map<String, Object> endpointConfig, |
| 137 | + int expectedNumAllocations, |
| 138 | + Optional<Boolean> adaptiveAllocationsEnabled, |
| 139 | + Optional<Integer> minNumberOfAllocations, |
| 140 | + Optional<Integer> maxNumberOfAllocations |
| 141 | + ) { |
| 142 | + var serviceSettings = (Map<String, Object>) endpointConfig.get("service_settings"); |
| 143 | + |
| 144 | + assertEquals(expectedNumAllocations, serviceSettings.get("num_allocations")); |
| 145 | + if (adaptiveAllocationsEnabled.isPresent() || minNumberOfAllocations.isPresent() || maxNumberOfAllocations.isPresent()) { |
| 146 | + var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations"); |
| 147 | + adaptiveAllocationsEnabled.ifPresent(enabled -> assertEquals(enabled, adaptiveAllocations.get("enabled"))); |
| 148 | + minNumberOfAllocations.ifPresent(min -> assertEquals(min, adaptiveAllocations.get("min_number_of_allocations"))); |
| 149 | + maxNumberOfAllocations.ifPresent(max -> assertEquals(max, adaptiveAllocations.get("max_number_of_allocations"))); |
| 150 | + } |
| 151 | + } |
| 152 | +} |
0 commit comments