Skip to content

Commit 3a8f433

Browse files
committed
cleanup tests, add repro case
1 parent a527131 commit 3a8f433

File tree

2 files changed

+108
-92
lines changed

2 files changed

+108
-92
lines changed

src/main/java/software/amazon/encryption/s3/internal/CipherSubscriber.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public void onNext(ByteBuffer byteBuffer) {
108108
violates the Reactive Streams specification and can cause exceptions downstream.
109109
*/
110110
System.out.println("[CipherSubscriber] Checking content read threshold: contentRead=" + contentRead.get() + ", tagLength=" + tagLength + ", contentLength=" + contentLength);
111-
if (contentRead.get() + tagLength >= contentLength) {
111+
if (contentRead.get() + (isEncrypt ? tagLength : 0) >= contentLength) {
112112
// All content has been read; complete the stream.
113113
System.out.println("[CipherSubscriber] Content read threshold (" + contentRead.get() + ") reached, proceeding to finalBytes");
114114
finalBytes();

src/test/java/software/amazon/encryption/s3/internal/CipherSubscriberTest.java

Lines changed: 107 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525

2626
class CipherSubscriberTest {
2727
// Helper classes for testing
28-
class MySubscriber implements Subscriber<ByteBuffer> {
28+
class SimpleSubscriber implements Subscriber<ByteBuffer> {
2929

3030
public static final long DEFAULT_REQUEST_SIZE = 1;
3131

3232
private final AtomicBoolean isSubscribed = new AtomicBoolean(false);
3333
private final AtomicLong requestedItems = new AtomicLong(0);
3434
private final AtomicLong lengthOfData = new AtomicLong(0);
35-
private LinkedList<ByteBuffer> buffersSeen = new LinkedList<>();
35+
private final LinkedList<ByteBuffer> buffersSeen = new LinkedList<>();
3636
private Subscription subscription;
3737

3838
@Override
@@ -41,14 +41,13 @@ public void onSubscribe(Subscription s) {
4141
this.subscription = s;
4242
requestMore(DEFAULT_REQUEST_SIZE);
4343
} else {
44-
s.cancel(); // Cancel the new subscription if we're already subscribed
44+
s.cancel();
4545
}
4646
}
4747

4848
@Override
4949
public void onNext(ByteBuffer item) {
5050
// Process the item here
51-
System.out.println("Received: " + item);
5251
lengthOfData.addAndGet(item.capacity());
5352
buffersSeen.add(item);
5453

@@ -63,7 +62,7 @@ public void onError(Throwable t) {
6362

6463
@Override
6564
public void onComplete() {
66-
System.out.println("Stream completed");
65+
// Do nothing.
6766
}
6867

6968
public void cancel() {
@@ -74,24 +73,18 @@ public void cancel() {
7473

7574
private void requestMore(long n) {
7675
if (subscription != null) {
77-
System.out.println("Requesting more...");
7876
requestedItems.addAndGet(n);
7977
subscription.request(n);
8078
}
8179
}
8280

83-
// Getter methods for testing
84-
public boolean isSubscribed() {
85-
return isSubscribed.get();
86-
}
87-
8881
public List<ByteBuffer> getBuffersSeen() {
8982
return buffersSeen;
9083
}
9184
}
9285

9386
class TestPublisher<T> {
94-
private List<Subscriber<T>> subscribers = new ArrayList<>();
87+
private final List<Subscriber<T>> subscribers = new ArrayList<>(1);
9588

9689
public void subscribe(Subscriber<T> subscriber) {
9790
subscribers.add(subscriber);
@@ -116,22 +109,27 @@ public int getSubscriberCount() {
116109
}
117110

118111
class TestSubscription implements Subscription {
119-
private long requestedItems = 0;
112+
private long requestCount = 0;
113+
private final AtomicBoolean canceled = new AtomicBoolean(false);
120114

121115
@Override
122116
public void request(long n) {
123-
System.out.println("received req for " + n);
124-
requestedItems += n;
125-
System.out.println("total req'd items is " + requestedItems);
117+
if (!canceled.get()) {
118+
requestCount += n;
119+
} else {
120+
// Maybe do something more useful/correct eventually,
121+
// for now just throw an exception
122+
throw new RuntimeException("Subscription has been canceled!");
123+
}
126124
}
127125

128126
@Override
129127
public void cancel() {
130-
// Implementation for testing cancel behavior
128+
canceled.set(true);
131129
}
132130

133-
public long getRequestedItems() {
134-
return requestedItems;
131+
public long getRequestCount() {
132+
return requestCount;
135133
}
136134
}
137135

@@ -160,23 +158,35 @@ private DecryptionMaterials getTestDecryptionMaterialsFromEncMats(EncryptionMate
160158
.build();
161159
}
162160

161+
private byte[] getByteArrayFromFixedLengthByteBuffers(List<ByteBuffer> byteBuffers, long expectedLength) {
162+
if (expectedLength > Integer.MAX_VALUE) {
163+
throw new RuntimeException("Use a smaller expected length.");
164+
}
165+
return getByteArrayFromFixedLengthByteBuffers(byteBuffers, (int) expectedLength);
166+
}
167+
168+
private byte[] getByteArrayFromFixedLengthByteBuffers(List<ByteBuffer> byteBuffers, int expectedLength) {
169+
byte[] bytes = new byte[expectedLength];
170+
int offset = 0;
171+
for (ByteBuffer bb : byteBuffers) {
172+
int remaining = bb.remaining();
173+
bb.get(bytes, offset, remaining);
174+
offset += remaining;
175+
}
176+
return bytes;
177+
}
178+
163179
@Test
164-
public void testSubscriberBehavior() throws InterruptedException {
180+
public void testSubscriberBehaviorOneChunk() {
181+
AlgorithmSuite algorithmSuite = AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF;
165182
String plaintext = "unit test of cipher subscriber";
166183
EncryptionMaterials materials = getTestEncryptMaterials(plaintext);
167184
byte[] iv = new byte[materials.algorithmSuite().iVLengthBytes()];
168185
// we reject 0-ized IVs, so just do something
169186
iv[0] = 1;
170-
MySubscriber wrappedSubscriber = new MySubscriber();
187+
SimpleSubscriber wrappedSubscriber = new SimpleSubscriber();
171188
CipherSubscriber subscriber = new CipherSubscriber(wrappedSubscriber, (long) plaintext.getBytes(StandardCharsets.UTF_8).length, materials, iv);
172189

173-
// Arrange
174-
// TODO: These need to be moved probably to the wrappedSubscriber,
175-
// so they are actually updated as the subscription is processed.
176-
// CountDownLatch completionLatch = new CountDownLatch(1);
177-
// AtomicInteger receivedItems = new AtomicInteger(0);
178-
// AtomicInteger errorCount = new AtomicInteger(0);
179-
180190
// Act
181191
TestPublisher<ByteBuffer> publisher = new TestPublisher<>();
182192
publisher.subscribe(subscriber);
@@ -185,35 +195,19 @@ public void testSubscriberBehavior() throws InterruptedException {
185195
assertTrue(publisher.isSubscribed());
186196
assertEquals(1, publisher.getSubscriberCount());
187197

188-
// Simulate publishing items
189-
// publisher.emit("item1");
190198
ByteBuffer ptBb = ByteBuffer.wrap(plaintext.getBytes(StandardCharsets.UTF_8));
191-
System.out.println("emitting...");
192199
publisher.emit(ptBb);
193-
System.out.println("emitted");
194200

195201
// Complete the stream
196-
System.out.println("completing...");
197202
publisher.complete();
198-
System.out.println("completed.");
199203

200-
// Assert
201-
// assertTrue(completionLatch.await(5, TimeUnit.SECONDS));
202-
// assertEquals(1, wrappedSubscriber.getRequestedItems());
203-
// assertEquals(0, errorCount.get());
204-
long expectedLength = plaintext.getBytes(StandardCharsets.UTF_8).length + AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF.cipherTagLengthBytes();
204+
long expectedLength = plaintext.getBytes(StandardCharsets.UTF_8).length + algorithmSuite.cipherTagLengthBytes();
205205
assertEquals(expectedLength, wrappedSubscriber.lengthOfData.get());
206-
byte[] ctBytes = new byte[(int) expectedLength];
207-
int offset = 0;
208-
for (ByteBuffer bb : wrappedSubscriber.getBuffersSeen()) {
209-
int remaining = bb.remaining();
210-
bb.get(ctBytes, offset, remaining);
211-
offset += remaining;
212-
}
206+
byte[] ctBytes = getByteArrayFromFixedLengthByteBuffers(wrappedSubscriber.getBuffersSeen(), expectedLength);
213207

214208
// Now decrypt.
215209
DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats(materials);
216-
MySubscriber wrappedDecryptSubscriber = new MySubscriber();
210+
SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber();
217211
CipherSubscriber decryptSubscriber = new CipherSubscriber(wrappedDecryptSubscriber, expectedLength, decryptionMaterials, iv);
218212
TestPublisher<ByteBuffer> decryptPublisher = new TestPublisher<>();
219213
decryptPublisher.subscribe(decryptSubscriber);
@@ -224,58 +218,80 @@ public void testSubscriberBehavior() throws InterruptedException {
224218

225219
// Simulate publishing items
226220
ByteBuffer ctBb = ByteBuffer.wrap(ctBytes);
227-
System.out.println("emitting...");
228221
decryptPublisher.emit(ctBb);
229-
System.out.println("emitted");
230222

231223
// Complete the stream
232-
System.out.println("completing...");
233224
decryptPublisher.complete();
234-
System.out.println("completed.");
235225

236-
// Assert
237226
long expectedLengthPt = plaintext.getBytes(StandardCharsets.UTF_8).length;
238227
assertEquals(expectedLengthPt, wrappedDecryptSubscriber.lengthOfData.get());
239-
byte[] ptBytes = new byte[(int) expectedLengthPt];
240-
int offsetPt = 0;
241-
for (ByteBuffer bb : wrappedDecryptSubscriber.getBuffersSeen()) {
242-
int remaining = bb.remaining();
243-
bb.get(ptBytes, offsetPt, remaining);
244-
offsetPt += remaining;
245-
}
246-
// Round trip encrypt/decrypt succeeds.
228+
byte[] ptBytes = getByteArrayFromFixedLengthByteBuffers(wrappedDecryptSubscriber.getBuffersSeen(), expectedLengthPt);
229+
// Assert round trip encrypt/decrypt succeeds.
247230
assertEquals(plaintext, new String(ptBytes, StandardCharsets.UTF_8));
248231
}
249232

250-
//// @Test
251-
// void testBackpressure() {
252-
// // Arrange
253-
// CipherSubscriber<ByteBuffer> subscriber = new CipherSubscriber(wrappedSubscriber, contentLength, materials, iv);
254-
// TestSubscription subscription = new TestSubscription();
255-
//
256-
// // Act
257-
// subscriber.onSubscribe(subscription);
258-
//
259-
// // Assert
260-
// assertEquals(TestSubscriber.DEFAULT_REQUEST_SIZE, subscription.getRequestedItems());
261-
// }
262-
//
263-
//// @Test
264-
// void testErrorHandling() {
265-
// // Arrange
266-
// AtomicInteger errorCount = new AtomicInteger(0);
267-
// MySubscriber<String> subscriber = new MySubscriber<>() {
268-
// @Override
269-
// public void onError(Throwable t) {
270-
// errorCount.incrementAndGet();
271-
// }
272-
// };
273-
//
274-
// // Act
275-
// subscriber.onError(new RuntimeException("Test error"));
276-
//
277-
// // Assert
278-
// assertEquals(1, errorCount.get());
279-
// }
280-
}
233+
@Test
234+
public void testSubscriberBehaviorTagLengthLastChunk() {
235+
AlgorithmSuite algorithmSuite = AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF;
236+
String plaintext = "unit test of cipher subscriber tag length last chunk";
237+
EncryptionMaterials materials = getTestEncryptMaterials(plaintext);
238+
byte[] iv = new byte[materials.algorithmSuite().iVLengthBytes()];
239+
// we reject 0-ized IVs, so just do something non-zero
240+
iv[0] = 1;
241+
SimpleSubscriber wrappedSubscriber = new SimpleSubscriber();
242+
CipherSubscriber subscriber = new CipherSubscriber(wrappedSubscriber, (long) plaintext.getBytes(StandardCharsets.UTF_8).length, materials, iv);
243+
244+
// Setup Publisher
245+
TestPublisher<ByteBuffer> publisher = new TestPublisher<>();
246+
publisher.subscribe(subscriber);
247+
248+
// Verify subscription behavior
249+
assertTrue(publisher.isSubscribed());
250+
assertEquals(1, publisher.getSubscriberCount());
281251

252+
// Send data to be encrypted
253+
ByteBuffer ptBb = ByteBuffer.wrap(plaintext.getBytes(StandardCharsets.UTF_8));
254+
publisher.emit(ptBb);
255+
publisher.complete();
256+
257+
// Convert to byte array for convenience
258+
long expectedLength = plaintext.getBytes(StandardCharsets.UTF_8).length + algorithmSuite.cipherTagLengthBytes();
259+
assertEquals(expectedLength, wrappedSubscriber.lengthOfData.get());
260+
byte[] ctBytes = getByteArrayFromFixedLengthByteBuffers(wrappedSubscriber.getBuffersSeen(), expectedLength);
261+
262+
// Now decrypt the ciphertext
263+
DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats(materials);
264+
SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber();
265+
CipherSubscriber decryptSubscriber = new CipherSubscriber(wrappedDecryptSubscriber, expectedLength, decryptionMaterials, iv);
266+
TestPublisher<ByteBuffer> decryptPublisher = new TestPublisher<>();
267+
decryptPublisher.subscribe(decryptSubscriber);
268+
269+
// Verify subscription behavior
270+
assertTrue(decryptPublisher.isSubscribed());
271+
assertEquals(1, decryptPublisher.getSubscriberCount());
272+
273+
int taglength = algorithmSuite.cipherTagLengthBytes();
274+
int ciphertextWithoutTagLength = ctBytes.length - taglength;
275+
276+
// Create the main ByteBuffer (all except last 16 bytes)
277+
ByteBuffer mainBuffer = ByteBuffer.allocate(ciphertextWithoutTagLength);
278+
mainBuffer.put(ctBytes, 0, ciphertextWithoutTagLength);
279+
mainBuffer.flip();
280+
281+
// Create the tag ByteBuffer (last 16 bytes)
282+
ByteBuffer tagBuffer = ByteBuffer.allocate(taglength);
283+
tagBuffer.put(ctBytes, ciphertextWithoutTagLength, taglength);
284+
tagBuffer.flip();
285+
286+
// Send the ciphertext, then the tag separately
287+
decryptPublisher.emit(mainBuffer);
288+
decryptPublisher.emit(tagBuffer);
289+
decryptPublisher.complete();
290+
291+
long expectedLengthPt = plaintext.getBytes(StandardCharsets.UTF_8).length;
292+
assertEquals(expectedLengthPt, wrappedDecryptSubscriber.lengthOfData.get());
293+
byte[] ptBytes = getByteArrayFromFixedLengthByteBuffers(wrappedDecryptSubscriber.getBuffersSeen(), expectedLengthPt);
294+
// Assert round trip encrypt/decrypt succeeds
295+
assertEquals(plaintext, new String(ptBytes, StandardCharsets.UTF_8));
296+
}
297+
}

0 commit comments

Comments
 (0)