Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.aws.greengrass.util.Permissions;
import com.aws.greengrass.util.RegionUtils;
import com.aws.greengrass.util.RootCAUtils;
import com.aws.greengrass.util.SdkClientWrapper;
import com.aws.greengrass.util.StsSdkClientFactory;
import com.aws.greengrass.util.Utils;
import com.aws.greengrass.util.exceptions.InvalidEnvironmentStageException;
Expand Down Expand Up @@ -118,6 +119,7 @@ public class DeviceProvisioningHelper {
private final IotClient iotClient;
private final IamClient iamClient;
private final StsClient stsClient;
private final SdkClientWrapper<StsClient> stsWrapper;
private final GreengrassV2Client greengrassClient;
private EnvironmentStage envStage = EnvironmentStage.PROD;
private boolean thingGroupExists = false;
Expand All @@ -139,7 +141,10 @@ public DeviceProvisioningHelper(String awsRegion, String environmentStage, Print
: EnvironmentStage.fromString(environmentStage);
this.iotClient = IotSdkClientFactory.getIotClient(awsRegion, envStage);
this.iamClient = IamSdkClientFactory.getIamClient(awsRegion);
//TODO: Need to remove this
this.stsClient = StsSdkClientFactory.getStsClient(awsRegion);
this.stsWrapper = new SdkClientWrapper<>(() ->
StsSdkClientFactory.getStsClient(awsRegion));
this.greengrassClient = GreengrassV2Client.builder().endpointOverride(
URI.create(RegionUtils.getGreengrassControlPlaneEndpoint(awsRegion, this.envStage)))
.region(Region.of(awsRegion))
Expand All @@ -162,6 +167,8 @@ public DeviceProvisioningHelper(String awsRegion, String environmentStage, Print
this.iamClient = iamClient;
this.stsClient = stsClient;
this.greengrassClient = greengrassClient;
//TODO: change unit tests
stsWrapper = null;
}

/**
Expand Down Expand Up @@ -454,7 +461,8 @@ private Optional<String> getPolicyArn(String policyName, Region awsRegion) {
}

private String getAccountId() {
return stsClient.getCallerIdentity(GetCallerIdentityRequest.builder().build()).account();
return stsWrapper.execute(client ->
client.getCallerIdentity(GetCallerIdentityRequest.builder().build()).account());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

software.amazon.awssdk.services.sts.StsClient.getCallerIdentity API can also throw the following exception types: SdkClientException, StsException, SdkException, AwsServiceException, UnsupportedOperationException. We recommend handling these uncaught exceptions as well.

}

/**
Expand Down
91 changes: 91 additions & 0 deletions src/main/java/com/aws/greengrass/util/SdkClientWrapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package com.aws.greengrass.util;

import com.aws.greengrass.logging.api.Logger;
import com.aws.greengrass.logging.impl.LogManager;
import org.apache.http.NoHttpResponseException;
import software.amazon.awssdk.core.SdkClient;
import software.amazon.awssdk.core.exception.SdkClientException;

import java.net.SocketException;
import java.util.concurrent.locks.Lock;
import java.util.function.Function;
import java.util.function.Supplier;

public final class SdkClientWrapper<T extends SdkClient> {
private static final Logger logger = LogManager.getLogger(SdkClientWrapper.class);

private volatile T client;
private final Supplier<T> clientFactory;
private final Lock lock = LockFactory.newReentrantLock(this);

public SdkClientWrapper(Supplier<T> clientFactory) {
this.clientFactory = clientFactory;
this.client = clientFactory.get();
}

/**
* Executes the given operation on the client, handling potential SDK client exceptions.
*
* <p>This method applies the provided operation to the client. If an {@link SdkClientException}
* occurs and the client needs refreshing (as determined by {@link #shouldRefreshClient(SdkClientException)}),
* it will attempt to refresh the client and retry the operation once.</p>
*
* @param <R> The return type of the operation
* @param operation A function that takes the client of type T and returns a result of type R
* @return The result of the operation
* @throws SdkClientException If the operation fails and the client cannot be refreshed or fails after refresh
* @throws RuntimeException If an unexpected error occurs during execution
*/
public <R> R execute(final Function<T, R> operation) {
try {
return operation.apply(client);
} catch (SdkClientException e) {
if (shouldRefreshClient(e)) {
logger.atDebug().log("Client needs refresh due to: {}", e.getMessage());
try {
refreshClient();
return operation.apply(client);
} catch (SdkClientException retryException) {
logger.atError().log("Failed to execute operation after client refresh", retryException);
throw retryException;
}
}
logger.atError().log("SDK client operation failed", e);
throw e;
}
}

private void refreshClient() {
try (LockScope ls = LockScope.lock(lock)) {
if (client != null) {
try {
client.close();
} catch (SdkClientException e) {
logger.atError().log("Error closing client: " + e.getMessage());
}
}
// Creates new client when refresh needed
client = clientFactory.get();
}
}

private boolean shouldRefreshClient(SdkClientException e) {
Throwable cause = e;
while (cause != null) {
if (cause instanceof SocketException && "Connection reset".equals(cause.getMessage())) {
return true;
}
if (cause instanceof NoHttpResponseException) {
return true;
}
// Add other conditions that should trigger a client refresh here
cause = cause.getCause();
}
return false;
}
}
Loading