diff --git a/sdk/cosmos/azure-cosmos-kafka-connect/CHANGELOG.md b/sdk/cosmos/azure-cosmos-kafka-connect/CHANGELOG.md index 0f8d76dfe6a2..dd1de2564019 100644 --- a/sdk/cosmos/azure-cosmos-kafka-connect/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-kafka-connect/CHANGELOG.md @@ -7,6 +7,7 @@ #### Breaking Changes #### Bugs Fixed +* Added filtering to deserialization of `"azure.cosmos.client.metadata.caches.snapshot"` - See [PR 47594](https://github.com/Azure/azure-sdk-for-java/pull/47594) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/KafkaCosmosUtils.java b/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/KafkaCosmosUtils.java index 6009821ea188..c0b7daca499a 100644 --- a/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/KafkaCosmosUtils.java +++ b/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/KafkaCosmosUtils.java @@ -12,19 +12,37 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.util.Base64; +import java.util.HashSet; +import java.util.Set; public class KafkaCosmosUtils { private static final Logger LOGGER = LoggerFactory.getLogger(KafkaCosmosUtils.class); + private static final Set ALLOWED_CLASSES = new HashSet<>(); + static { + ALLOWED_CLASSES.add(CosmosClientMetadataCachesSnapshot.class.getName()); + ALLOWED_CLASSES.add(byte[].class.getName()); + } public static CosmosClientMetadataCachesSnapshot getCosmosClientMetadataFromString(String metadataCacheString) { if (StringUtils.isNotEmpty(metadataCacheString)) { byte[] inputByteArray = Base64.getDecoder().decode(metadataCacheString); try (ObjectInputStream objectInputStream = - new ObjectInputStream(new ByteArrayInputStream(inputByteArray))) { - + new ObjectInputStream(new ByteArrayInputStream(inputByteArray)) { + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + // Whitelist only allowed classes to prevent rce from arbitrary classes + if (!ALLOWED_CLASSES.contains(desc.getName())) { + LOGGER.error(desc.getName()); + throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); + } + return super.resolveClass(desc); + } + }) { return (CosmosClientMetadataCachesSnapshot) objectInputStream.readObject(); } catch (IOException | ClassNotFoundException e) { LOGGER.warn("Failed to deserialize cosmos client metadata cache snapshot"); diff --git a/sdk/cosmos/azure-cosmos-kafka-connect/src/test/java/com/azure/cosmos/kafka/connect/CosmosSinkConnectorTest.java b/sdk/cosmos/azure-cosmos-kafka-connect/src/test/java/com/azure/cosmos/kafka/connect/CosmosSinkConnectorTest.java index 5782819f79a5..e4488339a55c 100644 --- a/sdk/cosmos/azure-cosmos-kafka-connect/src/test/java/com/azure/cosmos/kafka/connect/CosmosSinkConnectorTest.java +++ b/sdk/cosmos/azure-cosmos-kafka-connect/src/test/java/com/azure/cosmos/kafka/connect/CosmosSinkConnectorTest.java @@ -29,14 +29,21 @@ import org.testng.annotations.Test; import reactor.core.publisher.Mono; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -214,6 +221,23 @@ public void taskConfigsForClientMetadataCachesSnapshot() { } } + @Test(groups = "unit") + public void evilDeserializationIsBlocked() throws Exception { + AtomicReference payload = new AtomicReference<>("Test RCE payload"); + Evil evil = new Evil(payload); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(evil); + } + String evilBase64 = Base64.getEncoder().encodeToString(baos.toByteArray()); + + // Through KafkaCosmosUtils: should be blocked and return null + CosmosClientMetadataCachesSnapshot snapshot = + KafkaCosmosUtils.getCosmosClientMetadataFromString(evilBase64); + assertThat(snapshot).isNull(); + assertThat(payload.get()).isEqualTo("Test RCE payload"); + } + @Test(groups = "unit") public void misFormattedConfig() { CosmosSinkConnector sinkConnector = new CosmosSinkConnector(); @@ -471,4 +495,26 @@ public static class SinkConfigs { true) ); } + + public static class Evil implements Serializable { + private static final long serialVersionUID = 1L; + + private final AtomicReference payload; + + public Evil(AtomicReference payload) { + this.payload = payload; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + System.out.println("Payload executed"); + payload.set("Payload executed"); + } + + @Override + public String toString() { + return "Evil{payload='" + payload.get() + "'}"; + } + } + }