Skip to content

Commit 8e8f3e7

Browse files
committed
Fix race in S3HttpHandler overwrite protection
1 parent 3d54119 commit 8e8f3e7

File tree

2 files changed

+116
-37
lines changed

2 files changed

+116
-37
lines changed

test/fixtures/s3-fixture/src/main/java/fixture/s3/S3HttpHandler.java

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,16 @@ public void handle(final HttpExchange exchange) throws IOException {
208208
} else {
209209
final var blobContents = upload.complete(extractPartEtags(Streams.readFully(exchange.getRequestBody())));
210210

211-
if (isProtectOverwrite(exchange) && blobs.containsKey(request.path())) {
212-
preconditionFailed = true;
213-
responseBody = null;
211+
if (isProtectOverwrite(exchange)) {
212+
var previousValue = blobs.putIfAbsent(request.path(), blobContents);
213+
if (previousValue != null) {
214+
preconditionFailed = true;
215+
}
214216
} else {
215217
blobs.put(request.path(), blobContents);
218+
}
219+
220+
if (preconditionFailed == false) {
216221
responseBody = ("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
217222
+ "<CompleteMultipartUploadResult>\n"
218223
+ "<Bucket>"
@@ -222,6 +227,8 @@ public void handle(final HttpExchange exchange) throws IOException {
222227
+ request.path()
223228
+ "</Key>\n"
224229
+ "</CompleteMultipartUploadResult>").getBytes(StandardCharsets.UTF_8);
230+
} else {
231+
responseBody = null;
225232
}
226233
}
227234
}
@@ -241,27 +248,50 @@ public void handle(final HttpExchange exchange) throws IOException {
241248
} else if (request.isPutObjectRequest()) {
242249
// a copy request is a put request with an X-amz-copy-source header
243250
final var copySource = copySourceName(exchange);
244-
if (isProtectOverwrite(exchange) && blobs.containsKey(request.path())) {
245-
exchange.sendResponseHeaders(RestStatus.PRECONDITION_FAILED.getStatus(), -1);
246-
} else if (copySource != null) {
251+
if (copySource != null) {
247252
var sourceBlob = blobs.get(copySource);
248253
if (sourceBlob == null) {
249254
exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1);
250255
} else {
251-
blobs.put(request.path(), sourceBlob);
252-
253-
byte[] response = ("""
254-
<?xml version="1.0" encoding="UTF-8"?>
255-
<CopyObjectResult></CopyObjectResult>""").getBytes(StandardCharsets.UTF_8);
256-
exchange.getResponseHeaders().add("Content-Type", "application/xml");
257-
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
258-
exchange.getResponseBody().write(response);
256+
boolean preconditionFailed = false;
257+
if (isProtectOverwrite(exchange)) {
258+
var previousValue = blobs.putIfAbsent(request.path(), sourceBlob);
259+
if (previousValue != null) {
260+
preconditionFailed = true;
261+
}
262+
} else {
263+
blobs.put(request.path(), sourceBlob);
264+
}
265+
266+
if (preconditionFailed) {
267+
exchange.sendResponseHeaders(RestStatus.PRECONDITION_FAILED.getStatus(), -1);
268+
} else {
269+
byte[] response = ("""
270+
<?xml version="1.0" encoding="UTF-8"?>
271+
<CopyObjectResult></CopyObjectResult>""").getBytes(StandardCharsets.UTF_8);
272+
exchange.getResponseHeaders().add("Content-Type", "application/xml");
273+
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
274+
exchange.getResponseBody().write(response);
275+
}
259276
}
260277
} else {
261278
final Tuple<String, BytesReference> blob = parseRequestBody(exchange);
262-
blobs.put(request.path(), blob.v2());
263-
exchange.getResponseHeaders().add("ETag", blob.v1());
264-
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
279+
boolean preconditionFailed = false;
280+
if (isProtectOverwrite(exchange)) {
281+
var previousValue = blobs.putIfAbsent(request.path(), blob.v2());
282+
if (previousValue != null) {
283+
preconditionFailed = true;
284+
}
285+
} else {
286+
blobs.put(request.path(), blob.v2());
287+
}
288+
289+
if (preconditionFailed) {
290+
exchange.sendResponseHeaders(RestStatus.PRECONDITION_FAILED.getStatus(), -1);
291+
} else {
292+
exchange.getResponseHeaders().add("ETag", blob.v1());
293+
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
294+
}
265295
}
266296

267297
} else if (request.isListObjectsRequest()) {

test/fixtures/s3-fixture/src/test/java/fixture/s3/S3HttpHandlerTests.java

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@
3232
import java.util.ArrayList;
3333
import java.util.List;
3434
import java.util.Objects;
35+
import java.util.concurrent.Executors;
36+
import java.util.concurrent.TimeUnit;
37+
import java.util.function.Consumer;
3538

3639
import static org.hamcrest.Matchers.allOf;
3740
import static org.hamcrest.Matchers.containsString;
3841
import static org.hamcrest.Matchers.greaterThan;
42+
import static org.hamcrest.Matchers.hasSize;
3943

4044
public class S3HttpHandlerTests extends ESTestCase {
4145

@@ -383,35 +387,80 @@ public void testExtractPartEtags() {
383387

384388
}
385389

386-
public void testPreventObjectOverwrite() {
390+
public void testPreventObjectOverwrite() throws InterruptedException {
387391
final var handler = new S3HttpHandler("bucket", "path");
388392

389-
final var body = randomBytesReference(50);
390-
assertEquals(RestStatus.OK, handleRequest(handler, "PUT", "/bucket/path/blob", body, ifNoneMatchHeader()).status());
391-
assertEquals(
392-
RestStatus.PRECONDITION_FAILED,
393-
handleRequest(handler, "PUT", "/bucket/path/blob", body, ifNoneMatchHeader()).status()
394-
);
395-
396-
// multipart upload
397-
final var createUploadResponse = handleRequest(handler, "POST", "/bucket/path/blob?uploads");
398-
final var uploadId = getUploadId(createUploadResponse.body());
399-
400-
final var part1 = randomAlphaOfLength(50);
401-
final var uploadPart1Response = handleRequest(handler, "PUT", "/bucket/path/blob?uploadId=" + uploadId + "&partNumber=1", part1);
402-
final var part1Etag = Objects.requireNonNull(uploadPart1Response.etag());
403-
404-
assertEquals(
405-
RestStatus.PRECONDITION_FAILED,
406-
handleRequest(handler, "POST", "/bucket/path/blob?uploadId=" + uploadId, new BytesArray(Strings.format("""
393+
Consumer<TestWriteTask> putObjectConsumer = (task) -> task.status = handleRequest(
394+
handler,
395+
"PUT",
396+
"/bucket/path/blob",
397+
task.body,
398+
ifNoneMatchHeader()
399+
).status();
400+
401+
Consumer<TestWriteTask> prepareMultipartUploadConsumer = (task) -> {
402+
final var createUploadResponse = handleRequest(handler, "POST", "/bucket/path/blob?uploads");
403+
task.uploadId = getUploadId(createUploadResponse.body());
404+
405+
final var uploadPart1Response = handleRequest(
406+
handler,
407+
"PUT",
408+
"/bucket/path/blob?uploadId=" + task.uploadId + "&partNumber=1",
409+
task.body
410+
);
411+
task.etag = Objects.requireNonNull(uploadPart1Response.etag());
412+
};
413+
414+
Consumer<TestWriteTask> completeMultipartUploadConsumer = (task) -> {
415+
task.status = handleRequest(handler, "POST", "/bucket/path/blob?uploadId=" + task.uploadId, new BytesArray(Strings.format("""
407416
<?xml version="1.0" encoding="UTF-8"?>
408417
<CompleteMultipartUpload xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
409418
<Part>
410419
<ETag>%s</ETag>
411420
<PartNumber>1</PartNumber>
412421
</Part>
413-
</CompleteMultipartUpload>""", part1Etag)), ifNoneMatchHeader()).status()
422+
</CompleteMultipartUpload>""", task.etag)), ifNoneMatchHeader()).status();
423+
};
424+
425+
var tasks = List.of(
426+
new TestWriteTask(putObjectConsumer),
427+
new TestWriteTask(putObjectConsumer),
428+
new TestWriteTask(completeMultipartUploadConsumer, prepareMultipartUploadConsumer),
429+
new TestWriteTask(completeMultipartUploadConsumer, prepareMultipartUploadConsumer)
414430
);
431+
432+
try (var executor = Executors.newVirtualThreadPerTaskExecutor()) {
433+
tasks.forEach(task -> executor.submit(task.consumer));
434+
executor.shutdown();
435+
var done = executor.awaitTermination(1, TimeUnit.SECONDS);
436+
assertTrue(done);
437+
}
438+
439+
List<TestWriteTask> successfulTasks = tasks.stream().filter(task -> task.status == RestStatus.OK).toList();
440+
assertThat(successfulTasks, hasSize(1));
441+
442+
assertEquals(
443+
new TestHttpResponse(RestStatus.OK, successfulTasks.getFirst().body, TestHttpExchange.EMPTY_HEADERS),
444+
handleRequest(handler, "GET", "/bucket/path/blob")
445+
);
446+
}
447+
448+
private static class TestWriteTask {
449+
final BytesReference body;
450+
final Runnable consumer;
451+
String uploadId;
452+
String etag;
453+
RestStatus status;
454+
455+
TestWriteTask(Consumer<TestWriteTask> consumer, Consumer<TestWriteTask> prepare) {
456+
this(consumer);
457+
prepare.accept(this);
458+
}
459+
460+
TestWriteTask(Consumer<TestWriteTask> consumer) {
461+
this.body = randomBytesReference(50);
462+
this.consumer = () -> consumer.accept(this);
463+
}
415464
}
416465

417466
private void runExtractPartETagsTest(String body, String... expectedTags) {

0 commit comments

Comments
 (0)