Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-kafka-connect/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Breaking Changes

#### Bugs Fixed
* Added filtering to deserialization of `"azure.cosmos.client.metadata.caches.snapshot"` - See [PR 47594](https://github.com/Azure/azure-sdk-for-java/pull/47594)

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,37 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.util.Base64;
import java.util.HashSet;
import java.util.Set;

public class KafkaCosmosUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(KafkaCosmosUtils.class);
private static final Set<String> ALLOWED_CLASSES = new HashSet<>();
static {
ALLOWED_CLASSES.add(CosmosClientMetadataCachesSnapshot.class.getName());
ALLOWED_CLASSES.add(byte[].class.getName());
}

public static CosmosClientMetadataCachesSnapshot getCosmosClientMetadataFromString(String metadataCacheString) {
if (StringUtils.isNotEmpty(metadataCacheString)) {
byte[] inputByteArray = Base64.getDecoder().decode(metadataCacheString);
try (ObjectInputStream objectInputStream =
new ObjectInputStream(new ByteArrayInputStream(inputByteArray))) {

new ObjectInputStream(new ByteArrayInputStream(inputByteArray)) {
@Override
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
// Whitelist only allowed classes to prevent rce from arbitrary classes
if (!ALLOWED_CLASSES.contains(desc.getName())) {
LOGGER.error(desc.getName());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add some message here, like "malicious attempt to deserialize" or maybe less scary message like "invalid class type for deserialization", or something like that?

throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName());
}
return super.resolveClass(desc);
}
}) {
return (CosmosClientMetadataCachesSnapshot) objectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
LOGGER.warn("Failed to deserialize cosmos client metadata cache snapshot");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@
import org.testng.annotations.Test;
import reactor.core.publisher.Mono;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
Expand Down Expand Up @@ -214,6 +221,23 @@ public void taskConfigsForClientMetadataCachesSnapshot() {
}
}

@Test(groups = "unit")
public void evilDeserializationIsBlocked() throws Exception {
AtomicReference<String> payload = new AtomicReference<>("Test RCE payload");
Evil evil = new Evil(payload);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(evil);
}
String evilBase64 = Base64.getEncoder().encodeToString(baos.toByteArray());

// Through KafkaCosmosUtils: should be blocked and return null
CosmosClientMetadataCachesSnapshot snapshot =
KafkaCosmosUtils.getCosmosClientMetadataFromString(evilBase64);
assertThat(snapshot).isNull();
assertThat(payload.get()).isEqualTo("Test RCE payload");
}

@Test(groups = "unit")
public void misFormattedConfig() {
CosmosSinkConnector sinkConnector = new CosmosSinkConnector();
Expand Down Expand Up @@ -471,4 +495,26 @@ public static class SinkConfigs {
true)
);
}

public static class Evil implements Serializable {
private static final long serialVersionUID = 1L;

private final AtomicReference<String> payload;

public Evil(AtomicReference<String> payload) {
this.payload = payload;
}

private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
System.out.println("Payload executed");
payload.set("Payload executed");
}

@Override
public String toString() {
return "Evil{payload='" + payload.get() + "'}";
}
}

}
Loading