Skip to content

Commit 0b1cc36

Browse files
Merge branch 'main' into ia-fix-eis-endpoints
2 parents b23024f + d5a05ec commit 0b1cc36

File tree

15 files changed

+487
-268
lines changed

15 files changed

+487
-268
lines changed

docs/changelog/137598.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 137598
2+
summary: Improve SAML error handling by adding metadata
3+
area: Authentication
4+
type: enhancement
5+
issues:
6+
- 128179

docs/changelog/138876.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138876
2+
summary: "Use doc values skipper for @timestamp in synthetic `_id` postings #138568"
3+
area: TSDB
4+
type: enhancement
5+
issues: []

server/src/main/java/module-info.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,10 @@
396396
to
397397
org.elasticsearch.inference,
398398
org.elasticsearch.metering,
399-
org.elasticsearch.stateless,
400399
org.elasticsearch.settings.secure,
401400
org.elasticsearch.serverless.constants,
402401
org.elasticsearch.serverless.apifiltering,
402+
org.elasticsearch.serverless.stateless,
403403
org.elasticsearch.internal.security,
404404
org.elasticsearch.xpack.gpu;
405405

server/src/main/java/org/elasticsearch/index/codec/tsdb/TSDBSyntheticIdFieldsProducer.java

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,25 @@ private int findFirstDocWithTsIdOrdinalEqualTo(int tsIdOrd) throws IOException {
373373
return DocIdSetIterator.NO_MORE_DOCS;
374374
}
375375

