1616import java .util .Objects ;
1717import java .util .Set ;
1818
19+ import org .slf4j .Logger ;
20+ import org .slf4j .LoggerFactory ;
21+
1922import com .fasterxml .jackson .core .JsonProcessingException ;
2023import com .fasterxml .jackson .core .type .TypeReference ;
2124import com .fasterxml .jackson .databind .ObjectMapper ;
3841import static java .nio .charset .StandardCharsets .UTF_8 ;
3942
4043public abstract class AbstractVaultTestKmsFacade implements TestKmsFacade <Config , String , VaultEdek > {
44+
45+ private static final Logger LOGGER = LoggerFactory .getLogger (AbstractVaultTestKmsFacade .class );
46+
4147 private static final TypeReference <CreateTokenResponse > VAULT_RESPONSE_CREATE_TOKEN_RESPONSE_TYPEREF = new TypeReference <>() {
4248 };
4349 private static final TypeReference <VaultResponse <VaultResponse .ReadKeyData >> VAULT_RESPONSE_READ_KEY_DATA_TYPEREF = new TypeReference <>() {
@@ -61,7 +67,7 @@ public final void start() {
6167 startVault ();
6268
6369 // enable transit engine
64- enableTransit ( );
70+ withRetry ( this :: enableTransit , 3 );
6571
6672 // create policy
6773 var policyName = "kroxylicious_encryption_filter_policy" ;
@@ -77,12 +83,46 @@ public final void start() {
7783 kmsVaultToken = createOrphanToken ("kroxylicious_encryption_filter" , true , Set .of (policyName ));
7884 }
7985
80- protected void enableTransit () {
81- var engine = new EnableEngineRequest ("transit" );
82- var body = encodeJson (engine );
83- var request = createVaultPost (getVaultUrl ().resolve ("v1/sys/mounts/transit" ), HttpRequest .BodyPublishers .ofString (body ));
86+ private void withRetry (Runnable r , int maxAttempts ) {
87+ var done = false ;
88+ int remaining = maxAttempts ;
89+ RuntimeException last = null ;
90+ do {
91+ try {
92+ r .run ();
93+ done = true ;
94+ }
95+ catch (RuntimeException e ) {
96+ remaining --;
97+ last = e ;
98+ LOGGER .warn ("Failed to execute command (remaining %d/%d attempts)" .formatted (remaining , maxAttempts ), e );
99+ try {
100+ Thread .sleep (5000 );
101+ }
102+ catch (InterruptedException ie ) {
103+ Thread .currentThread ().interrupt ();
104+ throw new IllegalStateException ("Interrupted whilst retrying command" , ie );
105+ }
106+ }
107+ } while (!done && remaining > 0 );
84108
85- sendRequestExpectingNoContentResponse (request );
109+ if (!done ) {
110+ throw new IllegalStateException ("Task failed after " + maxAttempts + " attempts." , last );
111+ }
112+ }
113+
114+ protected void enableTransit () {
115+ String expectedMountPath = "transit" ;
116+ if (!sendRequestExpectingOk (createVaultGet (getVaultUrl ().resolve ("/v1/sys/mounts/" + expectedMountPath + "/" )))) {
117+ LOGGER .atInfo ().addArgument (expectedMountPath ).log ("Transit engine not found at: {}. Attempting to enable it." );
118+ var engine = new EnableEngineRequest ("transit" );
119+ var body = encodeJson (engine );
120+ var request = createVaultPost (getVaultUrl ().resolve ("v1/sys/mounts/" + expectedMountPath ), HttpRequest .BodyPublishers .ofString (body ));
121+ sendRequestExpectingNoContentResponse (request );
122+ }
123+ else {
124+ LOGGER .atInfo ().addArgument (expectedMountPath ).log ("Transit engine found at: {}. Continuing." );
125+ }
86126 }
87127
88128 protected void createPolicy (String policyName , InputStream policyStream ) {
@@ -99,7 +139,7 @@ protected String createOrphanToken(String description, boolean noDefaultPolicy,
99139 String body = encodeJson (token );
100140 var request = createVaultPost (getVaultUrl ().resolve ("v1/auth/token/create-orphan" ), HttpRequest .BodyPublishers .ofString (body ));
101141
102- return sendRequest ("dummy" , request , VAULT_RESPONSE_CREATE_TOKEN_RESPONSE_TYPEREF ).auth ().clientToken ();
142+ return sendRequestForKey ("dummy" , request , VAULT_RESPONSE_CREATE_TOKEN_RESPONSE_TYPEREF ).auth ().clientToken ();
103143 }
104144
105145 protected abstract URI getVaultUrl ();
@@ -132,14 +172,14 @@ class VaultKmsTestKekManager implements TestKekManager {
132172 @ Override
133173 public void generateKek (String keyId ) {
134174 var request = createVaultPost (getVaultUrl ().resolve (KEYS_PATH .formatted (encode (keyId , UTF_8 ))), HttpRequest .BodyPublishers .noBody ());
135- sendRequest (keyId , request , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF );
175+ sendRequestForKey (keyId , request , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF );
136176 }
137177
138178 @ Override
139179 public void deleteKek (String keyId ) {
140180 var update = createVaultPost (getVaultUrl ().resolve ((KEYS_PATH + "/config" ).formatted (encode (keyId , UTF_8 ))),
141181 HttpRequest .BodyPublishers .ofString (encodeJson (new UpdateKeyConfigRequest (true ))));
142- sendRequest (keyId , update , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF );
182+ sendRequestForKey (keyId , update , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF );
143183
144184 var delete = createVaultDelete (getVaultUrl ().resolve (KEYS_PATH .formatted (encode (keyId , UTF_8 ))));
145185 sendRequestExpectingNoContentResponse (delete );
@@ -148,13 +188,13 @@ public void deleteKek(String keyId) {
148188 @ Override
149189 public VaultResponse .ReadKeyData read (String keyId ) {
150190 var request = createVaultGet (getVaultUrl ().resolve (KEYS_PATH .formatted (encode (keyId , UTF_8 ))));
151- return sendRequest (keyId , request , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF ).data ();
191+ return sendRequestForKey (keyId , request , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF ).data ();
152192 }
153193
154194 @ Override
155195 public void rotateKek (String keyId ) {
156196 var request = createVaultPost (getVaultUrl ().resolve ((KEYS_PATH + "/rotate" ).formatted (encode (keyId , UTF_8 ))), HttpRequest .BodyPublishers .noBody ());
157- sendRequest (keyId , request , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF );
197+ sendRequestForKey (keyId , request , VAULT_RESPONSE_READ_KEY_DATA_TYPEREF );
158198 }
159199 }
160200
@@ -184,7 +224,28 @@ private HttpRequest.Builder createVaultRequest() {
184224 .header ("Accept" , "application/json" );
185225 }
186226
187- private <R > R sendRequest (String key , HttpRequest request , TypeReference <R > valueTypeRef ) {
227+ private boolean sendRequestExpectingOk (HttpRequest request ) {
228+ try {
229+ HttpResponse <String > response = vaultClient .send (request , HttpResponse .BodyHandlers .ofString ());
230+ if (response .statusCode () == 200 ) {
231+ return true ;
232+ }
233+ else {
234+ LOGGER .atWarn ().addArgument (response .statusCode ()).addArgument (response .uri ()).addArgument (response .body ())
235+ .log ("Received unexpected status code: {} from: {}. Response body: {}" );
236+ return false ;
237+ }
238+ }
239+ catch (IOException e ) {
240+ throw new UncheckedIOException (e );
241+ }
242+ catch (InterruptedException e ) {
243+ Thread .currentThread ().interrupt ();
244+ return false ;
245+ }
246+ }
247+
248+ private <R > R sendRequestForKey (String key , HttpRequest request , TypeReference <R > valueTypeRef ) {
188249 try {
189250 HttpResponse <byte []> response = vaultClient .send (request , HttpResponse .BodyHandlers .ofByteArray ());
190251 if (response .statusCode () == 404 ) {
@@ -210,8 +271,10 @@ else if (response.statusCode() != 200) {
210271
211272 private void sendRequestExpectingNoContentResponse (HttpRequest request ) {
212273 try {
213- var response = vaultClient .send (request , HttpResponse .BodyHandlers .discarding ());
274+ var response = vaultClient .send (request , HttpResponse .BodyHandlers .ofString ());
214275 if (response .statusCode () != 204 ) {
276+ LOGGER .atWarn ().addArgument (response .statusCode ()).addArgument (response .uri ()).addArgument (response .body ())
277+ .log ("Received unexpected status code: {} from: {}. Response body: {}" );
215278 throw new IllegalStateException ("Unexpected response : %d to request %s" .formatted (response .statusCode (), request .uri ()));
216279 }
217280 }
0 commit comments