Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collections;
Expand Down Expand Up @@ -144,8 +145,15 @@ public Object getAttribute(String name) {
oos.writeObject(obj);
oos.close();

// Create filter from user configuration for secure deserialization
String filterPattern = getServletContext()
.getInitParameter("serializable-object-filter");
ObjectInputFilter filter = filterPattern != null
? ObjectInputFilter.Config.createFilter(filterPattern)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we concerned with any duplicate object creations? Should we guard them against races?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate your excellent point, @sboorlagadda. I've implemented filter caching using the double-checked locking pattern with volatile fields. The changes include:

  • Added private volatile ObjectInputFilter cachedFilter to cache the filter instance
  • Added private volatile boolean filterLogged to ensure one-time logging
  • Implemented getOrCreateFilter() method that creates and caches the filter on first use
  • The double-checked locking ensures thread-safety without synchronization overhead on subsequent calls

This eliminates both the performance overhead of recreating the filter on every deserialization and prevents race conditions in multi-threaded servlet environments.

: null;

ObjectInputStream ois = new ClassLoaderObjectInputStream(
new ByteArrayInputStream(baos.toByteArray()), loader);
new ByteArrayInputStream(baos.toByteArray()), loader, filter);
tmpObj = ois.readObject();
} catch (IOException | ClassNotFoundException e) {
LOG.error("Exception while recreating attribute '" + name + "'", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,41 @@

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;

/**
* This class is used when session attributes need to be reconstructed with a new classloader.
* It now supports ObjectInputFilter for secure deserialization.
*/
public class ClassLoaderObjectInputStream extends ObjectInputStream {

private final ClassLoader loader;

/**
* Constructs a ClassLoaderObjectInputStream with an ObjectInputFilter for secure deserialization.
*
* @param in the input stream to read from
* @param loader the ClassLoader to use for class resolution
* @param filter the ObjectInputFilter to validate deserialized classes (required for security)
* @throws IOException if an I/O error occurs
*/
public ClassLoaderObjectInputStream(InputStream in, ClassLoader loader, ObjectInputFilter filter)
throws IOException {
super(in);
this.loader = loader;
setObjectInputFilter(filter);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a null check?

 if (filter != null) {
    setObjectInputFilter(filter);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Added a null check.

}

/**
* Legacy constructor for backward compatibility.
*
* @deprecated Use
* {@link #ClassLoaderObjectInputStream(InputStream, ClassLoader, ObjectInputFilter)}
* with a filter for secure deserialization
*/
@Deprecated
public ClassLoaderObjectInputStream(InputStream in, ClassLoader loader) throws IOException {
super(in);
this.loader = loader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
Expand Down Expand Up @@ -162,4 +164,142 @@ File getTempFile() {
return null;
}
}

@Test
public void filterRejectsUnauthorizedClasses() throws Exception {
// Arrange: Create filter that only allows java.lang and java.util classes
ObjectInputFilter filter = ObjectInputFilter.Config.createFilter("java.lang.*;java.util.*;!*");
TestSerializable testObject = new TestSerializable("test");
byte[] serializedData = serialize(testObject);

// Act & Assert: Deserialization should be rejected by filter
assertThatThrownBy(() -> {
try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream(
new ByteArrayInputStream(serializedData),
Thread.currentThread().getContextClassLoader(),
filter)) {
ois.readObject();
}
}).isInstanceOf(InvalidClassException.class);
}

@Test
public void filterAllowsAuthorizedClasses() throws Exception {
// Arrange: Create filter that allows this test class package
ObjectInputFilter filter = ObjectInputFilter.Config.createFilter(
"java.lang.*;java.util.*;org.apache.geode.modules.util.**;!*");
TestSerializable testObject = new TestSerializable("test data");
byte[] serializedData = serialize(testObject);

// Act: Deserialize with filter
Object deserialized;
try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream(
new ByteArrayInputStream(serializedData),
Thread.currentThread().getContextClassLoader(),
filter)) {
deserialized = ois.readObject();
}

// Assert: Object should be successfully deserialized
assertThat(deserialized).isInstanceOf(TestSerializable.class);
assertThat(((TestSerializable) deserialized).getData()).isEqualTo("test data");
}

@Test
public void nullFilterAllowsAllClasses() throws Exception {
// Arrange: Null filter means no filtering (backward compatibility)
TestSerializable testObject = new TestSerializable("unfiltered data");
byte[] serializedData = serialize(testObject);

// Act: Deserialize with null filter
Object deserialized;
try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream(
new ByteArrayInputStream(serializedData),
Thread.currentThread().getContextClassLoader(),
null)) {
deserialized = ois.readObject();
}

// Assert: Object should be successfully deserialized
assertThat(deserialized).isInstanceOf(TestSerializable.class);
assertThat(((TestSerializable) deserialized).getData()).isEqualTo("unfiltered data");
}

@Test
public void deprecatedConstructorStillWorks() throws Exception {
// Arrange: Use deprecated constructor without filter
TestSerializable testObject = new TestSerializable("legacy code");
byte[] serializedData = serialize(testObject);

// Act: Deserialize using deprecated constructor
Object deserialized;
try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream(
new ByteArrayInputStream(serializedData),
Thread.currentThread().getContextClassLoader())) {
deserialized = ois.readObject();
}

// Assert: Object should be successfully deserialized (backward compatibility)
assertThat(deserialized).isInstanceOf(TestSerializable.class);
assertThat(((TestSerializable) deserialized).getData()).isEqualTo("legacy code");
}

@Test
public void filterEnforcesResourceLimits() throws Exception {
// Arrange: Create filter with very low depth limit
ObjectInputFilter filter = ObjectInputFilter.Config.createFilter("maxdepth=2;*");
NestedSerializable nested = new NestedSerializable(
new NestedSerializable(
new NestedSerializable(null))); // Depth of 3
byte[] serializedData = serialize(nested);

// Act & Assert: Should reject due to depth limit
assertThatThrownBy(() -> {
try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream(
new ByteArrayInputStream(serializedData),
Thread.currentThread().getContextClassLoader(),
filter)) {
ois.readObject();
}
}).isInstanceOf(InvalidClassException.class);
}

/**
* Helper method to serialize an object to byte array
*/
private byte[] serialize(Object obj) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
}
return baos.toByteArray();
}

/**
* Test class for serialization testing
*/
static class TestSerializable implements Serializable {
private static final long serialVersionUID = 1L;
private final String data;

TestSerializable(String data) {
this.data = data;
}

String getData() {
return data;
}
}

/**
* Nested test class for depth limit testing
*/
static class NestedSerializable implements Serializable {
private static final long serialVersionUID = 1L;
private final NestedSerializable nested;

NestedSerializable(NestedSerializable nested) {
this.nested = nested;
}
}
}
Loading
Loading