376+
/**
377+
* Skip as many documents as possible after a given document ID to find the first document ID matching the timestamp.
378+
*
379+
* @param timestamp the timestamp to match
380+
* @param minDocID the min. document ID
381+
* @return a docID to start scanning documents from in order to find the first document ID matching the provided timestamp
382+
* @throws IOException if any I/O exception occurs
383+
*/
384+
private int skipDocIDForTimestamp(long timestamp, int minDocID) throws IOException {
385+
var skipper = docValuesProducer.getSkipper(timestampFieldInfo);
386+
assert skipper != null;
387+
if (skipper.minValue() > timestamp || timestamp > skipper.maxValue()) {
388+
return DocIdSetIterator.NO_MORE_DOCS;
389+
}
390+
skipper.advance(minDocID);
391+
skipper.advance(timestamp, Long.MAX_VALUE);
392+
return Math.max(minDocID, skipper.minDocID(0));
393+
}
394+
376395
private int getTsIdValueCount() throws IOException {
377396
if (tsIdDocValues == null) {
378397
tsIdDocValues = docValuesProducer.getSorted(tsIdFieldInfo);
@@ -483,32 +502,37 @@ public SeekStatus seekCeil(BytesRef id) throws IOException {
483502
return SeekStatus.END;
484503
}
485504

486-
// _tsid found, extract the timestamp
487-
final long timestamp = TsidExtractingIdFieldMapper.extractTimestampFromSyntheticId(id);
488-
489-
// Find the first document matching the _tsid
490-
final int startDocID = docValues.findFirstDocWithTsIdOrdinalEqualTo(tsIdOrd);
505+
// Find the first document ID matching the _tsid
506+
int startDocID = docValues.findFirstDocWithTsIdOrdinalEqualTo(tsIdOrd);
491507
assert startDocID >= 0 : startDocID;
492508

493-
int docID = startDocID;
494-
int docTsIdOrd = tsIdOrd;
495-
long docTimestamp;
496-
497-
// Iterate over documents to find the first one matching the timestamp
498-
for (; docID < maxDocs; docID++) {
499-
docTimestamp = docValues.docTimestamp(docID);
500-
if (startDocID < docID) {
501-
// After the first doc, we need to check again if _tsid matches
502-
docTsIdOrd = docValues.docTsIdOrdinal(docID);
503-
}
504-
if (docTsIdOrd == tsIdOrd && docTimestamp == timestamp) {
505-
// It's a match!
506-
current = new SyntheticTerm(docID, tsIdOrd, tsId, docTimestamp, docValues.docRoutingHash(docID));
507-
return SeekStatus.FOUND;
508-
}
509-
// Remaining docs don't match, stop here
510-
if (tsIdOrd < docTsIdOrd || docTimestamp < timestamp) {
511-
break;
509+
if (startDocID != DocIdSetIterator.NO_MORE_DOCS) {
510+
// _tsid found, extract the timestamp
511+
final long timestamp = TsidExtractingIdFieldMapper.extractTimestampFromSyntheticId(id);
512+
513+
startDocID = docValues.skipDocIDForTimestamp(timestamp, startDocID);
514+
if (startDocID != DocIdSetIterator.NO_MORE_DOCS) {
515+
int docID = startDocID;
516+
int docTsIdOrd = tsIdOrd;
517+
long docTimestamp;
518+
519+
// Iterate over documents to find the first one matching the timestamp
520+
for (; docID < maxDocs; docID++) {
521+
docTimestamp = docValues.docTimestamp(docID);
522+
if (startDocID < docID) {
523+
// After the first doc, we need to check again if _tsid matches
524+
docTsIdOrd = docValues.docTsIdOrdinal(docID);
525+
}
526+
if (docTsIdOrd == tsIdOrd && docTimestamp == timestamp) {
527+
// It's a match!
528+
current = new SyntheticTerm(docID, tsIdOrd, tsId, docTimestamp, docValues.docRoutingHash(docID));
529+
return SeekStatus.FOUND;
530+
}
531+
// Remaining docs don't match, stop here
532+
if (tsIdOrd < docTsIdOrd || docTimestamp < timestamp) {
533+
break;
534+
}
535+
}
512536
}
513537
}
514538
current = NO_MORE_DOCS;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import java.io.IOException;
2929
import java.util.Objects;
3030
import java.util.concurrent.CancellationException;
31-
import java.util.concurrent.atomic.AtomicReference;
3231

3332
import static org.elasticsearch.core.Strings.format;
3433
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_RESPONSE_THREAD_POOL_NAME;
@@ -39,14 +38,7 @@
3938
public class HttpClient implements Closeable {
4039
private static final Logger logger = LogManager.getLogger(HttpClient.class);
4140

42-
enum Status {
43-
CREATED,
44-
STARTED,
45-
STOPPED
46-
}
47-
4841
private final CloseableHttpAsyncClient client;
49-
private final AtomicReference<Status> status = new AtomicReference<>(Status.CREATED);
5042
private final ThreadPool threadPool;
5143
private final HttpSettings settings;
5244
private final ThrottlerManager throttlerManager;
@@ -127,15 +119,10 @@ private static CloseableHttpAsyncClient createAsyncClient(
127119
}
128120

129121
public void start() {
130-
if (status.compareAndSet(Status.CREATED, Status.STARTED)) {
131-
client.start();
132-
}
122+
client.start();
133123
}
134124

135125
public void send(HttpRequest request, HttpClientContext context, ActionListener<HttpResult> listener) throws IOException {
136-
// The caller must call start() first before attempting to send a request
137-
assert status.get() == Status.STARTED : "call start() before attempting to send a request";
138-
139126
SocketAccess.doPrivileged(() -> client.execute(request.httpRequestBase(), context, new FutureCallback<>() {
140127
@Override
141128
public void completed(HttpResponse response) {
@@ -145,7 +132,7 @@ public void completed(HttpResponse response) {
145132
@Override
146133
public void failed(Exception ex) {
147134
throttlerManager.warn(logger, format("Request from inference entity id [%s] failed", request.inferenceEntityId()), ex);
148-
failUsingResponseThread(ex, listener);
135+
failUsingResponseThread(getException(ex), listener);
149136
}
150137

151138
@Override
@@ -179,10 +166,22 @@ private void failUsingResponseThread(Exception exception, ActionListener<?> list
179166
threadPool.executor(INFERENCE_RESPONSE_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception));
180167
}
181168

182-
public void stream(HttpRequest request, HttpContext context, ActionListener<StreamingHttpResult> listener) throws IOException {
183-
// The caller must call start() first before attempting to send a request
184-
assert status.get() == Status.STARTED : "call start() before attempting to send a request";
169+
private static Exception getException(Exception e) {
170+
if (e instanceof CancellationException cancellationException) {
171+
return createNotRunningException(cancellationException);
172+
}
185173

174+
return e;
175+
}
176+
177+
private static IllegalStateException createNotRunningException(Exception exception) {
178+
// If the http client isn't running, it is either not started yet, in which case we have a bug somewhere because
179+
// it should always be started as part of the inference plugin startup, or it is stopped meaning the node is shutting down.
180+
// If we're shutting down, the user should retry the request, and hopefully it'll hit a node that isn't shutting down.
181+
return new IllegalStateException("Http client is not running, please retry the request", exception);
182+
}
183+
184+
public void stream(HttpRequest request, HttpContext context, ActionListener<StreamingHttpResult> listener) throws IOException {
186185
var streamingProcessor = new StreamingHttpResultPublisher(threadPool, settings, listener);
187186

188187
SocketAccess.doPrivileged(() -> client.execute(request.requestProducer(), streamingProcessor, context, new FutureCallback<>() {
@@ -193,7 +192,7 @@ public void completed(Void response) {
193192

194193
@Override
195194
public void failed(Exception ex) {
196-
threadPool.executor(INFERENCE_RESPONSE_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex));
195+
threadPool.executor(INFERENCE_RESPONSE_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(getException(ex)));
197196
}
198197

199198
@Override
@@ -212,7 +211,6 @@ public void cancelled() {
212211

213212
@Override
214213
public void close() throws IOException {
215-
status.set(Status.STOPPED);
216214
client.close();
217215
}
218216
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.http.nio.reactor.IOReactorException;
2323
import org.elasticsearch.ElasticsearchException;
2424
import org.elasticsearch.action.support.PlainActionFuture;
25+
import org.elasticsearch.action.support.TestPlainActionFuture;
2526
import org.elasticsearch.common.Strings;
2627
import org.elasticsearch.common.settings.Settings;
2728
import org.elasticsearch.common.unit.ByteSizeValue;
@@ -33,7 +34,6 @@
3334
import org.elasticsearch.threadpool.ThreadPool;
3435
import org.elasticsearch.xcontent.XContentType;
3536
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
36-
import org.elasticsearch.xpack.inference.external.request.HttpRequestTests;
3737
import org.junit.After;
3838
import org.junit.Before;
3939

@@ -47,6 +47,7 @@
4747
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
4848
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
4949
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
50+
import static org.hamcrest.Matchers.containsString;
5051
import static org.hamcrest.Matchers.equalTo;
5152
import static org.hamcrest.Matchers.hasSize;
5253
import static org.hamcrest.Matchers.is;
@@ -102,13 +103,12 @@ public void testSend_MockServerReceivesRequest() throws Exception {
102103

103104
public void testSend_ThrowsErrorIfCalledBeforeStart() throws Exception {
104105
try (var httpClient = HttpClient.create(emptyHttpSettings(), threadPool, createConnectionManager(), mockThrottlerManager())) {
105-
PlainActionFuture<HttpResult> listener = new PlainActionFuture<>();
106-
var thrownException = expectThrows(
107-
AssertionError.class,
108-
() -> httpClient.send(HttpRequestTests.createMock("inferenceEntityId"), HttpClientContext.create(), listener)
109-
);
106+
var listener = new TestPlainActionFuture<HttpResult>();
107+
var httpPost = createHttpPost(webServer.getPort(), "key", "value");
108+
httpClient.send(httpPost, HttpClientContext.create(), listener);
109+
var thrownException = expectThrows(IllegalStateException.class, () -> listener.actionGet(TimeValue.THIRTY_SECONDS));
110110

111-
assertThat(thrownException.getMessage(), is("call start() before attempting to send a request"));
111+
assertThat(thrownException.getMessage(), containsString("Http client is not running, please retry the request"));
112112
}
113113
}
114114

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.security.authc.saml;
9+
10+
import org.apache.http.util.EntityUtils;
11+
import org.elasticsearch.client.Request;
12+
import org.elasticsearch.client.ResponseException;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.test.rest.ObjectPath;
15+
import org.elasticsearch.xcontent.json.JsonXContent;
16+
17+
import java.net.URL;
18+
import java.nio.charset.StandardCharsets;
19+
import java.util.Base64;
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
23+
import static org.hamcrest.Matchers.containsString;
24+
import static org.hamcrest.Matchers.equalTo;
25+
import static org.hamcrest.Matchers.hasEntry;
26+
import static org.hamcrest.Matchers.hasKey;
27+
import static org.hamcrest.Matchers.is;
28+
29+
public class SamlInResponseToIT extends SamlRestTestCase {
30+
31+
private static final int REALM_NUMBER = 1;
32+
33+
public void testInResponseTo_matchingValues() throws Exception {
34+
final String username = randomAlphaOfLengthBetween(4, 12);
35+
String requestId = generateRandomRequestId();
36+
var response = authUser(username, requestId, requestId);
37+
assertThat(response, hasKey("access_token"));
38+
}
39+
40+
public void testInResponseTo_requestAndTokenHaveDifferentValues() throws Exception {
41+
final String username = randomAlphaOfLengthBetween(4, 12);
42+
String requestIdFromRequest = generateRandomRequestId();
43+
String requestIdFromToken = generateDifferentRandomRequestId(requestIdFromRequest);
44+
45+
var exception = expectThrows(ResponseException.class, () -> authUser(username, requestIdFromRequest, requestIdFromToken));
46+
assertThat(exception.getResponse().getStatusLine().getStatusCode(), is(401));
47+
String errorEntity = EntityUtils.toString(exception.getResponse().getEntity());
48+
assertThat(errorEntity, containsString("\"security.saml.unsolicited_in_response_to\":\"" + requestIdFromToken + "\""));
49+
}
50+
51+
public void testInResponseTo_requestNullTokenNotNull() throws Exception {
52+
final String username = randomAlphaOfLengthBetween(4, 12);
53+
String requestIdFromToken = generateRandomRequestId();
54+
55+
var exception = expectThrows(ResponseException.class, () -> authUser(username, null, requestIdFromToken));
56+
assertThat(exception.getResponse().getStatusLine().getStatusCode(), is(401));
57+
String errorEntity = EntityUtils.toString(exception.getResponse().getEntity());
58+
assertThat(errorEntity, containsString("\"security.saml.unsolicited_in_response_to\":\"" + requestIdFromToken + "\""));
59+
}
60+
61+
public void testInResponseTo_requestNotNullTokenNull() throws Exception {
62+
final String username = randomAlphaOfLengthBetween(4, 12);
63+
String requestIdFromRequest = generateRandomRequestId();
64+
65+
var response = authUser(username, requestIdFromRequest, null);
66+
assertThat(response, hasKey("access_token"));
67+
}
68+
69+
public void testInResponseTo_requestNullTokenNull() throws Exception {
70+
final String username = randomAlphaOfLengthBetween(4, 12);
71+
var response = authUser(username, null, null);
72+
assertThat(response, hasKey("access_token"));
73+
}
74+
75+
private String generateRandomRequestId() {
76+
return randomAlphaOfLength(1) + randomAlphanumericOfLength(random().nextInt(10));
77+
}
78+
79+
private String generateDifferentRandomRequestId(String existingId) {
80+
String newId;
81+
do {
82+
newId = generateRandomRequestId();
83+
} while (newId.equals(existingId));
84+
return newId;
85+
}
86+
87+
private Map<String, Object> authUser(String username, String inResponseToFromHeader, String inResponseToInSamlToken) throws Exception {
88+
89+
var httpsAddress = getAcsHttpsAddress();
90+
var message = new SamlResponseBuilder().spEntityId("https://sp" + REALM_NUMBER + ".example.org/")
91+
.idpEntityId(getIdpEntityId(REALM_NUMBER))
92+
.acs(new URL("https://" + httpsAddress.getHostName() + ":" + httpsAddress.getPort() + "/acs/" + REALM_NUMBER))
93+
.attribute("urn:oid:2.5.4.3", username)
94+
.sign(getDataPath(SAML_SIGNING_CRT), getDataPath(SAML_SIGNING_KEY), new char[0])
95+
.inResponseTo(inResponseToInSamlToken)
96+
.asString();
97+
98+
final Map<String, Object> body = new HashMap<>();
99+
body.put("content", Base64.getEncoder().encodeToString(message.getBytes(StandardCharsets.UTF_8)));
100+
body.put("realm", getSamlRealmName(REALM_NUMBER));
101+
if (inResponseToFromHeader != null) {
102+
body.put("ids", inResponseToFromHeader);
103+
}
104+
105+
var req = new Request("POST", "_security/saml/authenticate");
106+
req.setJsonEntity(Strings.toString(JsonXContent.contentBuilder().map(body)));
107+
var resp = entityAsMap(client().performRequest(req));
108+
assertThat(resp, hasEntry("username", username));
109+
assertThat(ObjectPath.evaluate(resp, "authentication.authentication_realm.name"), equalTo(getSamlRealmName(REALM_NUMBER)));
110+
return resp;
111+
}
112+
}

x-pack/plugin/security/qa/saml-rest-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/saml/SamlResponseBuilder.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ public SamlResponseBuilder attribute(String name, String value) {
122122
return this;
123123
}
124124

125+
public SamlResponseBuilder inResponseTo(String requestId) {
126+
this.inResponseTo = requestId;
127+
return this;
128+
}
129+
125130
public SamlResponseBuilder sign(Path certPath, Path keyPath, char[] keyPassword) throws GeneralSecurityException, IOException {
126131
var privateKey = PemUtils.readPrivateKey(keyPath, () -> keyPassword);
127132
var certificates = PemUtils.readCertificates(List.of(certPath));

0 commit comments

Comments
 (0)