1+ // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+ // SPDX-License-Identifier: Apache-2.0
3+ package software .amazon .encryption .s3 .internal ;
4+
5+ import org .junit .jupiter .api .Test ;
6+ import org .reactivestreams .Subscriber ;
7+ import org .reactivestreams .Subscription ;
8+ import software .amazon .encryption .s3 .algorithms .AlgorithmSuite ;
9+ import software .amazon .encryption .s3 .materials .DecryptionMaterials ;
10+ import software .amazon .encryption .s3 .materials .EncryptionMaterials ;
11+
12+ import javax .crypto .KeyGenerator ;
13+ import javax .crypto .SecretKey ;
14+ import java .nio .ByteBuffer ;
15+ import java .nio .charset .StandardCharsets ;
16+ import java .security .NoSuchAlgorithmException ;
17+ import java .util .ArrayList ;
18+ import java .util .LinkedList ;
19+ import java .util .List ;
20+ import java .util .concurrent .atomic .AtomicBoolean ;
21+ import java .util .concurrent .atomic .AtomicLong ;
22+
23+ import static org .junit .jupiter .api .Assertions .assertEquals ;
24+ import static org .junit .jupiter .api .Assertions .assertTrue ;
25+
26+ class CipherSubscriberTest {
27+ // Helper classes for testing
28+ class SimpleSubscriber implements Subscriber <ByteBuffer > {
29+
30+ public static final long DEFAULT_REQUEST_SIZE = 1 ;
31+
32+ private final AtomicBoolean isSubscribed = new AtomicBoolean (false );
33+ private final AtomicLong requestedItems = new AtomicLong (0 );
34+ private final AtomicLong lengthOfData = new AtomicLong (0 );
35+ private final LinkedList <ByteBuffer > buffersSeen = new LinkedList <>();
36+ private Subscription subscription ;
37+
38+ @ Override
39+ public void onSubscribe (Subscription s ) {
40+ if (isSubscribed .compareAndSet (false , true )) {
41+ this .subscription = s ;
42+ requestMore (DEFAULT_REQUEST_SIZE );
43+ } else {
44+ s .cancel ();
45+ }
46+ }
47+
48+ @ Override
49+ public void onNext (ByteBuffer item ) {
50+ // Process the item here
51+ lengthOfData .addAndGet (item .capacity ());
52+ buffersSeen .add (item );
53+
54+ // Request the next item
55+ requestMore (1 );
56+ }
57+
58+ @ Override
59+ public void onError (Throwable t ) {
60+ System .err .println ("Error occurred: " + t .getMessage ());
61+ }
62+
63+ @ Override
64+ public void onComplete () {
65+ // Do nothing.
66+ }
67+
68+ public void cancel () {
69+ if (isSubscribed .getAndSet (false )) {
70+ subscription .cancel ();
71+ }
72+ }
73+
74+ private void requestMore (long n ) {
75+ if (subscription != null ) {
76+ requestedItems .addAndGet (n );
77+ subscription .request (n );
78+ }
79+ }
80+
81+ public List <ByteBuffer > getBuffersSeen () {
82+ return buffersSeen ;
83+ }
84+ }
85+
86+ class TestPublisher <T > {
87+ private final List <Subscriber <T >> subscribers = new ArrayList <>(1 );
88+
89+ public void subscribe (Subscriber <T > subscriber ) {
90+ subscribers .add (subscriber );
91+ subscriber .onSubscribe (new TestSubscription ());
92+ }
93+
94+ public void emit (T item ) {
95+ subscribers .forEach (s -> s .onNext (item ));
96+ }
97+
98+ public void complete () {
99+ subscribers .forEach (Subscriber ::onComplete );
100+ }
101+
102+ public boolean isSubscribed () {
103+ return !subscribers .isEmpty ();
104+ }
105+
106+ public int getSubscriberCount () {
107+ return subscribers .size ();
108+ }
109+ }
110+
111+ class TestSubscription implements Subscription {
112+ private long requestCount = 0 ;
113+ private final AtomicBoolean canceled = new AtomicBoolean (false );
114+
115+ @ Override
116+ public void request (long n ) {
117+ if (!canceled .get ()) {
118+ requestCount += n ;
119+ } else {
120+ // Maybe do something more useful/correct eventually,
121+ // for now just throw an exception
122+ throw new RuntimeException ("Subscription has been canceled!" );
123+ }
124+ }
125+
126+ @ Override
127+ public void cancel () {
128+ canceled .set (true );
129+ }
130+
131+ public long getRequestCount () {
132+ return requestCount ;
133+ }
134+ }
135+
136+ private EncryptionMaterials getTestEncryptMaterials (String plaintext ) {
137+ try {
138+ SecretKey AES_KEY ;
139+ KeyGenerator keyGen = KeyGenerator .getInstance ("AES" );
140+ keyGen .init (256 );
141+ AES_KEY = keyGen .generateKey ();
142+ return EncryptionMaterials .builder ()
143+ .plaintextDataKey (AES_KEY .getEncoded ())
144+ .algorithmSuite (AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF )
145+ .plaintextLength (plaintext .getBytes (StandardCharsets .UTF_8 ).length )
146+ .build ();
147+ } catch (NoSuchAlgorithmException exception ) {
148+ // this should never happen
149+ throw new RuntimeException ("AES doesn't exist" );
150+ }
151+ }
152+
153+ private DecryptionMaterials getTestDecryptionMaterialsFromEncMats (EncryptionMaterials encMats ) {
154+ return DecryptionMaterials .builder ()
155+ .plaintextDataKey (encMats .plaintextDataKey ())
156+ .algorithmSuite (AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF )
157+ .ciphertextLength (encMats .getCiphertextLength ())
158+ .build ();
159+ }
160+
161+ private byte [] getByteArrayFromFixedLengthByteBuffers (List <ByteBuffer > byteBuffers , long expectedLength ) {
162+ if (expectedLength > Integer .MAX_VALUE ) {
163+ throw new RuntimeException ("Use a smaller expected length." );
164+ }
165+ return getByteArrayFromFixedLengthByteBuffers (byteBuffers , (int ) expectedLength );
166+ }
167+
168+ private byte [] getByteArrayFromFixedLengthByteBuffers (List <ByteBuffer > byteBuffers , int expectedLength ) {
169+ byte [] bytes = new byte [expectedLength ];
170+ int offset = 0 ;
171+ for (ByteBuffer bb : byteBuffers ) {
172+ int remaining = bb .remaining ();
173+ bb .get (bytes , offset , remaining );
174+ offset += remaining ;
175+ }
176+ return bytes ;
177+ }
178+
179+ @ Test
180+ public void testSubscriberBehaviorOneChunk () {
181+ AlgorithmSuite algorithmSuite = AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF ;
182+ String plaintext = "unit test of cipher subscriber" ;
183+ EncryptionMaterials materials = getTestEncryptMaterials (plaintext );
184+ byte [] iv = new byte [materials .algorithmSuite ().iVLengthBytes ()];
185+ // we reject 0-ized IVs, so just do something
186+ iv [0 ] = 1 ;
187+ SimpleSubscriber wrappedSubscriber = new SimpleSubscriber ();
188+ CipherSubscriber subscriber = new CipherSubscriber (wrappedSubscriber , materials .getCiphertextLength (), materials , iv );
189+
190+ // Act
191+ TestPublisher <ByteBuffer > publisher = new TestPublisher <>();
192+ publisher .subscribe (subscriber );
193+
194+ // Verify subscription behavior
195+ assertTrue (publisher .isSubscribed ());
196+ assertEquals (1 , publisher .getSubscriberCount ());
197+
198+ ByteBuffer ptBb = ByteBuffer .wrap (plaintext .getBytes (StandardCharsets .UTF_8 ));
199+ publisher .emit (ptBb );
200+
201+ // Complete the stream
202+ publisher .complete ();
203+
204+ long expectedLength = plaintext .getBytes (StandardCharsets .UTF_8 ).length + algorithmSuite .cipherTagLengthBytes ();
205+ assertEquals (expectedLength , wrappedSubscriber .lengthOfData .get ());
206+ byte [] ctBytes = getByteArrayFromFixedLengthByteBuffers (wrappedSubscriber .getBuffersSeen (), expectedLength );
207+
208+ // Now decrypt.
209+ DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats (materials );
210+ SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber ();
211+ CipherSubscriber decryptSubscriber = new CipherSubscriber (wrappedDecryptSubscriber , expectedLength , decryptionMaterials , iv );
212+ TestPublisher <ByteBuffer > decryptPublisher = new TestPublisher <>();
213+ decryptPublisher .subscribe (decryptSubscriber );
214+
215+ // Verify subscription behavior
216+ assertTrue (decryptPublisher .isSubscribed ());
217+ assertEquals (1 , decryptPublisher .getSubscriberCount ());
218+
219+ // Simulate publishing items
220+ ByteBuffer ctBb = ByteBuffer .wrap (ctBytes );
221+ decryptPublisher .emit (ctBb );
222+
223+ // Complete the stream
224+ decryptPublisher .complete ();
225+
226+ long expectedLengthPt = plaintext .getBytes (StandardCharsets .UTF_8 ).length ;
227+ assertEquals (expectedLengthPt , wrappedDecryptSubscriber .lengthOfData .get ());
228+ byte [] ptBytes = getByteArrayFromFixedLengthByteBuffers (wrappedDecryptSubscriber .getBuffersSeen (), expectedLengthPt );
229+ // Assert round trip encrypt/decrypt succeeds.
230+ assertEquals (plaintext , new String (ptBytes , StandardCharsets .UTF_8 ));
231+ }
232+
233+ @ Test
234+ public void testSubscriberBehaviorTagLengthLastChunk () {
235+ AlgorithmSuite algorithmSuite = AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF ;
236+ String plaintext = "unit test of cipher subscriber tag length last chunk" ;
237+ EncryptionMaterials materials = getTestEncryptMaterials (plaintext );
238+ byte [] iv = new byte [materials .algorithmSuite ().iVLengthBytes ()];
239+ // we reject 0-ized IVs, so just do something non-zero
240+ iv [0 ] = 1 ;
241+ SimpleSubscriber wrappedSubscriber = new SimpleSubscriber ();
242+ CipherSubscriber subscriber = new CipherSubscriber (wrappedSubscriber , materials .getCiphertextLength (), materials , iv );
243+
244+ // Setup Publisher
245+ TestPublisher <ByteBuffer > publisher = new TestPublisher <>();
246+ publisher .subscribe (subscriber );
247+
248+ // Verify subscription behavior
249+ assertTrue (publisher .isSubscribed ());
250+ assertEquals (1 , publisher .getSubscriberCount ());
251+
252+ // Send data to be encrypted
253+ ByteBuffer ptBb = ByteBuffer .wrap (plaintext .getBytes (StandardCharsets .UTF_8 ));
254+ publisher .emit (ptBb );
255+ publisher .complete ();
256+
257+ // Convert to byte array for convenience
258+ long expectedLength = plaintext .getBytes (StandardCharsets .UTF_8 ).length + algorithmSuite .cipherTagLengthBytes ();
259+ assertEquals (expectedLength , wrappedSubscriber .lengthOfData .get ());
260+ byte [] ctBytes = getByteArrayFromFixedLengthByteBuffers (wrappedSubscriber .getBuffersSeen (), expectedLength );
261+
262+ // Now decrypt the ciphertext
263+ DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats (materials );
264+ SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber ();
265+ CipherSubscriber decryptSubscriber = new CipherSubscriber (wrappedDecryptSubscriber , expectedLength , decryptionMaterials , iv );
266+ TestPublisher <ByteBuffer > decryptPublisher = new TestPublisher <>();
267+ decryptPublisher .subscribe (decryptSubscriber );
268+
269+ // Verify subscription behavior
270+ assertTrue (decryptPublisher .isSubscribed ());
271+ assertEquals (1 , decryptPublisher .getSubscriberCount ());
272+
273+ int taglength = algorithmSuite .cipherTagLengthBytes ();
274+ int ciphertextWithoutTagLength = ctBytes .length - taglength ;
275+
276+ // Create the main ByteBuffer (all except last 16 bytes)
277+ ByteBuffer mainBuffer = ByteBuffer .allocate (ciphertextWithoutTagLength );
278+ mainBuffer .put (ctBytes , 0 , ciphertextWithoutTagLength );
279+ mainBuffer .flip ();
280+
281+ // Create the tag ByteBuffer (last 16 bytes)
282+ ByteBuffer tagBuffer = ByteBuffer .allocate (taglength );
283+ tagBuffer .put (ctBytes , ciphertextWithoutTagLength , taglength );
284+ tagBuffer .flip ();
285+
286+ // Send the ciphertext, then the tag separately
287+ decryptPublisher .emit (mainBuffer );
288+ decryptPublisher .emit (tagBuffer );
289+ decryptPublisher .complete ();
290+
291+ long expectedLengthPt = plaintext .getBytes (StandardCharsets .UTF_8 ).length ;
292+ assertEquals (expectedLengthPt , wrappedDecryptSubscriber .lengthOfData .get ());
293+ byte [] ptBytes = getByteArrayFromFixedLengthByteBuffers (wrappedDecryptSubscriber .getBuffersSeen (), expectedLengthPt );
294+ // Assert round trip encrypt/decrypt succeeds
295+ assertEquals (plaintext , new String (ptBytes , StandardCharsets .UTF_8 ));
296+ }
297+ }
0 commit comments