Skip to content

Commit 66d8385

Browse files
authored
Storage Content Validation - Encoder Performance Improvements (#47531)
* adjusting encoder logic * adjusting tests to work with encoder * addressing copilot comments * adding more documentation
1 parent 0671045 commit 66d8385

File tree

2 files changed

+92
-95
lines changed

2 files changed

+92
-95
lines changed

sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageEncoder.java

Lines changed: 65 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
package com.azure.storage.common.implementation.structuredmessage;
55

66
import com.azure.core.util.logging.ClientLogger;
7+
import com.azure.storage.common.implementation.BufferStagingArea;
78
import com.azure.storage.common.implementation.StorageCrc64Calculator;
89
import com.azure.storage.common.implementation.StorageImplUtils;
10+
import reactor.core.publisher.Flux;
911

10-
import java.io.IOException;
1112
import java.nio.ByteBuffer;
12-
import java.io.ByteArrayOutputStream;
1313
import java.nio.ByteOrder;
14+
import java.util.ArrayList;
1415
import java.util.HashMap;
16+
import java.util.List;
1517
import java.util.Map;
1618

1719
import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.CRC64_LENGTH;
@@ -35,7 +37,6 @@ public class StructuredMessageEncoder {
3537
private int currentContentOffset;
3638
private int currentSegmentNumber;
3739
private int currentSegmentOffset;
38-
private int currentMessageLength;
3940
private long messageCRC64;
4041
private final Map<Integer, Long> segmentCRC64s;
4142

@@ -66,7 +67,6 @@ public StructuredMessageEncoder(int contentLength, int segmentSize, StructuredMe
6667
this.currentSegmentOffset = 0;
6768
this.messageCRC64 = 0;
6869
this.segmentCRC64s = new HashMap<>();
69-
this.currentMessageLength = 0;
7070

7171
if (numSegments > Short.MAX_VALUE) {
7272
StorageImplUtils.assertInBounds("numSegments", numSegments, 1, Short.MAX_VALUE);
@@ -109,115 +109,102 @@ private byte[] generateMessageHeader() {
109109
return buffer.array();
110110
}
111111

112-
private byte[] generateSegmentHeader() {
113-
int segmentHeaderSize = Math.min(segmentSize, contentLength - currentContentOffset);
114-
// 2 byte number, 8 byte size
112+
private byte[] generateSegmentHeader(int segmentContentSize) {
115113
ByteBuffer buffer = ByteBuffer.allocate(getSegmentHeaderLength()).order(ByteOrder.LITTLE_ENDIAN);
116114
buffer.putShort((short) currentSegmentNumber);
117-
buffer.putLong(segmentHeaderSize);
115+
buffer.putLong(segmentContentSize);
118116

119117
return buffer.array();
120118
}
121119

122120
/**
123-
* Encodes the given buffer into a structured message format.
121+
* Encodes the given buffer into a structured message format as a stream of ByteBuffers.
122+
* The encoder maintains mutable state and is designed for single, sequential subscription only.
123+
* Callers should pre-chunk input buffers to appropriate sizes (e.g., using {@link BufferStagingArea}) to
124+
* control memory usage.
124125
*
125126
* @param unencodedBuffer The buffer to be encoded.
126-
* @return The encoded buffer.
127-
* @throws IOException If an error occurs while encoding the buffer.
127+
* @return A Flux of encoded ByteBuffers.
128128
* @throws IllegalArgumentException If the buffer length exceeds the content length, or the content has already been
129129
* encoded.
130130
*/
131-
public ByteBuffer encode(ByteBuffer unencodedBuffer) throws IOException {
131+
public Flux<ByteBuffer> encode(ByteBuffer unencodedBuffer) {
132132
StorageImplUtils.assertNotNull("unencodedBuffer", unencodedBuffer);
133133

134-
if (currentContentOffset == contentLength) {
135-
throw LOGGER.logExceptionAsError(new IllegalArgumentException("Content has already been encoded."));
136-
}
137-
138-
if ((unencodedBuffer.remaining() + currentContentOffset) > contentLength) {
139-
throw LOGGER.logExceptionAsError(new IllegalArgumentException("Buffer length exceeds content length."));
140-
}
141-
142-
if (!unencodedBuffer.hasRemaining()) {
143-
return ByteBuffer.allocate(0);
144-
}
145-
146-
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
147-
148-
// if we are at the beginning of the message, encode message header
149-
if (currentMessageLength == 0) {
150-
encodeMessageHeader(byteArrayOutputStream);
151-
}
152-
153-
while (unencodedBuffer.hasRemaining()) {
154-
// if we are at the beginning of a segment's content, encode segment header
155-
if (currentSegmentOffset == 0) {
156-
encodeSegmentHeader(byteArrayOutputStream);
134+
return Flux.defer(() -> {
135+
if (currentContentOffset == contentLength) {
136+
return Flux.error(
137+
LOGGER.logExceptionAsError(new IllegalArgumentException("Content has already been encoded.")));
157138
}
158139

159-
encodeSegmentContent(unencodedBuffer, byteArrayOutputStream);
160-
161-
// if we are at the end of a segment's content, encode segment footer
162-
if (currentSegmentOffset == getSegmentContentLength()) {
163-
encodeSegmentFooter(byteArrayOutputStream);
140+
if ((unencodedBuffer.remaining() + currentContentOffset) > contentLength) {
141+
return Flux.error(
142+
LOGGER.logExceptionAsError(new IllegalArgumentException("Buffer length exceeds content length.")));
164143
}
165-
}
166144

167-
// if all content has been encoded, encode message footer
168-
if (currentContentOffset == contentLength) {
169-
encodeMessageFooter(byteArrayOutputStream);
170-
}
145+
if (!unencodedBuffer.hasRemaining()) {
146+
return Flux.empty();
147+
}
171148

172-
return ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
173-
}
149+
List<ByteBuffer> buffers = new ArrayList<>();
174150

175-
private void encodeMessageHeader(ByteArrayOutputStream output) {
176-
byte[] metadata = generateMessageHeader();
177-
output.write(metadata, 0, metadata.length);
151+
// if we are at the beginning of the message, encode message header
152+
if (currentContentOffset == 0) {
153+
buffers.add(ByteBuffer.wrap(generateMessageHeader()));
154+
}
178155

179-
currentMessageLength += metadata.length;
180-
}
156+
while (unencodedBuffer.hasRemaining()) {
157+
// if we are at the beginning of a segment's content, encode segment header
158+
if (currentSegmentOffset == 0) {
159+
incrementCurrentSegment();
160+
// Calculate actual segment size based on remaining content
161+
int actualSegmentSize = Math.min(segmentSize, contentLength - currentContentOffset);
162+
buffers.add(ByteBuffer.wrap(generateSegmentHeader(actualSegmentSize)));
163+
}
164+
165+
buffers.add(encodeSegmentContent(unencodedBuffer));
166+
167+
// if we are at the end of a segment's content, encode segment footer
168+
if (currentSegmentOffset == getSegmentContentLength()) {
169+
byte[] footer = generateSegmentFooter();
170+
if (footer.length > 0) {
171+
buffers.add(ByteBuffer.wrap(footer));
172+
}
173+
currentSegmentOffset = 0;
174+
}
175+
}
181176

182-
private void encodeSegmentHeader(ByteArrayOutputStream output) {
183-
incrementCurrentSegment();
184-
byte[] metadata = generateSegmentHeader();
185-
output.write(metadata, 0, metadata.length);
177+
// if all content has been encoded, encode message footer
178+
if (currentContentOffset == contentLength) {
179+
byte[] footer = generateMessageFooter();
180+
if (footer.length > 0) {
181+
buffers.add(ByteBuffer.wrap(footer));
182+
}
183+
}
186184

187-
currentMessageLength += metadata.length;
185+
return Flux.fromIterable(buffers);
186+
});
188187
}
189188

190-
private void encodeSegmentFooter(ByteArrayOutputStream output) {
191-
byte[] metadata;
189+
private byte[] generateSegmentFooter() {
192190
if (structuredMessageFlags == StructuredMessageFlags.STORAGE_CRC64) {
193-
metadata = ByteBuffer.allocate(CRC64_LENGTH)
191+
return ByteBuffer.allocate(CRC64_LENGTH)
194192
.order(ByteOrder.LITTLE_ENDIAN)
195193
.putLong(segmentCRC64s.get(currentSegmentNumber))
196194
.array();
197-
} else {
198-
metadata = new byte[0];
199195
}
200-
output.write(metadata, 0, metadata.length);
201-
202-
currentMessageLength += metadata.length;
203-
currentSegmentOffset = 0;
196+
return new byte[0];
204197
}
205198

206-
private void encodeMessageFooter(ByteArrayOutputStream output) {
207-
byte[] metadata;
199+
private byte[] generateMessageFooter() {
208200
if (structuredMessageFlags == StructuredMessageFlags.STORAGE_CRC64) {
209-
metadata = ByteBuffer.allocate(CRC64_LENGTH).order(ByteOrder.LITTLE_ENDIAN).putLong(messageCRC64).array();
210-
} else {
211-
metadata = new byte[0];
201+
return ByteBuffer.allocate(CRC64_LENGTH).order(ByteOrder.LITTLE_ENDIAN).putLong(messageCRC64).array();
212202
}
213-
214-
output.write(metadata, 0, metadata.length);
215-
currentMessageLength += metadata.length;
203+
return new byte[0];
216204
}
217205

218-
private void encodeSegmentContent(ByteBuffer unencodedBuffer, ByteArrayOutputStream output) {
206+
private ByteBuffer encodeSegmentContent(ByteBuffer unencodedBuffer) {
219207
int readSize = Math.min(unencodedBuffer.remaining(), getSegmentContentLength() - currentSegmentOffset);
220-
221208
byte[] content = new byte[readSize];
222209
unencodedBuffer.get(content, 0, readSize);
223210

@@ -230,8 +217,7 @@ private void encodeSegmentContent(ByteBuffer unencodedBuffer, ByteArrayOutputStr
230217
currentContentOffset += readSize;
231218
currentSegmentOffset += readSize;
232219

233-
output.write(content, 0, content.length);
234-
currentMessageLength += readSize;
220+
return ByteBuffer.wrap(content);
235221
}
236222

237223
private int calculateMessageLength() {
@@ -255,7 +241,7 @@ private void incrementCurrentSegment() {
255241
*
256242
* @return The length of the message.
257243
*/
258-
public int getMessageLength() {
244+
public long getEncodedMessageLength() {
259245
return messageLength;
260246
}
261247
}

sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/structuredmessage/MessageEncoderTests.java

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,27 @@
99
import org.junit.jupiter.params.ParameterizedTest;
1010
import org.junit.jupiter.params.provider.Arguments;
1111
import org.junit.jupiter.params.provider.MethodSource;
12+
import reactor.test.StepVerifier;
1213

1314
import java.io.ByteArrayOutputStream;
1415
import java.io.IOException;
1516
import java.nio.ByteBuffer;
1617
import java.nio.ByteOrder;
1718
import java.util.Arrays;
19+
import java.util.Objects;
1820
import java.util.concurrent.ThreadLocalRandom;
1921
import java.util.stream.Stream;
2022

23+
import static com.azure.core.util.FluxUtil.collectBytesInByteBufferStream;
24+
import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.CRC64_LENGTH;
25+
import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.V1_HEADER_LENGTH;
26+
import static com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH;
2127
import static org.junit.jupiter.api.Assertions.assertEquals;
28+
import static org.junit.jupiter.api.Assertions.assertNotNull;
2229
import static org.junit.jupiter.api.Assertions.assertThrows;
2330

2431
public class MessageEncoderTests {
2532

26-
private static final int V1_HEADER_LENGTH = 13;
27-
private static final int V1_SEGMENT_HEADER_LENGTH = 10;
28-
private static final int CRC64_LENGTH = 8;
29-
3033
private static byte[] getRandomData(int size) {
3134
byte[] result = new byte[size];
3235
ThreadLocalRandom.current().nextBytes(result);
@@ -156,7 +159,7 @@ public void readAll(int size, int segmentSize, StructuredMessageFlags flags) thr
156159

157160
StructuredMessageEncoder structuredMessageEncoder = new StructuredMessageEncoder(size, segmentSize, flags);
158161

159-
byte[] actual = structuredMessageEncoder.encode(unencodedBuffer).array();
162+
byte[] actual = collectBytesInByteBufferStream(structuredMessageEncoder.encode(unencodedBuffer)).block();
160163
byte[] expected = buildStructuredMessage(unencodedBuffer, segmentSize, flags).array();
161164

162165
Assertions.assertArrayEquals(expected, actual);
@@ -191,33 +194,41 @@ public void readMultiple(int segmentSize, StructuredMessageFlags flags) throws I
191194
byte[] expected = buildStructuredMessage(allWrappedData, segmentSize, flags).array();
192195

193196
ByteArrayOutputStream allActualData = new ByteArrayOutputStream();
194-
allActualData.write(structuredMessageEncoder.encode(wrappedData1).array());
195-
allActualData.write(structuredMessageEncoder.encode(wrappedData2).array());
196-
allActualData.write(structuredMessageEncoder.encode(wrappedData3).array());
197+
allActualData.write(Objects
198+
.requireNonNull(collectBytesInByteBufferStream(structuredMessageEncoder.encode(wrappedData1)).block()));
199+
allActualData.write(Objects
200+
.requireNonNull(collectBytesInByteBufferStream(structuredMessageEncoder.encode(wrappedData2)).block()));
201+
allActualData.write(Objects
202+
.requireNonNull(collectBytesInByteBufferStream(structuredMessageEncoder.encode(wrappedData3)).block()));
197203

198204
Assertions.assertArrayEquals(expected, allActualData.toByteArray());
199205
}
200206

201207
@Test
202-
public void emptyBuffer() throws IOException {
208+
public void emptyBuffer() {
203209
StructuredMessageEncoder encoder = new StructuredMessageEncoder(10, 5, StructuredMessageFlags.NONE);
204210
ByteBuffer emptyBuffer = ByteBuffer.allocate(0);
205-
ByteBuffer result = encoder.encode(emptyBuffer);
211+
ByteBuffer result = ByteBuffer
212+
.wrap(Objects.requireNonNull(collectBytesInByteBufferStream(encoder.encode(emptyBuffer)).block()));
206213
assertEquals(0, result.remaining());
207214
}
208215

209216
@Test
210-
public void contentAlreadyEncoded() throws IOException {
217+
public void contentAlreadyEncoded() {
211218
StructuredMessageEncoder encoder = new StructuredMessageEncoder(4, 2, StructuredMessageFlags.NONE);
212-
encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2, 3, 4 }));
213-
assertThrows(IllegalArgumentException.class, () -> encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2 })));
219+
encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2, 3, 4 })).blockLast();
220+
StepVerifier.create(encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2 })))
221+
.expectError(IllegalArgumentException.class)
222+
.verify();
214223
}
215224

216225
@Test
217-
public void bufferLengthExceedsContentLength() throws IOException {
226+
public void bufferLengthExceedsContentLength() {
218227
StructuredMessageEncoder encoder = new StructuredMessageEncoder(4, 2, StructuredMessageFlags.NONE);
219-
encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2, 3 }));
220-
assertThrows(IllegalArgumentException.class, () -> encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2 })));
228+
encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2, 3 })).blockLast();
229+
StepVerifier.create(encoder.encode(ByteBuffer.wrap(new byte[] { 1, 2 })))
230+
.expectError(IllegalArgumentException.class)
231+
.verify();
221232
}
222233

223234
@Test

0 commit comments

Comments
 (0)