Skip to content

Commit aa4ce4d

Browse files
committed
Make SSMProvider and AppConfigProvider thread-safe.
1 parent 5fc86a3 commit aa4ce4d

File tree

3 files changed

+118
-22
lines changed
  • powertools-parameters
    • powertools-parameters-appconfig/src/main/java/software/amazon/lambda/powertools/parameters/appconfig
    • powertools-parameters-ssm/src

3 files changed

+118
-22
lines changed

powertools-parameters/powertools-parameters-appconfig/src/main/java/software/amazon/lambda/powertools/parameters/appconfig/AppConfigProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
package software.amazon.lambda.powertools.parameters.appconfig;
1616

17-
import java.util.HashMap;
1817
import java.util.Map;
18+
import java.util.concurrent.ConcurrentHashMap;
1919

2020
import software.amazon.awssdk.core.SdkBytes;
2121
import software.amazon.awssdk.services.appconfigdata.AppConfigDataClient;
@@ -46,7 +46,7 @@ public class AppConfigProvider extends BaseProvider {
4646
private final AppConfigDataClient client;
4747
private final String application;
4848
private final String environment;
49-
private final Map<String, EstablishedSession> establishedSessions = new HashMap<>();
49+
private final Map<String, EstablishedSession> establishedSessions = new ConcurrentHashMap<>();
5050

5151
AppConfigProvider(CacheManager cacheManager, TransformationManager transformationManager,
5252
AppConfigDataClient client, String environment, String application) {

powertools-parameters/powertools-parameters-ssm/src/main/java/software/amazon/lambda/powertools/parameters/ssm/SSMProvider.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@
6666
public class SSMProvider extends BaseProvider {
6767

6868
private final SsmClient client;
69-
private boolean decrypt = false;
70-
private boolean recursive = false;
69+
private final ThreadLocal<Boolean> decrypt = ThreadLocal.withInitial(() -> false);
70+
private final ThreadLocal<Boolean> recursive = ThreadLocal.withInitial(() -> false);
7171

7272
/**
7373
* Constructor with custom {@link SsmClient}. <br/>
@@ -109,7 +109,7 @@ public static SSMProvider create() {
109109
public String getValue(String key) {
110110
GetParameterRequest request = GetParameterRequest.builder()
111111
.name(key)
112-
.withDecryption(decrypt)
112+
.withDecryption(decrypt.get())
113113
.build();
114114
return client.getParameter(request).parameter().value();
115115
}
@@ -122,7 +122,7 @@ public String getValue(String key) {
122122
* @return the provider itself in order to chain calls (eg. <pre>provider.withDecryption().get("key")</pre>).
123123
*/
124124
public SSMProvider withDecryption() {
125-
this.decrypt = true;
125+
this.decrypt.set(true);
126126
return this;
127127
}
128128

@@ -133,7 +133,7 @@ public SSMProvider withDecryption() {
133133
* @return the provider itself in order to chain calls (eg. <pre>provider.recursive().getMultiple("key")</pre>).
134134
*/
135135
public SSMProvider recursive() {
136-
this.recursive = true;
136+
this.recursive.set(true);
137137
return this;
138138
}
139139

@@ -160,8 +160,8 @@ protected Map<String, String> getMultipleValues(String path) {
160160
private Map<String, String> getMultipleBis(String path, String nextToken) {
161161
GetParametersByPathRequest request = GetParametersByPathRequest.builder()
162162
.path(path)
163-
.withDecryption(decrypt)
164-
.recursive(recursive)
163+
.withDecryption(decrypt.get())
164+
.recursive(recursive.get())
165165
.nextToken(nextToken)
166166
.build();
167167

@@ -170,12 +170,12 @@ private Map<String, String> getMultipleBis(String path, String nextToken) {
170170
// not using the client.getParametersByPathPaginator() as hardly testable
171171
GetParametersByPathResponse res = client.getParametersByPath(request);
172172
if (res.hasParameters()) {
173-
res.parameters().forEach(parameter ->
174-
{
175-
/* Standardize the parameter name
176-
The parameter name returned by SSM will contain the full path.
177-
However, for readability, we should return only the part after
178-
the path.
173+
res.parameters().forEach(parameter -> {
174+
/*
175+
* Standardize the parameter name
176+
* The parameter name returned by SSM will contain the full path.
177+
* However, for readability, we should return only the part after
178+
* the path.
179179
*/
180180
String name = parameter.name();
181181
if (name.startsWith(path)) {
@@ -196,8 +196,8 @@ private Map<String, String> getMultipleBis(String path, String nextToken) {
196196
@Override
197197
protected void resetToDefaults() {
198198
super.resetToDefaults();
199-
recursive = false;
200-
decrypt = false;
199+
decrypt.remove();
200+
recursive.remove();
201201
}
202202

203203
// For tests purpose only

powertools-parameters/powertools-parameters-ssm/src/test/java/software/amazon/lambda/powertools/parameters/ssm/SSMProviderTest.java

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.ArrayList;
2424
import java.util.List;
2525
import java.util.Map;
26+
import java.util.concurrent.CountDownLatch;
27+
2628
import org.assertj.core.data.MapEntry;
2729
import org.junit.jupiter.api.BeforeEach;
2830
import org.junit.jupiter.api.Test;
@@ -32,6 +34,7 @@
3234
import org.mockito.Mock;
3335
import org.mockito.Mockito;
3436
import org.mockito.MockitoAnnotations;
37+
3538
import software.amazon.awssdk.services.ssm.SsmClient;
3639
import software.amazon.awssdk.services.ssm.model.GetParameterRequest;
3740
import software.amazon.awssdk.services.ssm.model.GetParameterResponse;
@@ -165,8 +168,8 @@ public void getMultipleWithNextToken() {
165168
List<Parameter> parameters1 = new ArrayList<>();
166169
parameters1.add(Parameter.builder().name("/prod/app1/key1").value("foo1").build());
167170
parameters1.add(Parameter.builder().name("/prod/app1/key2").value("foo2").build());
168-
GetParametersByPathResponse response1 =
169-
GetParametersByPathResponse.builder().parameters(parameters1).nextToken("123abc").build();
171+
GetParametersByPathResponse response1 = GetParametersByPathResponse.builder().parameters(parameters1)
172+
.nextToken("123abc").build();
170173

171174
List<Parameter> parameters2 = new ArrayList<>();
172175
parameters2.add(Parameter.builder().name("/prod/app1/key3").value("foo3").build());
@@ -185,8 +188,7 @@ public void getMultipleWithNextToken() {
185188
GetParametersByPathRequest request1 = requestParams.get(0);
186189
GetParametersByPathRequest request2 = requestParams.get(1);
187190

188-
assertThat(asList(request1, request2)).allSatisfy(req ->
189-
{
191+
assertThat(asList(request1, request2)).allSatisfy(req -> {
190192
assertThat(req.path()).isEqualTo("/prod/app1");
191193
assertThat(req.withDecryption()).isFalse();
192194
assertThat(req.recursive()).isFalse();
@@ -203,7 +205,101 @@ public void testSSMProvider_withoutParameter_shouldHaveDefaultTransformationMana
203205
SSMProvider ssmProvider = SSMProvider.builder()
204206
.build();
205207
// Assert
206-
assertDoesNotThrow(()->ssmProvider.withTransformation(json));
208+
assertDoesNotThrow(() -> ssmProvider.withTransformation(json));
209+
}
210+
211+
@Test
212+
public void withDecryption_concurrentCalls_shouldBeThreadSafe() throws InterruptedException {
213+
// GIVEN
214+
Parameter param1 = Parameter.builder().value("value1").build();
215+
Parameter param2 = Parameter.builder().value("value2").build();
216+
GetParameterResponse response1 = GetParameterResponse.builder().parameter(param1).build();
217+
GetParameterResponse response2 = GetParameterResponse.builder().parameter(param2).build();
218+
CountDownLatch latch = new CountDownLatch(2);
219+
Mockito.when(client.getParameter(paramCaptor.capture()))
220+
.thenReturn(response1, response2);
221+
222+
// WHEN
223+
Thread thread1 = new Thread(() -> {
224+
try {
225+
latch.countDown();
226+
latch.await();
227+
provider.withDecryption().getValue("key1");
228+
} catch (InterruptedException e) {
229+
Thread.currentThread().interrupt();
230+
}
231+
});
232+
233+
Thread thread2 = new Thread(() -> {
234+
try {
235+
latch.countDown();
236+
latch.await();
237+
provider.getValue("key2");
238+
} catch (InterruptedException e) {
239+
Thread.currentThread().interrupt();
240+
}
241+
});
242+
243+
thread1.start();
244+
thread2.start();
245+
thread1.join();
246+
thread2.join();
247+
248+
// THEN
249+
List<GetParameterRequest> requests = paramCaptor.getAllValues();
250+
assertThat(requests).hasSize(2);
251+
boolean hasDecryptedRequest = requests.stream().anyMatch(GetParameterRequest::withDecryption);
252+
boolean hasNonDecryptedRequest = requests.stream().anyMatch(r -> !r.withDecryption());
253+
assertThat(hasDecryptedRequest).isTrue();
254+
assertThat(hasNonDecryptedRequest).isTrue();
255+
}
256+
257+
@Test
258+
public void recursive_concurrentCalls_shouldBeThreadSafe() throws InterruptedException {
259+
// GIVEN
260+
List<Parameter> params1 = new ArrayList<>();
261+
params1.add(Parameter.builder().name("/path1/key1").value("value1").build());
262+
List<Parameter> params2 = new ArrayList<>();
263+
params2.add(Parameter.builder().name("/path2/key2").value("value2").build());
264+
GetParametersByPathResponse response1 = GetParametersByPathResponse.builder().parameters(params1).build();
265+
GetParametersByPathResponse response2 = GetParametersByPathResponse.builder().parameters(params2).build();
266+
CountDownLatch latch = new CountDownLatch(2);
267+
Mockito.when(client.getParametersByPath(paramByPathCaptor.capture()))
268+
.thenReturn(response1, response2);
269+
270+
// WHEN
271+
Thread thread1 = new Thread(() -> {
272+
try {
273+
latch.countDown();
274+
latch.await();
275+
provider.recursive().getMultiple("/path1");
276+
} catch (InterruptedException e) {
277+
Thread.currentThread().interrupt();
278+
}
279+
});
280+
281+
Thread thread2 = new Thread(() -> {
282+
try {
283+
latch.countDown();
284+
latch.await();
285+
provider.getMultiple("/path2");
286+
} catch (InterruptedException e) {
287+
Thread.currentThread().interrupt();
288+
}
289+
});
290+
291+
thread1.start();
292+
thread2.start();
293+
thread1.join();
294+
thread2.join();
295+
296+
// THEN
297+
List<GetParametersByPathRequest> requests = paramByPathCaptor.getAllValues();
298+
assertThat(requests).hasSize(2);
299+
boolean hasRecursiveRequest = requests.stream().anyMatch(GetParametersByPathRequest::recursive);
300+
boolean hasNonRecursiveRequest = requests.stream().anyMatch(r -> !r.recursive());
301+
assertThat(hasRecursiveRequest).isTrue();
302+
assertThat(hasNonRecursiveRequest).isTrue();
207303
}
208304

209305
private void initMock(String expectedValue) {

0 commit comments

Comments
 (0)