Skip to content

Commit a527131

Browse files
committed
add a simple, passing unit test
1 parent 0729aff commit a527131

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.encryption.s3.internal;
4+
5+
import org.junit.jupiter.api.Test;
6+
import org.reactivestreams.Subscriber;
7+
import org.reactivestreams.Subscription;
8+
import software.amazon.encryption.s3.algorithms.AlgorithmSuite;
9+
import software.amazon.encryption.s3.materials.DecryptionMaterials;
10+
import software.amazon.encryption.s3.materials.EncryptionMaterials;
11+
12+
import javax.crypto.KeyGenerator;
13+
import javax.crypto.SecretKey;
14+
import java.nio.ByteBuffer;
15+
import java.nio.charset.StandardCharsets;
16+
import java.security.NoSuchAlgorithmException;
17+
import java.util.ArrayList;
18+
import java.util.LinkedList;
19+
import java.util.List;
20+
import java.util.concurrent.atomic.AtomicBoolean;
21+
import java.util.concurrent.atomic.AtomicLong;
22+
23+
import static org.junit.jupiter.api.Assertions.assertEquals;
24+
import static org.junit.jupiter.api.Assertions.assertTrue;
25+
26+
class CipherSubscriberTest {
27+
// Helper classes for testing
28+
class MySubscriber implements Subscriber<ByteBuffer> {
29+
30+
public static final long DEFAULT_REQUEST_SIZE = 1;
31+
32+
private final AtomicBoolean isSubscribed = new AtomicBoolean(false);
33+
private final AtomicLong requestedItems = new AtomicLong(0);
34+
private final AtomicLong lengthOfData = new AtomicLong(0);
35+
private LinkedList<ByteBuffer> buffersSeen = new LinkedList<>();
36+
private Subscription subscription;
37+
38+
@Override
39+
public void onSubscribe(Subscription s) {
40+
if (isSubscribed.compareAndSet(false, true)) {
41+
this.subscription = s;
42+
requestMore(DEFAULT_REQUEST_SIZE);
43+
} else {
44+
s.cancel(); // Cancel the new subscription if we're already subscribed
45+
}
46+
}
47+
48+
@Override
49+
public void onNext(ByteBuffer item) {
50+
// Process the item here
51+
System.out.println("Received: " + item);
52+
lengthOfData.addAndGet(item.capacity());
53+
buffersSeen.add(item);
54+
55+
// Request the next item
56+
requestMore(1);
57+
}
58+
59+
@Override
60+
public void onError(Throwable t) {
61+
System.err.println("Error occurred: " + t.getMessage());
62+
}
63+
64+
@Override
65+
public void onComplete() {
66+
System.out.println("Stream completed");
67+
}
68+
69+
public void cancel() {
70+
if (isSubscribed.getAndSet(false)) {
71+
subscription.cancel();
72+
}
73+
}
74+
75+
private void requestMore(long n) {
76+
if (subscription != null) {
77+
System.out.println("Requesting more...");
78+
requestedItems.addAndGet(n);
79+
subscription.request(n);
80+
}
81+
}
82+
83+
// Getter methods for testing
84+
public boolean isSubscribed() {
85+
return isSubscribed.get();
86+
}
87+
88+
public List<ByteBuffer> getBuffersSeen() {
89+
return buffersSeen;
90+
}
91+
}
92+
93+
class TestPublisher<T> {
94+
private List<Subscriber<T>> subscribers = new ArrayList<>();
95+
96+
public void subscribe(Subscriber<T> subscriber) {
97+
subscribers.add(subscriber);
98+
subscriber.onSubscribe(new TestSubscription());
99+
}
100+
101+
public void emit(T item) {
102+
subscribers.forEach(s -> s.onNext(item));
103+
}
104+
105+
public void complete() {
106+
subscribers.forEach(Subscriber::onComplete);
107+
}
108+
109+
public boolean isSubscribed() {
110+
return !subscribers.isEmpty();
111+
}
112+
113+
public int getSubscriberCount() {
114+
return subscribers.size();
115+
}
116+
}
117+
118+
class TestSubscription implements Subscription {
119+
private long requestedItems = 0;
120+
121+
@Override
122+
public void request(long n) {
123+
System.out.println("received req for " + n);
124+
requestedItems += n;
125+
System.out.println("total req'd items is " + requestedItems);
126+
}
127+
128+
@Override
129+
public void cancel() {
130+
// Implementation for testing cancel behavior
131+
}
132+
133+
public long getRequestedItems() {
134+
return requestedItems;
135+
}
136+
}
137+
138+
private EncryptionMaterials getTestEncryptMaterials(String plaintext) {
139+
try {
140+
SecretKey AES_KEY;
141+
KeyGenerator keyGen = KeyGenerator.getInstance("AES");
142+
keyGen.init(256);
143+
AES_KEY = keyGen.generateKey();
144+
return EncryptionMaterials.builder()
145+
.plaintextDataKey(AES_KEY.getEncoded())
146+
.algorithmSuite(AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF)
147+
.plaintextLength(plaintext.getBytes(StandardCharsets.UTF_8).length)
148+
.build();
149+
} catch (NoSuchAlgorithmException exception) {
150+
// this should never happen
151+
throw new RuntimeException("AES doesn't exist");
152+
}
153+
}
154+
155+
private DecryptionMaterials getTestDecryptionMaterialsFromEncMats(EncryptionMaterials encMats) {
156+
return DecryptionMaterials.builder()
157+
.plaintextDataKey(encMats.plaintextDataKey())
158+
.algorithmSuite(AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF)
159+
.ciphertextLength(encMats.getCiphertextLength())
160+
.build();
161+
}
162+
163+
@Test
164+
public void testSubscriberBehavior() throws InterruptedException {
165+
String plaintext = "unit test of cipher subscriber";
166+
EncryptionMaterials materials = getTestEncryptMaterials(plaintext);
167+
byte[] iv = new byte[materials.algorithmSuite().iVLengthBytes()];
168+
// we reject 0-ized IVs, so just do something
169+
iv[0] = 1;
170+
MySubscriber wrappedSubscriber = new MySubscriber();
171+
CipherSubscriber subscriber = new CipherSubscriber(wrappedSubscriber, (long) plaintext.getBytes(StandardCharsets.UTF_8).length, materials, iv);
172+
173+
// Arrange
174+
// TODO: These need to be moved probably to the wrappedSubscriber,
175+
// so they are actually updated as the subscription is processed.
176+
// CountDownLatch completionLatch = new CountDownLatch(1);
177+
// AtomicInteger receivedItems = new AtomicInteger(0);
178+
// AtomicInteger errorCount = new AtomicInteger(0);
179+
180+
// Act
181+
TestPublisher<ByteBuffer> publisher = new TestPublisher<>();
182+
publisher.subscribe(subscriber);
183+
184+
// Verify subscription behavior
185+
assertTrue(publisher.isSubscribed());
186+
assertEquals(1, publisher.getSubscriberCount());
187+
188+
// Simulate publishing items
189+
// publisher.emit("item1");
190+
ByteBuffer ptBb = ByteBuffer.wrap(plaintext.getBytes(StandardCharsets.UTF_8));
191+
System.out.println("emitting...");
192+
publisher.emit(ptBb);
193+
System.out.println("emitted");
194+
195+
// Complete the stream
196+
System.out.println("completing...");
197+
publisher.complete();
198+
System.out.println("completed.");
199+
200+
// Assert
201+
// assertTrue(completionLatch.await(5, TimeUnit.SECONDS));
202+
// assertEquals(1, wrappedSubscriber.getRequestedItems());
203+
// assertEquals(0, errorCount.get());
204+
long expectedLength = plaintext.getBytes(StandardCharsets.UTF_8).length + AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF.cipherTagLengthBytes();
205+
assertEquals(expectedLength, wrappedSubscriber.lengthOfData.get());
206+
byte[] ctBytes = new byte[(int) expectedLength];
207+
int offset = 0;
208+
for (ByteBuffer bb : wrappedSubscriber.getBuffersSeen()) {
209+
int remaining = bb.remaining();
210+
bb.get(ctBytes, offset, remaining);
211+
offset += remaining;
212+
}
213+
214+
// Now decrypt.
215+
DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats(materials);
216+
MySubscriber wrappedDecryptSubscriber = new MySubscriber();
217+
CipherSubscriber decryptSubscriber = new CipherSubscriber(wrappedDecryptSubscriber, expectedLength, decryptionMaterials, iv);
218+
TestPublisher<ByteBuffer> decryptPublisher = new TestPublisher<>();
219+
decryptPublisher.subscribe(decryptSubscriber);
220+
221+
// Verify subscription behavior
222+
assertTrue(decryptPublisher.isSubscribed());
223+
assertEquals(1, decryptPublisher.getSubscriberCount());
224+
225+
// Simulate publishing items
226+
ByteBuffer ctBb = ByteBuffer.wrap(ctBytes);
227+
System.out.println("emitting...");
228+
decryptPublisher.emit(ctBb);
229+
System.out.println("emitted");
230+
231+
// Complete the stream
232+
System.out.println("completing...");
233+
decryptPublisher.complete();
234+
System.out.println("completed.");
235+
236+
// Assert
237+
long expectedLengthPt = plaintext.getBytes(StandardCharsets.UTF_8).length;
238+
assertEquals(expectedLengthPt, wrappedDecryptSubscriber.lengthOfData.get());
239+
byte[] ptBytes = new byte[(int) expectedLengthPt];
240+
int offsetPt = 0;
241+
for (ByteBuffer bb : wrappedDecryptSubscriber.getBuffersSeen()) {
242+
int remaining = bb.remaining();
243+
bb.get(ptBytes, offsetPt, remaining);
244+
offsetPt += remaining;
245+
}
246+
// Round trip encrypt/decrypt succeeds.
247+
assertEquals(plaintext, new String(ptBytes, StandardCharsets.UTF_8));
248+
}
249+
250+
//// @Test
251+
// void testBackpressure() {
252+
// // Arrange
253+
// CipherSubscriber<ByteBuffer> subscriber = new CipherSubscriber(wrappedSubscriber, contentLength, materials, iv);
254+
// TestSubscription subscription = new TestSubscription();
255+
//
256+
// // Act
257+
// subscriber.onSubscribe(subscription);
258+
//
259+
// // Assert
260+
// assertEquals(TestSubscriber.DEFAULT_REQUEST_SIZE, subscription.getRequestedItems());
261+
// }
262+
//
263+
//// @Test
264+
// void testErrorHandling() {
265+
// // Arrange
266+
// AtomicInteger errorCount = new AtomicInteger(0);
267+
// MySubscriber<String> subscriber = new MySubscriber<>() {
268+
// @Override
269+
// public void onError(Throwable t) {
270+
// errorCount.incrementAndGet();
271+
// }
272+
// };
273+
//
274+
// // Act
275+
// subscriber.onError(new RuntimeException("Test error"));
276+
//
277+
// // Assert
278+
// assertEquals(1, errorCount.get());
279+
// }
280+
}
281+

0 commit comments

Comments
 (0)