diff --git a/sdk/storage/azure-storage-blob/assets.json b/sdk/storage/azure-storage-blob/assets.json index 29bcfdae5bf9..bf8353895683 100644 --- a/sdk/storage/azure-storage-blob/assets.json +++ b/sdk/storage/azure-storage-blob/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "java", "TagPrefix": "java/storage/azure-storage-blob", - "Tag": "java/storage/azure-storage-blob_80c07fe827" + "Tag": "java/storage/azure-storage-blob_c976afa88e" } diff --git a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java index 7c56941c7014..b203e3c123de 100644 --- a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java +++ b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java @@ -39,6 +39,7 @@ import com.azure.storage.common.policy.ResponseValidationPolicyBuilder; import com.azure.storage.common.policy.ScrubEtagPolicy; import com.azure.storage.common.policy.StorageBearerTokenChallengeAuthorizationPolicy; +import com.azure.storage.common.policy.StorageContentValidationDecoderPolicy; import com.azure.storage.common.policy.StorageSharedKeyCredentialPolicy; import java.net.MalformedURLException; @@ -140,6 +141,8 @@ public static HttpPipeline buildPipeline(StorageSharedKeyCredential storageShare HttpPolicyProviders.addAfterRetryPolicies(policies); + policies.add(new StorageContentValidationDecoderPolicy()); + policies.add(getResponseValidationPolicy()); policies.add(new HttpLoggingPolicy(logOptions)); diff --git a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java index 812fabc80214..e093fd85cf8b 100644 --- a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java +++ b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java @@ -79,8 +79,10 @@ import com.azure.storage.blob.options.BlobSetAccessTierOptions; import com.azure.storage.blob.options.BlobSetTagsOptions; import com.azure.storage.blob.sas.BlobServiceSasSignatureValues; +import com.azure.storage.common.DownloadContentValidationOptions; import com.azure.storage.common.StorageSharedKeyCredential; import com.azure.storage.common.Utility; +import com.azure.storage.common.implementation.Constants; import com.azure.storage.common.implementation.SasImplUtils; import com.azure.storage.common.implementation.StorageImplUtils; import reactor.core.publisher.Flux; @@ -1173,6 +1175,52 @@ public Mono downloadStreamWithResponse(BlobRange rang } } + /** + * Reads a range of bytes from a blob with content validation options. Uploading data must be done from the {@link BlockBlobClient}, {@link + * PageBlobClient}, or {@link AppendBlobClient}. + * + *

Code Samples

+ * + *
{@code
+     * BlobRange range = new BlobRange(1024, 2048L);
+     * DownloadRetryOptions options = new DownloadRetryOptions().setMaxRetryRequests(5);
+     * DownloadContentValidationOptions validationOptions = new DownloadContentValidationOptions()
+     *     .setStructuredMessageValidationEnabled(true);
+     *
+     * client.downloadStreamWithResponse(range, options, null, false, validationOptions).subscribe(response -> {
+     *     ByteArrayOutputStream downloadData = new ByteArrayOutputStream();
+     *     response.getValue().subscribe(piece -> {
+     *         try {
+     *             downloadData.write(piece.array());
+     *         } catch (IOException ex) {
+     *             throw new UncheckedIOException(ex);
+     *         }
+     *     });
+     * });
+     * }
+ * + *

For more information, see the + * Azure Docs

+ * + * @param range {@link BlobRange} + * @param options {@link DownloadRetryOptions} + * @param requestConditions {@link BlobRequestConditions} + * @param getRangeContentMd5 Whether the contentMD5 for the specified blob range should be returned. + * @param contentValidationOptions {@link DownloadContentValidationOptions} options for content validation + * @return A reactive response containing the blob data. + */ + @ServiceMethod(returns = ReturnType.SINGLE) + public Mono downloadStreamWithResponse(BlobRange range, DownloadRetryOptions options, + BlobRequestConditions requestConditions, boolean getRangeContentMd5, + DownloadContentValidationOptions contentValidationOptions) { + try { + return withContext(context -> downloadStreamWithResponse(range, options, requestConditions, + getRangeContentMd5, contentValidationOptions, context)); + } catch (RuntimeException ex) { + return monoError(LOGGER, ex); + } + } + /** * Reads a range of bytes from a blob. Uploading data must be done from the {@link BlockBlobClient}, {@link * PageBlobClient}, or {@link AppendBlobClient}. @@ -1215,19 +1263,41 @@ public Mono downloadContentWithResponse(Downlo } Mono downloadStreamWithResponse(BlobRange range, DownloadRetryOptions options, - BlobRequestConditions requestConditions, boolean getRangeContentMd5, Context context) { + BlobRequestConditions requestConditions, boolean getRangeContentMd5, + DownloadContentValidationOptions contentValidationOptions, Context context) { BlobRange finalRange = range == null ? new BlobRange(0) : range; - Boolean getMD5 = getRangeContentMd5 ? getRangeContentMd5 : null; + + // Determine MD5 validation: properly consider both getRangeContentMd5 parameter and validation options + // MD5 validation is enabled if: + // 1. getRangeContentMd5 is explicitly true, OR + // 2. contentValidationOptions.isMd5ValidationEnabled() is true + final Boolean finalGetMD5; + if (getRangeContentMd5 + || (contentValidationOptions != null && contentValidationOptions.isMd5ValidationEnabled())) { + finalGetMD5 = true; + } else { + finalGetMD5 = null; + } + BlobRequestConditions finalRequestConditions = requestConditions == null ? new BlobRequestConditions() : requestConditions; DownloadRetryOptions finalOptions = (options == null) ? new DownloadRetryOptions() : options; // The first range should eagerly convert headers as they'll be used to create response types. - Context firstRangeContext = context == null + Context initialContext = context == null ? new Context("azure-eagerly-convert-headers", true) : context.addData("azure-eagerly-convert-headers", true); - return downloadRange(finalRange, finalRequestConditions, finalRequestConditions.getIfMatch(), getMD5, + // Add structured message decoding context if enabled + final Context firstRangeContext; + if (contentValidationOptions != null && contentValidationOptions.isStructuredMessageValidationEnabled()) { + firstRangeContext = initialContext.addData(Constants.STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY, true) + .addData(Constants.STRUCTURED_MESSAGE_VALIDATION_OPTIONS_CONTEXT_KEY, contentValidationOptions); + } else { + firstRangeContext = initialContext; + } + + return downloadRange(finalRange, finalRequestConditions, finalRequestConditions.getIfMatch(), finalGetMD5, firstRangeContext).map(response -> { BlobsDownloadHeaders blobsDownloadHeaders = new BlobsDownloadHeaders(response.getHeaders()); String eTag = blobsDownloadHeaders.getETag(); @@ -1271,16 +1341,22 @@ Mono downloadStreamWithResponse(BlobRange range, Down try { return downloadRange(new BlobRange(initialOffset + offset, newCount), finalRequestConditions, - eTag, getMD5, context); + eTag, finalGetMD5, firstRangeContext); } catch (Exception e) { return Mono.error(e); } }; + // Structured message decoding is now handled by StructuredMessageDecoderPolicy return BlobDownloadAsyncResponseConstructorProxy.create(response, onDownloadErrorResume, finalOptions); }); } + Mono downloadStreamWithResponse(BlobRange range, DownloadRetryOptions options, + BlobRequestConditions requestConditions, boolean getRangeContentMd5, Context context) { + return downloadStreamWithResponse(range, options, requestConditions, getRangeContentMd5, null, context); + } + private Mono downloadRange(BlobRange range, BlobRequestConditions requestConditions, String eTag, Boolean getMD5, Context context) { return azureBlobStorage.getBlobs() diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobMessageDecoderDownloadTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobMessageDecoderDownloadTests.java new file mode 100644 index 000000000000..5508ddc30831 --- /dev/null +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobMessageDecoderDownloadTests.java @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.blob; + +import com.azure.core.test.utils.TestUtils; +import com.azure.core.util.FluxUtil; +import com.azure.storage.blob.models.BlobRange; +import com.azure.storage.common.DownloadContentValidationOptions; +import com.azure.storage.common.implementation.Constants; +import com.azure.storage.common.implementation.structuredmessage.StructuredMessageEncoder; +import com.azure.storage.common.implementation.structuredmessage.StructuredMessageFlags; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for structured message decoding during blob downloads using StorageContentValidationDecoderPolicy. + * These tests verify that the pipeline policy correctly decodes structured messages when content validation is enabled. + */ +public class BlobMessageDecoderDownloadTests extends BlobTestBase { + + private BlobAsyncClient bc; + + @BeforeEach + public void setup() { + String blobName = generateBlobName(); + bc = ccAsync.getBlobAsyncClient(blobName); + bc.upload(Flux.just(ByteBuffer.wrap(new byte[0])), null).block(); + } + + @Test + public void downloadStreamWithResponseContentValidation() throws IOException { + byte[] randomData = getRandomByteArray(Constants.KB); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 512, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + DownloadContentValidationOptions validationOptions + = new DownloadContentValidationOptions().setStructuredMessageValidationEnabled(true); + + StepVerifier + .create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(null, null, null, false, validationOptions)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(r -> TestUtils.assertArraysEqual(r, randomData)) + .verifyComplete(); + } + + @Test + public void downloadStreamWithResponseContentValidationRange() throws IOException { + byte[] randomData = getRandomByteArray(Constants.KB); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 512, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + DownloadContentValidationOptions validationOptions + = new DownloadContentValidationOptions().setStructuredMessageValidationEnabled(true); + + BlobRange range = new BlobRange(0, 512L); + + StepVerifier.create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(range, null, null, false, validationOptions)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(r -> { + assertNotNull(r); + assertTrue(r.length > 0); + }).verifyComplete(); + } + + @Test + public void downloadStreamWithResponseContentValidationLargeBlob() throws IOException { + // Test with larger data to verify chunking works correctly + byte[] randomData = getRandomByteArray(5 * Constants.KB); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 1024, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + DownloadContentValidationOptions validationOptions + = new DownloadContentValidationOptions().setStructuredMessageValidationEnabled(true); + + StepVerifier + .create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(null, null, null, false, validationOptions)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(r -> TestUtils.assertArraysEqual(r, randomData)) + .verifyComplete(); + } + + @Test + public void downloadStreamWithResponseContentValidationMultipleSegments() throws IOException { + // Test with multiple segments to ensure all segments are decoded correctly + byte[] randomData = getRandomByteArray(2 * Constants.KB); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 512, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + DownloadContentValidationOptions validationOptions + = new DownloadContentValidationOptions().setStructuredMessageValidationEnabled(true); + + StepVerifier + .create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(null, null, null, false, validationOptions)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(r -> TestUtils.assertArraysEqual(r, randomData)) + .verifyComplete(); + } + + @Test + public void downloadStreamWithResponseNoValidation() throws IOException { + // Test that download works normally when validation is not enabled + byte[] randomData = getRandomByteArray(Constants.KB); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 512, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + // No validation options - should download encoded data as-is + StepVerifier.create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(null, null, null, false)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(r -> { + assertNotNull(r); + // Should get encoded data, not decoded + assertTrue(r.length > randomData.length); // Encoded data is larger + }).verifyComplete(); + } + + @Test + public void downloadStreamWithResponseValidationDisabled() throws IOException { + // Test with validation options but validation disabled + byte[] randomData = getRandomByteArray(Constants.KB); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 512, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + DownloadContentValidationOptions validationOptions + = new DownloadContentValidationOptions().setStructuredMessageValidationEnabled(false); + + StepVerifier.create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(null, null, null, false, validationOptions)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(r -> { + assertNotNull(r); + // Should get encoded data, not decoded + assertTrue(r.length > randomData.length); // Encoded data is larger + }).verifyComplete(); + } + + @Test + public void downloadStreamWithResponseContentValidationSmallSegment() throws IOException { + // Test with small segment size to ensure boundary conditions are handled + byte[] randomData = getRandomByteArray(256); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 128, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + DownloadContentValidationOptions validationOptions + = new DownloadContentValidationOptions().setStructuredMessageValidationEnabled(true); + + StepVerifier + .create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(null, null, null, false, validationOptions)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(r -> TestUtils.assertArraysEqual(r, randomData)) + .verifyComplete(); + } + + @Test + public void downloadStreamWithResponseContentValidationVeryLargeBlob() throws IOException { + // Test with very large data to verify chunking and policy work correctly with large blobs + byte[] randomData = getRandomByteArray(10 * Constants.KB); + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(randomData.length, 2048, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = encoder.encode(ByteBuffer.wrap(randomData)); + + Flux input = Flux.just(encodedData); + + DownloadContentValidationOptions validationOptions + = new DownloadContentValidationOptions().setStructuredMessageValidationEnabled(true); + + StepVerifier + .create(bc.upload(input, null, true) + .then(bc.downloadStreamWithResponse(null, null, null, false, validationOptions)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(r -> TestUtils.assertArraysEqual(r, randomData)) + .verifyComplete(); + } +} diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/DownloadContentValidationOptions.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/DownloadContentValidationOptions.java new file mode 100644 index 000000000000..2b663494bfe9 --- /dev/null +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/DownloadContentValidationOptions.java @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common; + +import com.azure.core.annotation.Fluent; + +/** + * Options for content validation during download operations. + */ +@Fluent +public final class DownloadContentValidationOptions { + private boolean enableStructuredMessageValidation; + private boolean enableMd5Validation; + + /** + * Creates a new instance of DownloadContentValidationOptions. + */ + public DownloadContentValidationOptions() { + this.enableStructuredMessageValidation = false; + this.enableMd5Validation = false; + } + + /** + * Gets whether structured message validation is enabled. + * + * @return true if structured message validation is enabled, false otherwise. + */ + public boolean isStructuredMessageValidationEnabled() { + return enableStructuredMessageValidation; + } + + /** + * Sets whether structured message validation is enabled. + * When enabled, downloads will use CRC64 checksums embedded in structured messages for content validation. + * + * @param enableStructuredMessageValidation true to enable structured message validation, false to disable. + * @return The updated DownloadContentValidationOptions object. + */ + public DownloadContentValidationOptions + setStructuredMessageValidationEnabled(boolean enableStructuredMessageValidation) { + this.enableStructuredMessageValidation = enableStructuredMessageValidation; + return this; + } + + /** + * Gets whether MD5 validation is enabled. + * + * @return true if MD5 validation is enabled, false otherwise. + */ + public boolean isMd5ValidationEnabled() { + return enableMd5Validation; + } + + /** + * Sets whether MD5 validation is enabled. + * When enabled, downloads will use MD5 checksums for content validation. + * + * @param enableMd5Validation true to enable MD5 validation, false to disable. + * @return The updated DownloadContentValidationOptions object. + */ + public DownloadContentValidationOptions setMd5ValidationEnabled(boolean enableMd5Validation) { + this.enableMd5Validation = enableMd5Validation; + return this; + } +} diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/Constants.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/Constants.java index 34110d163145..5f6c36f85d4b 100644 --- a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/Constants.java +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/Constants.java @@ -94,6 +94,17 @@ public final class Constants { public static final String SKIP_ECHO_VALIDATION_KEY = "skipEchoValidation"; + /** + * Context key used to signal that structured message decoding should be applied. + */ + public static final String STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY = "azure-storage-structured-message-decoding"; + + /** + * Context key used to pass DownloadContentValidationOptions to the policy. + */ + public static final String STRUCTURED_MESSAGE_VALIDATION_OPTIONS_CONTEXT_KEY + = "azure-storage-structured-message-validation-options"; + private Constants() { } diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecoder.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecoder.java new file mode 100644 index 000000000000..6117a7765541 --- /dev/null +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecoder.java @@ -0,0 +1,267 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common.implementation.structuredmessage; + +import com.azure.core.util.logging.ClientLogger; +import com.azure.storage.common.implementation.StorageCrc64Calculator; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; + +import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.CRC64_LENGTH; +import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.DEFAULT_MESSAGE_VERSION; +import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.V1_HEADER_LENGTH; +import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH; + +/** + * Decoder for structured messages with support for segmenting and CRC64 checksums. + */ +public class StructuredMessageDecoder { + private static final ClientLogger LOGGER = new ClientLogger(StructuredMessageDecoder.class); + private long messageLength; + private StructuredMessageFlags flags; + private int numSegments; + private final long expectedContentLength; + + private int messageOffset = 0; + private int currentSegmentNumber = 0; + private int currentSegmentContentLength = 0; + private int currentSegmentContentOffset = 0; + + private long messageCrc64 = 0; + private long segmentCrc64 = 0; + private final Map segmentCrcs = new HashMap<>(); + + /** + * Constructs a new StructuredMessageDecoder. + * + * @param expectedContentLength The expected length of the content to be decoded. + */ + public StructuredMessageDecoder(long expectedContentLength) { + this.expectedContentLength = expectedContentLength; + } + + /** + * Reads the message header from the given buffer. + * + * @param buffer The buffer containing the message header. + * @throws IllegalArgumentException if the buffer does not contain a valid message header. + */ + private void readMessageHeader(ByteBuffer buffer) { + if (buffer.remaining() < V1_HEADER_LENGTH) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Content not long enough to contain a valid " + "message header.")); + } + + int messageVersion = Byte.toUnsignedInt(buffer.get()); + if (messageVersion != DEFAULT_MESSAGE_VERSION) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Unsupported structured message version: " + messageVersion)); + } + + messageLength = (int) buffer.getLong(); + if (messageLength < V1_HEADER_LENGTH) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Content not long enough to contain a valid " + "message header.")); + } + if (messageLength != expectedContentLength) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Structured message length " + messageLength + + " did not match content length " + expectedContentLength)); + } + + flags = StructuredMessageFlags.fromValue(Short.toUnsignedInt(buffer.getShort())); + numSegments = Short.toUnsignedInt(buffer.getShort()); + + messageOffset += V1_HEADER_LENGTH; + } + + /** + * Reads the segment header from the given buffer. + * + * @param buffer The buffer containing the segment header. + * @throws IllegalArgumentException if the buffer does not contain a valid segment header. + */ + private void readSegmentHeader(ByteBuffer buffer) { + if (buffer.remaining() < V1_SEGMENT_HEADER_LENGTH) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Segment header is incomplete.")); + } + + int segmentNum = Short.toUnsignedInt(buffer.getShort()); + int segmentSize = (int) buffer.getLong(); + + if (segmentSize < 0 || segmentSize > buffer.remaining()) { + throw LOGGER + .logExceptionAsError(new IllegalArgumentException("Invalid segment size detected: " + segmentSize)); + } + + if (segmentNum != currentSegmentNumber + 1) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Unexpected segment number.")); + } + + currentSegmentNumber = segmentNum; + currentSegmentContentLength = segmentSize; + currentSegmentContentOffset = 0; + + if (segmentSize == 0) { + readSegmentFooter(buffer); + } + + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + segmentCrc64 = 0; + } + + messageOffset += V1_SEGMENT_HEADER_LENGTH; + } + + /** + * Reads the segment content from the given buffer and writes it to the output stream. + * + * @param buffer The buffer containing the segment content. + * @param output The output stream to write the segment content to. + * @param size The maximum number of bytes to read. + * @throws IllegalArgumentException if there is a segment size mismatch. + */ + private void readSegmentContent(ByteBuffer buffer, ByteArrayOutputStream output, int size) { + int toRead = Math.min(buffer.remaining(), currentSegmentContentLength - currentSegmentContentOffset); + toRead = Math.min(toRead, size); + + if (toRead == 0) { + return; + } + + byte[] content = new byte[toRead]; + buffer.get(content); + output.write(content, 0, toRead); + + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + segmentCrc64 = StorageCrc64Calculator.compute(content, segmentCrc64); + messageCrc64 = StorageCrc64Calculator.compute(content, messageCrc64); + } + + messageOffset += toRead; + currentSegmentContentOffset += toRead; + + if (currentSegmentContentOffset > currentSegmentContentLength) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Segment size mismatch detected in segment " + currentSegmentNumber)); + } + + if (currentSegmentContentOffset == currentSegmentContentLength) { + readSegmentFooter(buffer); + } + } + + /** + * Reads the segment footer from the given buffer. + * + * @param buffer The buffer containing the segment footer. + * @throws IllegalArgumentException if the buffer does not contain a valid segment footer. + */ + private void readSegmentFooter(ByteBuffer buffer) { + if (currentSegmentContentOffset != currentSegmentContentLength) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Segment content length mismatch in segment " + currentSegmentNumber + + ". Expected: " + currentSegmentContentLength + ", Read: " + currentSegmentContentOffset)); + } + + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + if (buffer.remaining() < CRC64_LENGTH) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Segment footer is incomplete.")); + } + + long reportedCrc64 = buffer.getLong(); + if (segmentCrc64 != reportedCrc64) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("CRC64 mismatch detected in segment " + currentSegmentNumber)); + } + segmentCrcs.put(currentSegmentNumber, segmentCrc64); + messageOffset += CRC64_LENGTH; + } + + if (currentSegmentNumber == numSegments) { + readMessageFooter(buffer); + } else { + readSegmentHeader(buffer); + } + } + + /** + * Reads the segment footer from the given buffer. + * + * @param buffer The buffer containing the segment footer. + * @throws IllegalArgumentException if the buffer does not contain a valid segment footer. + */ + private void readMessageFooter(ByteBuffer buffer) { + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + if (buffer.remaining() < CRC64_LENGTH) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Message footer is incomplete.")); + } + + long reportedCrc = buffer.getLong(); + if (messageCrc64 != reportedCrc) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("CRC64 mismatch detected in message " + "footer.")); + } + messageOffset += CRC64_LENGTH; + } + + if (messageOffset != messageLength) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Decoded message length does not match " + "expected length.")); + } + } + + /** + * Decodes the structured message from the given buffer up to the specified size. + * + * @param buffer The buffer containing the structured message. + * @param size The maximum number of bytes to decode. + * @return A ByteBuffer containing the decoded message content. + * @throws IllegalArgumentException if the buffer does not contain a valid structured message. + */ + public ByteBuffer decode(ByteBuffer buffer, int size) { + buffer.order(ByteOrder.LITTLE_ENDIAN); + ByteArrayOutputStream decodedContent = new ByteArrayOutputStream(); + + if (messageOffset == 0) { + readMessageHeader(buffer); + } + + while (buffer.hasRemaining() && decodedContent.size() < size) { + if (currentSegmentContentOffset == currentSegmentContentLength) { + readSegmentHeader(buffer); + } + + readSegmentContent(buffer, decodedContent, size - decodedContent.size()); + } + + return ByteBuffer.wrap(decodedContent.toByteArray()); + } + + /** + * Decodes the entire structured message from the given buffer. + * + * @param buffer The buffer containing the structured message. + * @return A ByteBuffer containing the decoded message content. + * @throws IllegalArgumentException if the buffer does not contain a valid structured message. + */ + public ByteBuffer decode(ByteBuffer buffer) { + return decode(buffer, buffer.remaining()); + } + + /** + * Finalizes the decoding process and validates that the entire message has been decoded. + * + * @throws IllegalArgumentException if the decoded message length does not match the expected length. + */ + public void finalizeDecoding() { + if (messageOffset != messageLength) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Decoded message length does not match " + + "expected length. Expected: " + messageLength + ", but was: " + messageOffset)); + } + } +} diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecodingStream.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecodingStream.java new file mode 100644 index 000000000000..5fec64e0c18a --- /dev/null +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecodingStream.java @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common.implementation.structuredmessage; + +import com.azure.core.util.logging.ClientLogger; +import com.azure.storage.common.DownloadContentValidationOptions; +import reactor.core.publisher.Flux; + +import java.nio.ByteBuffer; + +/** + * A utility class for applying structured message decoding to download streams. + */ +public final class StructuredMessageDecodingStream { + private static final ClientLogger LOGGER = new ClientLogger(StructuredMessageDecodingStream.class); + + private StructuredMessageDecodingStream() { + // utility class + } + + /** + * Wraps a download stream with structured message decoding if content validation is enabled. + * + * @param originalStream The original download stream. + * @param contentLength The expected content length. + * @param validationOptions The content validation options. + * @return A Flux that decodes structured messages if validation is enabled, otherwise returns the original stream. + */ + public static Flux wrapStreamIfNeeded(Flux originalStream, Long contentLength, + DownloadContentValidationOptions validationOptions) { + + if (validationOptions == null || !validationOptions.isStructuredMessageValidationEnabled()) { + return originalStream; + } + + if (contentLength == null || contentLength <= 0) { + LOGGER.warning("Cannot apply structured message validation without valid content length."); + return originalStream; + } + + return applyStructuredMessageDecoding(originalStream, contentLength); + } + + /** + * Applies structured message decoding to the stream. + * + * @param stream The stream to decode. + * @param expectedContentLength The expected content length. + * @return A Flux that decodes the structured message. + */ + private static Flux applyStructuredMessageDecoding(Flux stream, + long expectedContentLength) { + return stream + .collect(() -> new StructuredMessageDecodingCollector(expectedContentLength), + StructuredMessageDecodingCollector::addBuffer) + .flatMapMany(collector -> collector.getDecodedData()); + } + + /** + * Helper class to collect and decode structured message data. + */ + private static class StructuredMessageDecodingCollector { + private final StructuredMessageDecoder decoder; + private ByteBuffer accumulatedBuffer; + private boolean completed = false; + + StructuredMessageDecodingCollector(long expectedContentLength) { + this.decoder = new StructuredMessageDecoder(expectedContentLength); + this.accumulatedBuffer = ByteBuffer.allocate(0); + } + + void addBuffer(ByteBuffer buffer) { + if (completed) { + return; + } + + // Accumulate the buffer + ByteBuffer newBuffer = ByteBuffer.allocate(accumulatedBuffer.remaining() + buffer.remaining()); + newBuffer.put(accumulatedBuffer); + newBuffer.put(buffer); + newBuffer.flip(); + accumulatedBuffer = newBuffer; + } + + Flux getDecodedData() { + try { + if (accumulatedBuffer.remaining() == 0) { + return Flux.empty(); + } + + ByteBuffer decodedData = decoder.decode(accumulatedBuffer); + decoder.finalizeDecoding(); + completed = true; + + return Flux.just(decodedData); + } catch (Exception e) { + LOGGER.error("Failed to decode structured message: " + e.getMessage(), e); + return Flux.error(e); + } + } + } +} diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java new file mode 100644 index 000000000000..7652bb846e82 --- /dev/null +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common.policy; + +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpMethod; +import com.azure.core.http.HttpPipelineCallContext; +import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpResponse; +import com.azure.core.http.policy.HttpPipelinePolicy; +import com.azure.core.util.FluxUtil; +import com.azure.core.util.logging.ClientLogger; +import com.azure.storage.common.DownloadContentValidationOptions; +import com.azure.storage.common.implementation.Constants; +import com.azure.storage.common.implementation.structuredmessage.StructuredMessageDecodingStream; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.nio.ByteBuffer; +import java.nio.charset.Charset; + +/** + * This is a decoding policy in an {@link com.azure.core.http.HttpPipeline} to decode structured messages in + * storage download requests. The policy checks for a context value to determine when to apply structured message decoding. + */ +public class StorageContentValidationDecoderPolicy implements HttpPipelinePolicy { + private static final ClientLogger LOGGER = new ClientLogger(StorageContentValidationDecoderPolicy.class); + + /** + * Creates a new instance of {@link StorageContentValidationDecoderPolicy}. + */ + public StorageContentValidationDecoderPolicy() { + } + + @Override + public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { + // Check if structured message decoding is enabled for this request + if (!shouldApplyDecoding(context)) { + return next.process(); + } + + return next.process().map(httpResponse -> { + // Only apply decoding to download responses (GET requests with body) + if (!isDownloadResponse(httpResponse)) { + return httpResponse; + } + + DownloadContentValidationOptions validationOptions = getValidationOptions(context); + Long contentLength = getContentLength(httpResponse.getHeaders()); + + if (contentLength != null && contentLength > 0 && validationOptions != null) { + Flux decodedStream = StructuredMessageDecodingStream + .wrapStreamIfNeeded(httpResponse.getBody(), contentLength, validationOptions); + return new DecodedResponse(httpResponse, decodedStream); + } + + return httpResponse; + }); + } + + /** + * Checks if structured message decoding should be applied based on context. + * + * @param context The pipeline call context. + * @return true if decoding should be applied, false otherwise. + */ + private boolean shouldApplyDecoding(HttpPipelineCallContext context) { + return context.getData(Constants.STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY) + .map(value -> value instanceof Boolean && (Boolean) value) + .orElse(false); + } + + /** + * Gets the validation options from context. + * + * @param context The pipeline call context. + * @return The validation options or null if not present. + */ + private DownloadContentValidationOptions getValidationOptions(HttpPipelineCallContext context) { + return context.getData(Constants.STRUCTURED_MESSAGE_VALIDATION_OPTIONS_CONTEXT_KEY) + .filter(value -> value instanceof DownloadContentValidationOptions) + .map(value -> (DownloadContentValidationOptions) value) + .orElse(null); + } + + /** + * Gets the content length from response headers. + * + * @param headers The response headers. + * @return The content length or null if not present. + */ + private Long getContentLength(HttpHeaders headers) { + String contentLengthStr = headers.getValue(HttpHeaderName.CONTENT_LENGTH); + if (contentLengthStr != null) { + try { + return Long.parseLong(contentLengthStr); + } catch (NumberFormatException e) { + LOGGER.warning("Invalid content length in response headers: " + contentLengthStr); + } + } + return null; + } + + /** + * Checks if the response is a download response (GET request with body). + * + * @param httpResponse The HTTP response. + * @return true if it's a download response, false otherwise. + */ + private boolean isDownloadResponse(HttpResponse httpResponse) { + return httpResponse.getRequest().getHttpMethod() == HttpMethod.GET && httpResponse.getBody() != null; + } + + /** + * HTTP response wrapper that provides a decoded response body. + */ + static class DecodedResponse extends HttpResponse { + private final Flux decodedBody; + private final HttpResponse originalResponse; + + DecodedResponse(HttpResponse httpResponse, Flux decodedBody) { + super(httpResponse.getRequest()); + this.originalResponse = httpResponse; + this.decodedBody = decodedBody; + } + + @Override + public int getStatusCode() { + return originalResponse.getStatusCode(); + } + + @Override + public String getHeaderValue(String name) { + return originalResponse.getHeaderValue(name); + } + + @Override + public HttpHeaders getHeaders() { + return originalResponse.getHeaders(); + } + + @Override + public Flux getBody() { + return decodedBody; + } + + @Override + public Mono getBodyAsByteArray() { + return FluxUtil.collectBytesInByteBufferStream(decodedBody); + } + + @Override + public Mono getBodyAsString() { + return getBodyAsByteArray().map(String::new); + } + + @Override + public Mono getBodyAsString(Charset charset) { + return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + } + } +}