Skip to content

Commit 4a39b4c

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Adding headers support for OpenAI chat completion (#134504)
* Adding headers support for openai chat completion and completion * Update docs/changelog/134504.yaml * [CI] Auto commit changes from spotless * Adjusting test name * Updating the changelog * Adding headers to the configuration * [CI] Auto commit changes from spotless * [CI] Update transport version definitions * Fixing transport versions --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 534cdb6 commit 4a39b4c

File tree

22 files changed

+329
-62
lines changed

22 files changed

+329
-62
lines changed

docs/changelog/134504.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 134504
2+
summary: Adding headers support for OpenAI chat completion
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ static TransportVersion def(int id) {
324324
public static final TransportVersion INDEX_SOURCE = def(9_158_0_00);
325325
public static final TransportVersion MAX_HEAP_SIZE_PER_NODE_IN_CLUSTER_INFO = def(9_159_0_00);
326326
public static final TransportVersion TIMESERIES_DEFAULT_LIMIT = def(9_160_0_00);
327+
public static final TransportVersion INFERENCE_API_OPENAI_HEADERS = def(9_161_0_00);
327328

328329
/*
329330
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,19 @@ public <V> Map<String, V> readImmutableMap(Writeable.Reader<V> valueReader) thro
841841
}
842842

843843
/**
844-
* Read a {@link Map} using the given key and value readers. The return Map is immutable.
844+
* Read an optional {@link Map} using the given key and value readers. The returned Map is immutable.
845+
*
846+
* @param keyReader Method to read a key. Must not return null.
847+
* @param valueReader Method to read a value. Must not return null.
848+
* @return The immutable map or null if not present
849+
*/
850+
public <K, V> Map<K, V> readOptionalImmutableMap(Writeable.Reader<K> keyReader, Writeable.Reader<V> valueReader) throws IOException {
851+
final boolean present = readBoolean();
852+
return present ? readImmutableMap(keyReader, valueReader) : null;
853+
}
854+
855+
/**
856+
* Read a {@link Map} using the given key and value readers. The returned Map is immutable.
845857
*
846858
* @param keyReader Method to read a key. Must not return null.
847859
* @param valueReader Method to read a value. Must not return null.

server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,26 @@ public final <K extends Writeable, V extends Writeable> void writeMap(final Map<
643643
writeMap(map, StreamOutput::writeWriteable, StreamOutput::writeWriteable);
644644
}
645645

646+
/**
647+
* Write an optional {@link Map} of {@code K}-type keys to {@code V}-type.
648+
* <pre><code>
649+
* Map&lt;String, String&gt; map = ...;
650+
* out.writeMap(map, StreamOutput::writeString, StreamOutput::writeString);
651+
* </code></pre>
652+
*
653+
* @param keyWriter The key writer
654+
* @param valueWriter The value writer
655+
*/
656+
public final <K, V> void writeOptionalMap(final Map<K, V> map, final Writer<K> keyWriter, final Writer<V> valueWriter)
657+
throws IOException {
658+
if (map == null) {
659+
writeBoolean(false);
660+
} else {
661+
writeBoolean(true);
662+
writeMap(map, keyWriter, valueWriter);
663+
}
664+
}
665+
646666
/**
647667
* Write a {@link Map} of {@code K}-type keys to {@code V}-type.
648668
* <pre><code>

server/src/main/java/org/elasticsearch/inference/configuration/SettingsConfigurationFieldType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ public enum SettingsConfigurationFieldType {
1313
STRING("str"),
1414
INTEGER("int"),
1515
LIST("list"),
16-
BOOLEAN("bool");
16+
BOOLEAN("bool"),
17+
MAP("map");
1718

1819
private final String value;
1920

server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.common.util.set.Sets;
2626
import org.elasticsearch.core.CheckedConsumer;
2727
import org.elasticsearch.core.CheckedFunction;
28+
import org.elasticsearch.core.Nullable;
2829
import org.elasticsearch.core.Strings;
2930
import org.elasticsearch.core.Tuple;
3031
import org.elasticsearch.test.ESTestCase;
@@ -674,22 +675,63 @@ public void testObjectArrayIsWriteable() throws IOException {
674675
assertNotWriteable(new Object[] { new Unwriteable() }, Unwriteable.class);
675676
}
676677

677-
public void assertImmutableMapSerialization(Map<String, Integer> expected) throws IOException {
678-
final BytesStreamOutput output = new BytesStreamOutput();
679-
output.writeMap(expected, StreamOutput::writeString, StreamOutput::writeVInt);
680-
final BytesReference bytesReference = output.bytes();
678+
public void testImmutableMapSerialization() throws IOException {
679+
CheckedBiConsumer<BytesStreamOutput, Map<String, Integer>, IOException> writer = (out, map) -> out.writeMap(
680+
map,
681+
StreamOutput::writeString,
682+
StreamOutput::writeVInt
683+
);
684+
CheckedFunction<StreamInput, Map<String, Integer>, IOException> reader = in -> in.readImmutableMap(
685+
StreamInput::readString,
686+
StreamInput::readVInt
687+
);
681688

682-
final StreamInput input = getStreamInput(bytesReference);
683-
Map<String, Integer> got = input.readImmutableMap(StreamInput::readString, StreamInput::readVInt);
689+
assertOptionalImmutableMapSerialization(Map.of(), writer, reader);
690+
assertOptionalImmutableMapSerialization(Map.of("a", 1), writer, reader);
691+
assertOptionalImmutableMapSerialization(Map.of("a", 1, "b", 2), writer, reader);
692+
}
693+
694+
public void testOptionalImmutableMapSerialization() throws IOException {
695+
CheckedBiConsumer<BytesStreamOutput, Map<String, Integer>, IOException> writer = (out, map) -> out.writeOptionalMap(
696+
map,
697+
StreamOutput::writeString,
698+
StreamOutput::writeVInt
699+
);
700+
CheckedFunction<StreamInput, Map<String, Integer>, IOException> reader = in -> in.readOptionalImmutableMap(
701+
StreamInput::readString,
702+
StreamInput::readVInt
703+
);
704+
705+
assertOptionalImmutableMapSerialization(null, writer, reader);
706+
assertOptionalImmutableMapSerialization(Map.of(), writer, reader);
707+
assertOptionalImmutableMapSerialization(Map.of("a", 1), writer, reader);
708+
assertOptionalImmutableMapSerialization(Map.of("a", 1, "b", 2), writer, reader);
709+
}
710+
711+
public void assertOptionalImmutableMapSerialization(
712+
@Nullable Map<String, Integer> expected,
713+
CheckedBiConsumer<BytesStreamOutput, Map<String, Integer>, IOException> writer,
714+
CheckedFunction<StreamInput, Map<String, Integer>, IOException> reader
715+
) throws IOException {
716+
var got = writeThenReadImmutableMap(expected, writer, reader);
684717
assertThat(got, equalTo(expected));
685718

686-
expectThrows(UnsupportedOperationException.class, () -> got.put("blah", 1));
719+
if (got != null) {
720+
expectThrows(UnsupportedOperationException.class, () -> got.put("blah", 1));
721+
}
687722
}
688723

689-
public void testImmutableMapSerialization() throws IOException {
690-
assertImmutableMapSerialization(Map.of());
691-
assertImmutableMapSerialization(Map.of("a", 1));
692-
assertImmutableMapSerialization(Map.of("a", 1, "b", 2));
724+
private <K, V, E extends IOException> Map<K, V> writeThenReadImmutableMap(
725+
@Nullable Map<K, V> expected,
726+
CheckedBiConsumer<BytesStreamOutput, Map<K, V>, E> writer,
727+
CheckedFunction<StreamInput, Map<K, V>, E> reader
728+
) throws IOException {
729+
final BytesStreamOutput output = new BytesStreamOutput();
730+
writer.accept(output, expected);
731+
final BytesReference bytesReference = output.bytes();
732+
733+
final StreamInput input = getStreamInput(bytesReference);
734+
return reader.apply(input);
693735
}
694736

695737
public <T> void assertImmutableListSerialization(List<T> expected, Writeable.Reader<T> reader, Writeable.Writer<T> writer)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,6 @@ public static Map<String, Object> extractRequiredMap(
532532
public static Map<String, Object> extractOptionalMap(
533533
Map<String, Object> map,
534534
String settingName,
535-
String scope,
536535
ValidationException validationException
537536
) {
538537
int initialValidationErrorCount = validationException.validationErrors().size();
@@ -545,6 +544,14 @@ public static Map<String, Object> extractOptionalMap(
545544
return optionalField;
546545
}
547546

547+
public static Map<String, Object> extractOptionalMapRemoveNulls(
548+
Map<String, Object> map,
549+
String settingName,
550+
ValidationException validationException
551+
) {
552+
return removeNullValues(extractOptionalMap(map, settingName, validationException));
553+
}
554+
548555
public static List<Tuple<String, String>> extractOptionalListOfStringTuples(
549556
Map<String, Object> map,
550557
String settingName,
@@ -626,6 +633,20 @@ private static void validateString(
626633
}
627634
}
628635

636+
public static Map<String, String> validateMapStringValues(
637+
Map<String, ?> map,
638+
String settingName,
639+
ValidationException validationException,
640+
boolean censorValue,
641+
@Nullable Map<String, String> defaultValue
642+
) {
643+
if (map == null) {
644+
return defaultValue;
645+
}
646+
647+
return validateMapStringValues(map, settingName, validationException, censorValue);
648+
}
649+
629650
/**
630651
* Validates that each value in the map is a {@link String} and returns a new map of {@code Map<String, String>}.
631652
*/

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
import java.util.Objects;
2424

2525
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString;
26-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
27-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues;
26+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
2827

2928
public class CustomSecretSettings implements SecretSettings {
3029
public static final String NAME = "custom_secret_settings";
@@ -37,8 +36,7 @@ public static CustomSecretSettings fromMap(@Nullable Map<String, Object> map) {
3736

3837
ValidationException validationException = new ValidationException();
3938

40-
Map<String, Object> requestSecretParamsMap = extractOptionalMap(map, SECRET_PARAMETERS, NAME, validationException);
41-
removeNullValues(requestSecretParamsMap);
39+
Map<String, Object> requestSecretParamsMap = extractOptionalMapRemoveNulls(map, SECRET_PARAMETERS, validationException);
4240
var secureStringMap = convertMapStringsToSecureString(requestSecretParamsMap, SECRET_PARAMETERS, validationException);
4341

4442
if (validationException.validationErrors().isEmpty() == false) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@
4242
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
4343
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
4444
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
45-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
45+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
4646
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
4747
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
4848
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
4949
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
5050
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
51-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues;
5251
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
5352
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
5453

@@ -75,8 +74,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
7574

7675
var queryParams = QueryParameters.fromMap(map, validationException);
7776

78-
Map<String, Object> headers = extractOptionalMap(map, HEADERS, ModelConfigurations.SERVICE_SETTINGS, validationException);
79-
removeNullValues(headers);
77+
Map<String, Object> headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
8078
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false);
8179

8280
String requestContentString = extractRequiredString(map, REQUEST, ModelConfigurations.SERVICE_SETTINGS, validationException);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.common.ValidationException;
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
15-
import org.elasticsearch.inference.ModelConfigurations;
1615
import org.elasticsearch.inference.TaskSettings;
1716
import org.elasticsearch.xcontent.XContentBuilder;
1817

@@ -22,8 +21,7 @@
2221
import java.util.Map;
2322
import java.util.Objects;
2423

25-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
26-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues;
24+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
2725
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues;
2826

2927
public class CustomTaskSettings implements TaskSettings {
@@ -39,8 +37,7 @@ public static CustomTaskSettings fromMap(Map<String, Object> map) {
3937
return EMPTY_SETTINGS;
4038
}
4139

42-
Map<String, Object> parameters = extractOptionalMap(map, PARAMETERS, ModelConfigurations.TASK_SETTINGS, validationException);
43-
removeNullValues(parameters);
40+
Map<String, Object> parameters = extractOptionalMapRemoveNulls(map, PARAMETERS, validationException);
4441
validateMapValues(
4542
parameters,
4643
List.of(String.class, Integer.class, Double.class, Float.class, Boolean.class),

0 commit comments

Comments
 (0)