Skip to content

Commit 35d363c

Browse files
author
Myron Scott
authored
Merge pull request #1283 from chrisdennis/safe-callbacks
Provide protection from untrusted callbacks when calling user code
2 parents de64ea2 + 15c8d3e commit 35d363c

File tree

9 files changed

+199
-19
lines changed

9 files changed

+199
-19
lines changed

tc-client/src/main/java/com/tc/object/BinaryInvocationCallback.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
package com.tc.object;
2020

2121
import org.terracotta.entity.EntityResponse;
22-
import org.terracotta.entity.InvocationCallback;
2322
import org.terracotta.entity.MessageCodec;
2423
import org.terracotta.entity.MessageCodecException;
2524

26-
public class BinaryInvocationCallback<R extends EntityResponse> implements InvocationCallback<byte[]> {
25+
public class BinaryInvocationCallback<R extends EntityResponse> implements SafeInvocationCallback<byte[]> {
2726
private final MessageCodec<?, R> codec;
28-
private final InvocationCallback<R> callback;
27+
private final SafeInvocationCallback<R> callback;
2928

30-
public BinaryInvocationCallback(MessageCodec<?, R> codec, InvocationCallback<R> callback) {
29+
public BinaryInvocationCallback(MessageCodec<?, R> codec, SafeInvocationCallback<R> callback) {
3130
this.codec = codec;
3231
this.callback = callback;
3332
}

tc-client/src/main/java/com/tc/object/ClientEntityManagerImpl.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import org.terracotta.exception.EntityServerUncaughtException;
8080

8181
import static com.tc.object.EntityDescriptor.createDescriptorForLifecycle;
82+
import static com.tc.object.SafeInvocationCallback.safe;
8283
import static java.util.stream.Collectors.toCollection;
8384
import static org.terracotta.entity.Invocation.uninterruptiblyGet;
8485

@@ -196,7 +197,7 @@ private byte[] lifecycleAndRetire(EntityID entityId, long version, VoltronEntity
196197
}
197198

198199
private Invocation<byte[]> lifecycle(EntityID entityID, EntityDescriptor entityDescriptor, VoltronEntityMessage.Type type, byte[] message) {
199-
return (callback, callbacks) -> ClientEntityManagerImpl.this.invoke(entityID, entityDescriptor, callbacks, callback, true, type, message);
200+
return (callback, callbacks) -> ClientEntityManagerImpl.this.invoke(entityID, entityDescriptor, callbacks, safe(callback), true, type, message);
200201
}
201202

202203
@Override
@@ -228,12 +229,12 @@ private Set<VoltronEntityMessage.Acks> makeServerAcks(Set<InvocationCallback.Typ
228229

229230
@Override
230231
public Invocation.Task invokeAction(EntityID eid, EntityDescriptor entityDescriptor, Set<InvocationCallback.Types> requestedCallbacks,
231-
InvocationCallback<byte[]> callback, boolean requiresReplication, byte[] payload) {
232+
SafeInvocationCallback<byte[]> callback, boolean requiresReplication, byte[] payload) {
232233
return invoke(eid, entityDescriptor, requestedCallbacks, callback, requiresReplication, VoltronEntityMessage.Type.INVOKE_ACTION, payload);
233234
}
234235

235236
private Invocation.Task invoke(EntityID eid, EntityDescriptor entityDescriptor, Set<InvocationCallback.Types> requestedCallbacks,
236-
InvocationCallback<byte[]> callback, boolean requiresReplication, VoltronEntityMessage.Type type, byte[] payload) {
237+
SafeInvocationCallback<byte[]> callback, boolean requiresReplication, VoltronEntityMessage.Type type, byte[] payload) {
237238
Set<VoltronEntityMessage.Acks> requestedAcks = makeServerAcks(requestedCallbacks);
238239
return queueInFlightMessage(eid, () -> createMessageWithDescriptor(eid, entityDescriptor, requiresReplication, payload, type, requestedAcks), callback);
239240
}
@@ -476,7 +477,7 @@ private byte[] internalRetrieve(EntityDescriptor entityDescriptor) throws Entity
476477
return lifecycleAndComplete(entityDescriptor.getEntityID(), entityDescriptor, VoltronEntityMessage.Type.FETCH_ENTITY);
477478
}
478479

479-
private Invocation.Task queueInFlightMessage(EntityID eid, Supplier<NetworkVoltronEntityMessage> message, InvocationCallback<byte[]> callback) {
480+
private Invocation.Task queueInFlightMessage(EntityID eid, Supplier<NetworkVoltronEntityMessage> message, SafeInvocationCallback<byte[]> callback) {
480481
boolean queued;
481482
try {
482483
InFlightMessage inFlight = new InFlightMessage(eid, message, callback);

tc-client/src/main/java/com/tc/object/EntityClientEndpointImpl.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import org.slf4j.Logger;
4040
import org.slf4j.LoggerFactory;
4141

42+
import static com.tc.object.SafeInvocationCallback.safe;
43+
4244

4345
public class EntityClientEndpointImpl<M extends EntityMessage, R extends EntityResponse> implements EntityClientEndpoint<M, R> {
4446

@@ -140,13 +142,13 @@ private InvocationImpl(M request) {
140142
public Task invoke(InvocationCallback<R> callback, Set<InvocationCallback.Types> callbacks) {
141143
checkInvoked();
142144
invoked = true;
143-
InvocationCallback<byte[]> binaryCallback = new BinaryInvocationCallback(codec, callback);
145+
SafeInvocationCallback<byte[]> binaryCallback = new BinaryInvocationCallback<>(codec, safe(callback));
144146
try {
145147
return invocationHandler.invokeAction(entityID, invokeDescriptor, callbacks, binaryCallback, true, codec.encodeMessage(request));
146148
} catch (MessageCodecException e) {
147-
callback.failure(e);
148-
callback.complete();
149-
callback.retired();
149+
binaryCallback.failure(e);
150+
binaryCallback.complete();
151+
binaryCallback.retired();
150152
return () -> false;
151153
}
152154
}

tc-client/src/main/java/com/tc/object/InFlightMessage.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import java.util.concurrent.atomic.AtomicReference;
5151
import java.util.function.Supplier;
5252
import com.tc.net.protocol.tcm.TCAction;
53-
import org.terracotta.entity.InvocationCallback;
5453

5554

5655
/**
@@ -62,7 +61,7 @@
6261
public class InFlightMessage implements PrettyPrintable {
6362
private final VoltronEntityMessage message;
6463
private final EntityID eid;
65-
private final InvocationCallback<byte[]> callback;
64+
private final SafeInvocationCallback<byte[]> callback;
6665

6766
private final EnumSet<VoltronEntityMessage.Acks> outstandingAcks = EnumSet.allOf(VoltronEntityMessage.Acks.class);
6867

@@ -86,7 +85,7 @@ public class InFlightMessage implements PrettyPrintable {
8685

8786
private TCNetworkMessage networkMessage;
8887

89-
public InFlightMessage(EntityID eid, Supplier<? extends VoltronEntityMessage> message, InvocationCallback<byte[]> callback) {
88+
public InFlightMessage(EntityID eid, Supplier<? extends VoltronEntityMessage> message, SafeInvocationCallback<byte[]> callback) {
9089
this.eid = requireNonNull(eid);
9190
this.message = requireNonNull(message.get());
9291
this.callback = callback;

tc-client/src/main/java/com/tc/object/InvocationHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@
2929
* The minimal interface, provided to the EntityClientEndpoint, to handle invocations to send to the server.
3030
*/
3131
public interface InvocationHandler {
32-
Invocation.Task invokeAction(EntityID eid, EntityDescriptor entityDescriptor, Set<InvocationCallback.Types> callbacks, InvocationCallback<byte[]> callback, boolean requiresReplication, byte[] payload);
32+
Invocation.Task invokeAction(EntityID eid, EntityDescriptor entityDescriptor, Set<InvocationCallback.Types> callbacks, SafeInvocationCallback<byte[]> callback, boolean requiresReplication, byte[] payload);
3333
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
*
3+
* The contents of this file are subject to the Terracotta Public License Version
4+
* 2.0 (the "License"); You may not use this file except in compliance with the
5+
* License. You may obtain a copy of the License at
6+
*
7+
* http://terracotta.org/legal/terracotta-public-license.
8+
*
9+
* Software distributed under the License is distributed on an "AS IS" basis,
10+
* WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
11+
* the specific language governing rights and limitations under the License.
12+
*
13+
* The Covered Software is Terracotta Core.
14+
*
15+
* The Initial Developer of the Covered Software is
16+
* Terracotta, Inc., a Software AG company
17+
*
18+
*/
19+
package com.tc.object;
20+
21+
import org.slf4j.Logger;
22+
import org.slf4j.LoggerFactory;
23+
import org.terracotta.entity.InvocationCallback;
24+
25+
public interface SafeInvocationCallback<R> extends InvocationCallback<R> {
26+
27+
static <R> SafeInvocationCallback<R> safe(InvocationCallback<R> callback) {
28+
if (callback instanceof SafeInvocationCallback<?>) {
29+
return (SafeInvocationCallback<R>) callback;
30+
} else {
31+
return new Guard<>(callback);
32+
}
33+
}
34+
35+
class Guard<R> implements SafeInvocationCallback<R> {
36+
private static final Logger LOGGER = LoggerFactory.getLogger(SafeInvocationCallback.class);
37+
38+
private final InvocationCallback<R> untrustedCallback;
39+
40+
private Guard(InvocationCallback<R> untrustedCallback) {
41+
this.untrustedCallback = untrustedCallback;
42+
}
43+
44+
@Override
45+
public void sent() {
46+
try {
47+
untrustedCallback.sent();
48+
} catch (Exception t) {
49+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw exception", t);
50+
} catch (Throwable t) {
51+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw throwable", t);
52+
}
53+
}
54+
55+
@Override
56+
public void received() {
57+
try {
58+
untrustedCallback.received();
59+
} catch (Exception t) {
60+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw exception", t);
61+
} catch (Throwable t) {
62+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw throwable", t);
63+
}
64+
}
65+
66+
@Override
67+
public void result(R response) {
68+
try {
69+
untrustedCallback.result(response);
70+
} catch (Exception t) {
71+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw exception", t);
72+
} catch (Throwable t) {
73+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw throwable", t);
74+
}
75+
}
76+
77+
@Override
78+
public void failure(Throwable failure) {
79+
try {
80+
untrustedCallback.failure(failure);
81+
} catch (Exception t) {
82+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw exception", t);
83+
} catch (Throwable t) {
84+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw throwable", t);
85+
}
86+
}
87+
88+
@Override
89+
public void complete() {
90+
try {
91+
untrustedCallback.complete();
92+
} catch (Exception t) {
93+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw exception", t);
94+
} catch (Throwable t) {
95+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw throwable", t);
96+
}
97+
}
98+
99+
@Override
100+
public void retired() {
101+
try {
102+
untrustedCallback.retired();
103+
} catch (Exception t) {
104+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw exception", t);
105+
} catch (Throwable t) {
106+
LOGGER.warn("User-provided callback [" + untrustedCallback + "] threw throwable", t);
107+
}
108+
}
109+
}
110+
}

tc-client/src/main/java/com/terracotta/diagnostic/DiagnosticClientEntityManager.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.tc.object.EntityDescriptor;
2828
import com.tc.object.EntityID;
2929
import com.tc.object.InFlightMessage;
30+
import com.tc.object.SafeInvocationCallback;
3031
import com.tc.object.msg.ClientHandshakeMessage;
3132
import com.tc.object.tx.TransactionID;
3233
import com.tc.util.Assert;
@@ -154,7 +155,7 @@ public void shutdown() {
154155
}
155156

156157
@Override
157-
public Invocation.Task invokeAction(EntityID eid, EntityDescriptor entityDescriptor, Set<InvocationCallback.Types> callbacks, InvocationCallback<byte[]> callback, boolean requiresReplication, byte[] payload) {
158+
public Invocation.Task invokeAction(EntityID eid, EntityDescriptor entityDescriptor, Set<InvocationCallback.Types> callbacks, SafeInvocationCallback<byte[]> callback, boolean requiresReplication, byte[] payload) {
158159
DiagnosticMessage network = createMessage(payload);
159160
InFlightMessage message = new InFlightMessage(eid, ()->network, callback);
160161
waitingForAnswer.put(network.getTransactionID(), message);

tc-client/src/test/java/com/tc/object/ClientEntityManagerTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ public void testSingleInvoke() throws Exception {
549549
when(channel.createMessage(TCMessageType.VOLTRON_ENTITY_MESSAGE)).thenReturn(message);
550550

551551
CompletableFuture<byte[]> result = new CompletableFuture<>();
552-
InvocationCallback<byte[]> callback = new InvocationCallback<byte[]>() {
552+
SafeInvocationCallback<byte[]> callback = new SafeInvocationCallback<byte[]>() {
553553
@Override
554554
public void result(byte[] response) {
555555
result.complete(response);
@@ -570,7 +570,7 @@ public void testSingleInvokeTimeout() throws Exception {
570570
when(channel.createMessage(TCMessageType.VOLTRON_ENTITY_MESSAGE)).thenReturn(message);
571571

572572
CompletableFuture<byte[]> result = new CompletableFuture<>();
573-
InvocationCallback<byte[]> callback = new InvocationCallback<byte[]>() {
573+
SafeInvocationCallback<byte[]> callback = new SafeInvocationCallback<byte[]>() {
574574
@Override
575575
public void result(byte[] response) {
576576
result.complete(response);
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
*
3+
* The contents of this file are subject to the Terracotta Public License Version
4+
* 2.0 (the "License"); You may not use this file except in compliance with the
5+
* License. You may obtain a copy of the License at
6+
*
7+
* http://terracotta.org/legal/terracotta-public-license.
8+
*
9+
* Software distributed under the License is distributed on an "AS IS" basis,
10+
* WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
11+
* the specific language governing rights and limitations under the License.
12+
*
13+
* The Covered Software is Terracotta Core.
14+
*
15+
* The Initial Developer of the Covered Software is
16+
* Terracotta, Inc., a Software AG company
17+
*
18+
*/
19+
package com.tc.object;
20+
21+
import org.junit.Test;
22+
import org.terracotta.entity.InvocationCallback;
23+
24+
import java.lang.reflect.InvocationTargetException;
25+
import java.lang.reflect.Method;
26+
27+
import static org.hamcrest.MatcherAssert.assertThat;
28+
import static org.hamcrest.Matchers.equalTo;
29+
import static org.hamcrest.Matchers.instanceOf;
30+
import static org.hamcrest.Matchers.is;
31+
import static org.junit.Assert.fail;
32+
import static org.mockito.Mockito.mock;
33+
import static org.mockito.Mockito.times;
34+
import static org.mockito.Mockito.verify;
35+
36+
public class SafeInvocationCallbackTest {
37+
38+
@Test
39+
public void testSafeInvocationCallbackIsTypedCorrectly() {
40+
assertThat(SafeInvocationCallback.safe(mock(InvocationCallback.class)), is(instanceOf(SafeInvocationCallback.class)));
41+
}
42+
43+
@Test
44+
public void testSafeInvocationCallbackCatchesAll() throws InvocationTargetException, IllegalAccessException {
45+
for (Method method : InvocationCallback.class.getDeclaredMethods()) {
46+
System.out.println(method);
47+
InvocationCallback<Object> callback = mock(InvocationCallback.class, inv -> {
48+
if (inv.getMethod().getDeclaringClass().equals(InvocationCallback.class)) {
49+
throw new Throwable();
50+
} else {
51+
return null;
52+
}
53+
});
54+
55+
SafeInvocationCallback<Object> safe = SafeInvocationCallback.safe(callback);
56+
57+
Object[] parameters = new Object[method.getParameterCount()];
58+
try {
59+
method.invoke(callback, parameters);
60+
fail("Expected Throwable");
61+
} catch (InvocationTargetException t) {
62+
assertThat(t.getCause().getClass(), is(equalTo(Throwable.class)));
63+
}
64+
method.invoke(safe, parameters);
65+
method.invoke(verify(callback, times(2)), parameters);
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)