Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-kafka-connect/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Breaking Changes

#### Bugs Fixed
* Added filtering to deserialization of `"azure.cosmos.client.metadata.caches.snapshot"` - See [PR 47594](https://github.com/Azure/azure-sdk-for-java/pull/47594)

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,31 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.util.Base64;

public class KafkaCosmosUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(KafkaCosmosUtils.class);
private static final String ALLOWED_CLASSES = "CosmosClientMetadataCachesSnapshot";

public static CosmosClientMetadataCachesSnapshot getCosmosClientMetadataFromString(String metadataCacheString) {
if (StringUtils.isNotEmpty(metadataCacheString)) {
byte[] inputByteArray = Base64.getDecoder().decode(metadataCacheString);
try (ObjectInputStream objectInputStream =
new ObjectInputStream(new ByteArrayInputStream(inputByteArray))) {

return (CosmosClientMetadataCachesSnapshot) objectInputStream.readObject();
new ObjectInputStream(new ByteArrayInputStream(inputByteArray)) {
@Override
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
// Whitelist only allowed classes to prevent rce from arbitrary classes
if (!ALLOWED_CLASSES.contains(desc.getName())) {
throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName());
}
return super.resolveClass(desc);
}
}){
return (CosmosClientMetadataCachesSnapshot) objectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
LOGGER.warn("Failed to deserialize cosmos client metadata cache snapshot");
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@
import org.testng.annotations.Test;
import reactor.core.publisher.Mono;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
Expand Down Expand Up @@ -214,6 +221,23 @@ public void taskConfigsForClientMetadataCachesSnapshot() {
}
}

@Test(groups = "unit")
public void evilDeserializationIsBlocked() throws Exception {
AtomicReference<String> payload = new AtomicReference<>("Test RCE payload");
Evil evil = new Evil(payload);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(evil);
}
String evilBase64 = Base64.getEncoder().encodeToString(baos.toByteArray());

// Through KafkaCosmosUtils: should be blocked and return null
CosmosClientMetadataCachesSnapshot snapshot =
KafkaCosmosUtils.getCosmosClientMetadataFromString(evilBase64);
assertThat(snapshot).isNull();
assertThat(payload.get()).isEqualTo("Test RCE payload");
}

@Test(groups = "unit")
public void misFormattedConfig() {
CosmosSinkConnector sinkConnector = new CosmosSinkConnector();
Expand Down Expand Up @@ -471,4 +495,26 @@ public static class SinkConfigs {
true)
);
}

public static class Evil implements Serializable {
private static final long serialVersionUID = 1L;

private final AtomicReference<String> payload;

public Evil(AtomicReference<String> payload) {
this.payload = payload;
}

private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
System.out.println("Payload executed");
payload.set("Payload executed");
}

@Override
public String toString() {
return "Evil{payload='" + payload.get() + "'}";
}
}

}
Loading