Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.ToXContent;
Expand All @@ -29,6 +31,7 @@
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class DeepSeekChatCompletionRequest implements Request {
private static final Logger logger = LogManager.getLogger(DeepSeekChatCompletionRequest.class);
private static final String MODEL_FIELD = "model";
private static final String MAX_TOKENS = "max_tokens";

Expand All @@ -47,7 +50,11 @@ public HttpRequest createHttpRequest() {
httpPost.setEntity(createEntity());

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
model.apiKey()
.ifPresentOrElse(
apiKey -> httpPost.setHeader(createAuthBearerHeader(apiKey)),
() -> logger.debug("No auth token present in request, sending without auth...")
);

return new HttpRequest(httpPost, getInferenceEntityId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
Expand All @@ -30,6 +31,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
Expand Down Expand Up @@ -63,6 +65,7 @@ public class DeepSeekChatCompletionModel extends Model {

private static final URI DEFAULT_URI = URI.create("https://api.deepseek.com/chat/completions");
private final DeepSeekServiceSettings serviceSettings;
@Nullable
private final DefaultSecretSettings secretSettings;

public static List<NamedWriteableRegistry.Entry> namedWriteables() {
Expand Down Expand Up @@ -126,7 +129,7 @@ public static DeepSeekChatCompletionModel readFromStorage(

private DeepSeekChatCompletionModel(
DeepSeekServiceSettings serviceSettings,
DefaultSecretSettings secretSettings,
@Nullable DefaultSecretSettings secretSettings,
ModelConfigurations configurations,
ModelSecrets secrets
) {
Expand All @@ -135,8 +138,8 @@ private DeepSeekChatCompletionModel(
this.secretSettings = secretSettings;
}

public SecureString apiKey() {
return secretSettings.apiKey();
public Optional<SecureString> apiKey() {
return Optional.ofNullable(secretSettings).map(DefaultSecretSettings::apiKey);
}

public String model() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ public Model parsePersistedConfigWithSecrets(

@Override
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
return parsePersistedConfigWithSecrets(modelId, taskType, config, config);
var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
return DeepSeekChatCompletionModel.readFromStorage(modelId, taskType, NAME, serviceSettingsMap, null);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.ExceptionsHelper.unwrapCause;
Expand Down Expand Up @@ -90,7 +91,7 @@ public void testParseRequestConfig() throws IOException, URISyntaxException {
}
""", webServer.getUri(null).toString()), assertNoFailureListener(model -> {
if (model instanceof DeepSeekChatCompletionModel deepSeekModel) {
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
assertThat(deepSeekModel.apiKey().get().getChars(), equalTo("12345".toCharArray()));
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
assertThat(deepSeekModel.uri(), equalTo(webServer.getUri(null)));
} else {
Expand Down Expand Up @@ -158,13 +159,10 @@ public void testParsePersistedConfig() throws IOException {
{
"service_settings": {
"model_id": "some-cool-model"
},
"secret_settings": {
"api_key": "12345"
}
}
""");
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
assertThat(deepSeekModel.apiKey(), equalTo(Optional.empty()));
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
}

Expand All @@ -174,33 +172,14 @@ public void testParsePersistedConfigWithUrl() throws IOException {
"service_settings": {
"model_id": "some-cool-model",
"url": "http://localhost:989"
},
"secret_settings": {
"api_key": "12345"
}
}
""");
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
assertThat(deepSeekModel.apiKey(), equalTo(Optional.empty()));
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
assertThat(deepSeekModel.uri(), equalTo(URI.create("http://localhost:989")));
}

public void testParsePersistedConfigWithoutApiKey() {
assertThrows(
"Validation Failed: 1: [secret_settings] does not contain the required setting [api_key];",
ValidationException.class,
() -> parsePersistedConfig("""
{
"service_settings": {
"model_id": "some-cool-model"
},
"secret_settings": {
}
}
""")
);
}

public void testParsePersistedConfigWithoutModel() {
assertThrows(
"Validation Failed: 1: [service_settings] does not contain the required setting [model];",
Expand Down Expand Up @@ -424,17 +403,20 @@ private InferenceEventsAssertion doUnifiedCompletionInfer() throws Exception {
}

private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskType taskType) throws URISyntaxException, IOException {
var model = service.parsePersistedConfig("inference-id", taskType, map(Strings.format("""
var model = service.parsePersistedConfigWithSecrets("inference-id", taskType, map(Strings.format("""
{
"service_settings": {
"model_id": "some-cool-model",
"url": "%s"
},
}
}
""", webServer.getUri(null).toString())), map("""
{
"secret_settings": {
"api_key": "12345"
}
}
""", webServer.getUri(null).toString())));
"""));
assertThat(model, isA(DeepSeekChatCompletionModel.class));
return (DeepSeekChatCompletionModel) model;
}
Expand Down