|
| 1 | +package io.github.pixee.security; |
| 2 | + |
| 3 | +import io.github.pixee.security.ObjectInputFilters; |
| 4 | +import static org.hamcrest.CoreMatchers.*; |
| 5 | +import static org.hamcrest.MatcherAssert.assertThat; |
| 6 | +import static org.junit.jupiter.api.Assertions.assertThrows; |
| 7 | +import static org.junit.jupiter.api.Assertions.fail; |
| 8 | +import static org.mockito.Mockito.*; |
| 9 | + |
| 10 | +import java.io.*; |
| 11 | +import java.nio.file.Files; |
| 12 | +import org.apache.commons.fileupload.disk.DiskFileItem; |
| 13 | +import org.junit.jupiter.api.BeforeAll; |
| 14 | +import org.junit.jupiter.api.Test; |
| 15 | + |
| 16 | +final class ObjectInputFiltersTest { |
| 17 | + |
| 18 | + private static DiskFileItem gadget; // this is an evil gadget type |
| 19 | + private static byte[] serializedGadget; // this the serialized bytes of that gadget |
| 20 | + |
| 21 | + @BeforeAll |
| 22 | + static void setup() throws IOException { |
| 23 | + ByteArrayOutputStream baos = new ByteArrayOutputStream(); |
| 24 | + gadget = |
| 25 | + new DiskFileItem( |
| 26 | + "fieldName", |
| 27 | + "text/html", |
| 28 | + false, |
| 29 | + "foo.html", |
| 30 | + 100, |
| 31 | + Files.createTempDirectory("adi").toFile()); |
| 32 | + gadget.getOutputStream(); // needed to make the object serializable |
| 33 | + ObjectOutputStream oos = new ObjectOutputStream(baos); |
| 34 | + oos.writeObject(gadget); |
| 35 | + serializedGadget = baos.toByteArray(); |
| 36 | + } |
| 37 | + |
| 38 | + @Test |
| 39 | + void default_is_unprotected() throws Exception { |
| 40 | + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(serializedGadget)); |
| 41 | + ObjectInputFilters.enableObjectFilterIfUnprotected(ois); |
| 42 | + Object o = ois.readObject(); |
| 43 | + assertThat(o, instanceOf(DiskFileItem.class)); |
| 44 | + } |
| 45 | + |
| 46 | + |
| 47 | + @Test |
| 48 | + void ois_harden_works() throws Exception { |
| 49 | + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(serializedGadget)); |
| 50 | + ObjectInputFilters.enableObjectFilterIfUnprotected(ois); |
| 51 | + assertThrows( |
| 52 | + InvalidClassException.class, |
| 53 | + () -> { |
| 54 | + ois.readObject(); |
| 55 | + fail("this should have been blocked"); |
| 56 | + }); |
| 57 | + } |
| 58 | + |
| 59 | + @Test |
| 60 | + void objectinputfilter_works_when_none_present() throws Exception { |
| 61 | + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(serializedGadget)); |
| 62 | + ois.setObjectInputFilter(ObjectInputFilters.getHardenedObjectFilter()); |
| 63 | + assertThrows( |
| 64 | + InvalidClassException.class, |
| 65 | + () -> { |
| 66 | + ois.readObject(); |
| 67 | + fail("this should have been blocked"); |
| 68 | + }); |
| 69 | + } |
| 70 | + |
| 71 | + /** |
| 72 | + * This test makes sure that if there's an existing {@link ObjectInputFilter}, that we honor it |
| 73 | + * while we also do our protection. It bans a BadType, and allows a GoodType, so that behavior |
| 74 | + * should still work as well as still reject our evil gadgets. |
| 75 | + */ |
| 76 | + @Test |
| 77 | + void objectinputfilter_works_and_honors_existing() throws Exception { |
| 78 | + ObjectInputFilter filter = |
| 79 | + ObjectInputFilter.Config.createFilter( |
| 80 | + "!" + BadType.class.getName() + ";" + GoodType.class.getName()); |
| 81 | + { |
| 82 | + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(serializedGadget)); |
| 83 | + // this joins the default filter to the existing |
| 84 | + ois.setObjectInputFilter(ObjectInputFilters.createCombinedHardenedObjectFilter(filter)); |
| 85 | + assertThrows( |
| 86 | + InvalidClassException.class, |
| 87 | + () -> { |
| 88 | + ois.readObject(); |
| 89 | + fail("this should have been blocked"); |
| 90 | + }); |
| 91 | + } |
| 92 | + |
| 93 | + // make sure we still reject the bad type |
| 94 | + { |
| 95 | + byte[] serializedBadType = serialize(new BadType()); |
| 96 | + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(serializedBadType)); |
| 97 | + ois.setObjectInputFilter( |
| 98 | + ObjectInputFilters.createCombinedHardenedObjectFilter(filter)); // this is our weave |
| 99 | + |
| 100 | + assertThrows( |
| 101 | + InvalidClassException.class, |
| 102 | + () -> { |
| 103 | + ois.readObject(); |
| 104 | + fail("this should have been blocked -- the original filter should have rejected it"); |
| 105 | + }); |
| 106 | + } |
| 107 | + |
| 108 | + // make we still allow the good type |
| 109 | + { |
| 110 | + byte[] serializedGoodType = serialize(new GoodType()); |
| 111 | + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(serializedGoodType)); |
| 112 | + ois.setObjectInputFilter(ObjectInputFilters.createCombinedHardenedObjectFilter(filter)); |
| 113 | + GoodType goodType = (GoodType) ois.readObject(); |
| 114 | + assertThat(goodType, is(notNullValue())); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + @Test |
| 119 | + void the_filter_works_as_expected() { |
| 120 | + ObjectInputFilter filter = ObjectInputFilters.getHardenedObjectFilter(); |
| 121 | + ObjectInputFilter.FilterInfo filterInfo = mock(ObjectInputFilter.FilterInfo.class); |
| 122 | + |
| 123 | + // we never want to interfere with existing logic, so we never explicitly approve anything, even |
| 124 | + // innocent j.l.String |
| 125 | + doReturn(String.class).when(filterInfo).serialClass(); |
| 126 | + ObjectInputFilter.Status status = filter.checkInput(filterInfo); |
| 127 | + assertThat(status, is(ObjectInputFilter.Status.UNDECIDED)); |
| 128 | + |
| 129 | + // this confirm that the exact match of ProcessBuilder in our list is caught by the filter |
| 130 | + doReturn(ProcessBuilder.class).when(filterInfo).serialClass(); |
| 131 | + ObjectInputFilter.Status exactMatch = filter.checkInput(filterInfo); |
| 132 | + assertThat(exactMatch, is(ObjectInputFilter.Status.REJECTED)); |
| 133 | + |
| 134 | + // this confirms that although the Redirect type is not explicitly in the tokens, it's still |
| 135 | + // caught by the wildcard |
| 136 | + doReturn(ProcessBuilder.Redirect.class).when(filterInfo).serialClass(); |
| 137 | + ObjectInputFilter.Status partialMatch = filter.checkInput(filterInfo); |
| 138 | + assertThat(partialMatch, is(ObjectInputFilter.Status.REJECTED)); |
| 139 | + } |
| 140 | + |
| 141 | + byte[] serialize(Serializable s) throws IOException { |
| 142 | + ByteArrayOutputStream stream = new ByteArrayOutputStream(); |
| 143 | + new ObjectOutputStream(stream).writeObject(s); |
| 144 | + return stream.toByteArray(); |
| 145 | + } |
| 146 | + |
| 147 | + static class BadType implements Serializable {} |
| 148 | + |
| 149 | + static class GoodType implements Serializable {} |
| 150 | +} |
0 commit comments