Skip to content

Commit ef48cd3

Browse files
committed
Fix unsafe deserialization
1 parent 8295935 commit ef48cd3

File tree

2 files changed

+62
-17
lines changed

2 files changed

+62
-17
lines changed

sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/KafkaCosmosUtils.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,31 @@
1212
import java.io.ByteArrayInputStream;
1313
import java.io.ByteArrayOutputStream;
1414
import java.io.IOException;
15+
import java.io.InvalidClassException;
1516
import java.io.ObjectInputStream;
1617
import java.io.ObjectOutputStream;
18+
import java.io.ObjectStreamClass;
1719
import java.util.Base64;
1820

1921
public class KafkaCosmosUtils {
2022
private static final Logger LOGGER = LoggerFactory.getLogger(KafkaCosmosUtils.class);
23+
private static final String ALLOWED_CLASSES = "CosmosClientMetadataCachesSnapshot";
2124

2225
public static CosmosClientMetadataCachesSnapshot getCosmosClientMetadataFromString(String metadataCacheString) {
2326
if (StringUtils.isNotEmpty(metadataCacheString)) {
2427
byte[] inputByteArray = Base64.getDecoder().decode(metadataCacheString);
2528
try (ObjectInputStream objectInputStream =
26-
new ObjectInputStream(new ByteArrayInputStream(inputByteArray))) {
27-
28-
return (CosmosClientMetadataCachesSnapshot) objectInputStream.readObject();
29+
new ObjectInputStream(new ByteArrayInputStream(inputByteArray)) {
30+
@Override
31+
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
32+
// Whitelist only allowed classes to prevent rce from arbitrary classes
33+
if (!ALLOWED_CLASSES.contains(desc.getName())) {
34+
throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName());
35+
}
36+
return super.resolveClass(desc);
37+
}
38+
}){
39+
return (CosmosClientMetadataCachesSnapshot) objectInputStream.readObject();
2940
} catch (IOException | ClassNotFoundException e) {
3041
LOGGER.warn("Failed to deserialize cosmos client metadata cache snapshot");
3142
return null;

sdk/cosmos/azure-cosmos-kafka-connect/src/test/java/com/azure/cosmos/kafka/connect/CosmosSinkConnectorTest.java

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import com.azure.cosmos.kafka.connect.implementation.sink.CosmosSinkTaskConfig;
1919
import com.azure.cosmos.kafka.connect.implementation.sink.IdStrategyType;
2020
import com.azure.cosmos.kafka.connect.implementation.sink.ItemWriteStrategy;
21-
import com.azure.cosmos.kafka.connect.implementation.sink.patch.KafkaCosmosPatchOperationType;
2221
import com.azure.cosmos.models.CosmosContainerProperties;
2322
import org.apache.kafka.common.config.Config;
2423
import org.apache.kafka.common.config.ConfigDef;
@@ -29,14 +28,23 @@
2928
import org.testng.annotations.Test;
3029
import reactor.core.publisher.Mono;
3130

31+
import java.io.ByteArrayInputStream;
32+
import java.io.ByteArrayOutputStream;
33+
import java.io.IOException;
34+
import java.io.InvalidClassException;
35+
import java.io.ObjectInputStream;
36+
import java.io.ObjectOutputStream;
37+
import java.io.Serializable;
3238
import java.util.ArrayList;
3339
import java.util.Arrays;
40+
import java.util.Base64;
3441
import java.util.Collections;
3542
import java.util.HashMap;
3643
import java.util.List;
3744
import java.util.Map;
3845
import java.util.UUID;
3946
import java.util.concurrent.atomic.AtomicInteger;
47+
import java.util.concurrent.atomic.AtomicReference;
4048
import java.util.stream.Collectors;
4149

4250
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@@ -214,6 +222,23 @@ public void taskConfigsForClientMetadataCachesSnapshot() {
214222
}
215223
}
216224

225+
@Test(groups = "unit")
226+
public void evilDeserializationIsBlocked() throws Exception {
227+
AtomicReference<String> payload = new AtomicReference<>("Test RCE payload");
228+
Evil evil = new Evil(payload);
229+
ByteArrayOutputStream baos = new ByteArrayOutputStream();
230+
try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
231+
oos.writeObject(evil);
232+
}
233+
String evilBase64 = Base64.getEncoder().encodeToString(baos.toByteArray());
234+
235+
// Through KafkaCosmosUtils: should be blocked and return null
236+
CosmosClientMetadataCachesSnapshot snapshot =
237+
KafkaCosmosUtils.getCosmosClientMetadataFromString(evilBase64);
238+
assertThat(snapshot).isNull();
239+
assertThat(payload.get()).isEqualTo("Test RCE payload");
240+
}
241+
217242
@Test(groups = "unit")
218243
public void misFormattedConfig() {
219244
CosmosSinkConnector sinkConnector = new CosmosSinkConnector();
@@ -450,19 +475,6 @@ public static class SinkConfigs {
450475
"azure.cosmos.sink.write.strategy",
451476
ItemWriteStrategy.ITEM_OVERWRITE.getName(),
452477
true),
453-
new KafkaCosmosConfigEntry<String>(
454-
"azure.cosmos.sink.write.patch.operationType.default",
455-
KafkaCosmosPatchOperationType.SET.getName(),
456-
true),
457-
new KafkaCosmosConfigEntry<String>(
458-
"azure.cosmos.sink.write.patch.property.configs",
459-
Strings.Emtpy,
460-
true),
461-
new KafkaCosmosConfigEntry<String>(
462-
"azure.cosmos.sink.write.patch.filter",
463-
Strings.Emtpy,
464-
true),
465-
new KafkaCosmosConfigEntry<Integer>("azure.cosmos.sink.maxRetryCount", 10, true),
466478
new KafkaCosmosConfigEntry<String>("azure.cosmos.sink.database.name", null, false),
467479
new KafkaCosmosConfigEntry<String>("azure.cosmos.sink.containers.topicMap", null, false),
468480
new KafkaCosmosConfigEntry<String>(
@@ -471,4 +483,26 @@ public static class SinkConfigs {
471483
true)
472484
);
473485
}
486+
487+
public static class Evil implements Serializable {
488+
private static final long serialVersionUID = 1L;
489+
490+
private final AtomicReference<String> payload;
491+
492+
public Evil(AtomicReference<String> payload) {
493+
this.payload = payload;
494+
}
495+
496+
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
497+
in.defaultReadObject();
498+
System.out.println("Payload executed");
499+
payload.set("Payload executed");
500+
}
501+
502+
@Override
503+
public String toString() {
504+
return "Evil{payload='" + payload.get() + "'}";
505+
}
506+
}
507+
474508
}

0 commit comments

Comments
 (0)