Skip to content

Commit 12360e4

Browse files
Fix smart retry to properly track consumed encoded bytes
The original implementation tracked received buffer sizes before decoding, which didn't account for pending bytes from previous buffers. This caused retry offsets to be incorrect. Now tracking only newly consumed bytes after accounting for pending data. Co-authored-by: gunjansingh-msft <[email protected]>
1 parent 3e37e24 commit 12360e4

File tree

2 files changed

+50
-56
lines changed

2 files changed

+50
-56
lines changed

sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,28 +1395,31 @@ Mono<BlobDownloadAsyncResponse> downloadStreamWithResponse(BlobRange range, Down
13951395
Context retryContext = firstRangeContext;
13961396
BlobRange retryRange;
13971397

1398-
// If structured message decoding is enabled, we need to restart from the beginning
1399-
// because the decoder must parse the complete structured message from the start
1398+
// If structured message decoding is enabled, we need to calculate the retry offset
1399+
// based on the encoded bytes processed, not the decoded bytes
14001400
if (contentValidationOptions != null
14011401
&& contentValidationOptions.isStructuredMessageValidationEnabled()) {
1402-
// Get the decoder state to determine how many decoded bytes were already emitted
1402+
// Get the decoder state to determine how many encoded bytes were processed
14031403
Object decoderStateObj
14041404
= firstRangeContext.getData(Constants.STRUCTURED_MESSAGE_DECODER_STATE_CONTEXT_KEY)
14051405
.orElse(null);
14061406

1407-
// For structured message validation, we must restart from the beginning
1408-
// because the message has headers and sequential segment numbers that must
1409-
// be parsed in order. We cannot resume parsing mid-stream.
1410-
retryRange = new BlobRange(initialOffset, finalCount);
1411-
1412-
// DO NOT preserve decoder state - create a fresh decoder for the retry
1413-
// The policy will track how many decoded bytes to skip
14141407
if (decoderStateObj instanceof StorageContentValidationDecoderPolicy.DecoderState) {
14151408
DecoderState decoderState = (DecoderState) decoderStateObj;
1416-
// Add the current decoded offset so the policy knows how many bytes to skip
1417-
retryContext = retryContext.addData(
1418-
Constants.STRUCTURED_MESSAGE_DECODED_BYTES_TO_SKIP_CONTEXT_KEY,
1419-
decoderState.getTotalBytesDecoded());
1409+
1410+
// Use totalEncodedBytesProcessed to request NEW bytes from the server
1411+
// The pending buffer already contains bytes we've received, so we request
1412+
// starting from the next byte after what we've already received
1413+
long encodedOffset = decoderState.getTotalEncodedBytesProcessed();
1414+
long remainingCount = finalCount - encodedOffset;
1415+
retryRange = new BlobRange(initialOffset + encodedOffset, remainingCount);
1416+
1417+
// Preserve the decoder state for the retry
1418+
retryContext = retryContext
1419+
.addData(Constants.STRUCTURED_MESSAGE_DECODER_STATE_CONTEXT_KEY, decoderState);
1420+
} else {
1421+
// No decoder state yet, use the normal retry logic
1422+
retryRange = new BlobRange(initialOffset + offset, newCount);
14201423
}
14211424
} else {
14221425
// For non-structured downloads, use smart retry from the interrupted offset

sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,8 @@ public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineN
6060
Long contentLength = getContentLength(httpResponse.getHeaders());
6161

6262
if (contentLength != null && contentLength > 0 && validationOptions != null) {
63-
// Check if this is a retry - if so, get the number of decoded bytes to skip
64-
long bytesToSkip = context.getData(Constants.STRUCTURED_MESSAGE_DECODED_BYTES_TO_SKIP_CONTEXT_KEY)
65-
.filter(value -> value instanceof Long)
66-
.map(value -> (Long) value)
67-
.orElse(0L);
68-
69-
// Always create a fresh decoder for each request
70-
// This is necessary because structured messages must be parsed from the beginning
71-
DecoderState decoderState = new DecoderState(contentLength, bytesToSkip);
63+
// Get or create decoder with state tracking
64+
DecoderState decoderState = getOrCreateDecoderState(context, contentLength);
7265

7366
// Decode using the stateful decoder
7467
Flux<ByteBuffer> decodedStream = decodeStream(httpResponse.getBody(), decoderState);
@@ -92,13 +85,12 @@ public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineN
9285
*/
9386
private Flux<ByteBuffer> decodeStream(Flux<ByteBuffer> encodedFlux, DecoderState state) {
9487
return encodedFlux.concatMap(encodedBuffer -> {
88+
// Track how many bytes were pending before we process
89+
int previousPendingBytes = (state.pendingBuffer != null) ? state.pendingBuffer.remaining() : 0;
90+
9591
// Combine with pending data if any
9692
ByteBuffer dataToProcess = state.combineWithPending(encodedBuffer);
9793

98-
// Track encoded bytes
99-
int encodedBytesInBuffer = encodedBuffer.remaining();
100-
state.totalEncodedBytesProcessed.addAndGet(encodedBytesInBuffer);
101-
10294
try {
10395
// Try to decode what we have - decoder handles partial data
10496
// Create duplicate for decoder - it will advance the duplicate's position as it reads
@@ -113,6 +105,14 @@ private Flux<ByteBuffer> decodeStream(Flux<ByteBuffer> encodedFlux, DecoderState
113105
int bytesConsumed = duplicateForDecode.position() - initialPosition;
114106
int bytesRemaining = availableSize - bytesConsumed;
115107

108+
// Track the newly consumed encoded bytes (excluding previously pending bytes)
109+
// The consumed bytes include both old pending bytes and new bytes from this buffer
110+
// We only want to add the NEW bytes that were consumed
111+
int newBytesConsumed = bytesConsumed - previousPendingBytes;
112+
if (newBytesConsumed > 0) {
113+
state.totalEncodedBytesProcessed.addAndGet(newBytesConsumed);
114+
}
115+
116116
// Save only unconsumed portion to pending
117117
if (bytesRemaining > 0) {
118118
// Position the original buffer to skip consumed bytes, then slice to get unconsumed
@@ -124,33 +124,13 @@ private Flux<ByteBuffer> decodeStream(Flux<ByteBuffer> encodedFlux, DecoderState
124124
state.pendingBuffer = null;
125125
}
126126

127-
// Handle skipping bytes for retries and tracking decoded bytes
127+
// Track decoded bytes
128128
int decodedBytes = decodedData.remaining();
129+
state.totalBytesDecoded.addAndGet(decodedBytes);
130+
131+
// Return decoded data if any
129132
if (decodedBytes > 0) {
130-
// Track total decoded bytes
131-
long totalDecoded = state.totalBytesDecoded.addAndGet(decodedBytes);
132-
133-
// If we need to skip bytes (retry scenario), adjust the buffer
134-
if (state.bytesToSkip > 0) {
135-
long currentPosition = totalDecoded - decodedBytes; // Where we were before adding these bytes
136-
137-
if (currentPosition + decodedBytes <= state.bytesToSkip) {
138-
// All these bytes should be skipped
139-
return Flux.empty();
140-
} else if (currentPosition < state.bytesToSkip) {
141-
// Some bytes should be skipped
142-
int skipAmount = (int) (state.bytesToSkip - currentPosition);
143-
decodedData.position(decodedData.position() + skipAmount);
144-
}
145-
// else: no bytes need to be skipped, emit all
146-
}
147-
148-
// Return decoded data if any remains after skipping
149-
if (decodedData.hasRemaining()) {
150-
return Flux.just(decodedData);
151-
} else {
152-
return Flux.empty();
153-
}
133+
return Flux.just(decodedData);
154134
} else {
155135
return Flux.empty();
156136
}
@@ -226,6 +206,20 @@ private Long getContentLength(HttpHeaders headers) {
226206
return null;
227207
}
228208

209+
/**
210+
* Gets or creates a decoder state from context.
211+
*
212+
* @param context The pipeline call context.
213+
* @param contentLength The content length.
214+
* @return The decoder state.
215+
*/
216+
private DecoderState getOrCreateDecoderState(HttpPipelineCallContext context, long contentLength) {
217+
return context.getData(Constants.STRUCTURED_MESSAGE_DECODER_STATE_CONTEXT_KEY)
218+
.filter(value -> value instanceof DecoderState)
219+
.map(value -> (DecoderState) value)
220+
.orElseGet(() -> new DecoderState(contentLength));
221+
}
222+
229223
/**
230224
* Checks if the response is a download response.
231225
*
@@ -246,21 +240,18 @@ public static class DecoderState {
246240
private final long expectedContentLength;
247241
private final AtomicLong totalBytesDecoded;
248242
private final AtomicLong totalEncodedBytesProcessed;
249-
private final long bytesToSkip;
250243
private ByteBuffer pendingBuffer;
251244

252245
/**
253246
* Creates a new decoder state.
254247
*
255248
* @param expectedContentLength The expected length of the encoded content.
256-
* @param bytesToSkip The number of decoded bytes to skip (for retry scenarios).
257249
*/
258-
public DecoderState(long expectedContentLength, long bytesToSkip) {
250+
public DecoderState(long expectedContentLength) {
259251
this.expectedContentLength = expectedContentLength;
260252
this.decoder = new StructuredMessageDecoder(expectedContentLength);
261253
this.totalBytesDecoded = new AtomicLong(0);
262254
this.totalEncodedBytesProcessed = new AtomicLong(0);
263-
this.bytesToSkip = bytesToSkip;
264255
this.pendingBuffer = null;
265256
}
266257

0 commit comments

Comments
 (0)