Skip to content

Commit c5f312a

Browse files
Marcelo Vanzindongjoon-hyun
authored andcommitted
[SPARK-30129][CORE] Set client's id in TransportClient after successful auth
The new auth code was missing this bit, so it was not possible to know which app a client belonged to when auth was on. I also refactored the SASL test that checks for this so it also checks the new protocol (test failed before the fix, passes now). Closes apache#26760 from vanzin/SPARK-30129. Authored-by: Marcelo Vanzin <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 29e09a8 commit c5f312a

File tree

4 files changed

+186
-106
lines changed

4 files changed

+186
-106
lines changed

common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ public void doBootstrap(TransportClient client, Channel channel) {
7878

7979
try {
8080
doSparkAuth(client, channel);
81+
client.setClientId(appId);
8182
} catch (GeneralSecurityException | IOException e) {
8283
throw Throwables.propagate(e);
8384
} catch (RuntimeException e) {

common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
125125
response.encode(responseData);
126126
callback.onSuccess(responseData.nioBuffer());
127127
engine.sessionCipher().addToChannel(channel);
128+
client.setClientId(challenge.appId);
128129
} catch (Exception e) {
129130
// This is a fatal error: authentication has failed. Close the channel explicitly.
130131
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());

common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import java.nio.ByteBuffer;
2222
import java.util.ArrayList;
2323
import java.util.Arrays;
24-
import java.util.concurrent.CountDownLatch;
25-
import java.util.concurrent.atomic.AtomicReference;
2624

2725
import org.junit.After;
2826
import org.junit.AfterClass;
@@ -34,8 +32,6 @@
3432

3533
import org.apache.spark.network.TestUtils;
3634
import org.apache.spark.network.TransportContext;
37-
import org.apache.spark.network.buffer.ManagedBuffer;
38-
import org.apache.spark.network.client.ChunkReceivedCallback;
3935
import org.apache.spark.network.client.RpcResponseCallback;
4036
import org.apache.spark.network.client.TransportClient;
4137
import org.apache.spark.network.client.TransportClientFactory;
@@ -44,15 +40,6 @@
4440
import org.apache.spark.network.server.StreamManager;
4541
import org.apache.spark.network.server.TransportServer;
4642
import org.apache.spark.network.server.TransportServerBootstrap;
47-
import org.apache.spark.network.shuffle.BlockFetchingListener;
48-
import org.apache.spark.network.shuffle.ExternalBlockHandler;
49-
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
50-
import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
51-
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
52-
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
53-
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
54-
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
55-
import org.apache.spark.network.shuffle.protocol.StreamHandle;
5643
import org.apache.spark.network.util.JavaUtils;
5744
import org.apache.spark.network.util.MapConfigProvider;
5845
import org.apache.spark.network.util.TransportConf;
@@ -165,93 +152,6 @@ public void testNoSaslServer() {
165152
}
166153
}
167154

168-
/**
169-
* This test is not actually testing SASL behavior, but testing that the shuffle service
170-
* performs correct authorization checks based on the SASL authentication data.
171-
*/
172-
@Test
173-
public void testAppIsolation() throws Exception {
174-
// Start a new server with the correct RPC handler to serve block data.
175-
ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
176-
ExternalBlockHandler blockHandler = new ExternalBlockHandler(
177-
new OneForOneStreamManager(), blockResolver);
178-
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
179-
180-
try (
181-
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
182-
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
183-
// Create a client, and make a request to fetch blocks from a different app.
184-
TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
185-
Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
186-
TransportClient client1 = clientFactory1.createClient(
187-
TestUtils.getLocalHost(), blockServer.getPort())) {
188-
189-
AtomicReference<Throwable> exception = new AtomicReference<>();
190-
191-
CountDownLatch blockFetchLatch = new CountDownLatch(1);
192-
BlockFetchingListener listener = new BlockFetchingListener() {
193-
@Override
194-
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
195-
blockFetchLatch.countDown();
196-
}
197-
@Override
198-
public void onBlockFetchFailure(String blockId, Throwable t) {
199-
exception.set(t);
200-
blockFetchLatch.countDown();
201-
}
202-
};
203-
204-
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
205-
OneForOneBlockFetcher fetcher =
206-
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
207-
fetcher.start();
208-
blockFetchLatch.await();
209-
checkSecurityException(exception.get());
210-
211-
// Register an executor so that the next steps work.
212-
ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
213-
new String[] { System.getProperty("java.io.tmpdir") }, 1,
214-
"org.apache.spark.shuffle.sort.SortShuffleManager");
215-
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
216-
client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
217-
218-
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
219-
// fetch any blocks, to keep the stream open.
220-
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
221-
ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
222-
StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
223-
long streamId = stream.streamId;
224-
225-
try (
226-
// Create a second client, authenticated with a different app ID, and try to read from
227-
// the stream created for the previous app.
228-
TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
229-
Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
230-
TransportClient client2 = clientFactory2.createClient(
231-
TestUtils.getLocalHost(), blockServer.getPort())
232-
) {
233-
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
234-
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
235-
@Override
236-
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
237-
chunkReceivedLatch.countDown();
238-
}
239-
240-
@Override
241-
public void onFailure(int chunkIndex, Throwable t) {
242-
exception.set(t);
243-
chunkReceivedLatch.countDown();
244-
}
245-
};
246-
247-
exception.set(null);
248-
client2.fetchChunk(streamId, 0, callback);
249-
chunkReceivedLatch.await();
250-
checkSecurityException(exception.get());
251-
}
252-
}
253-
}
254-
255155
/** RPC handler which simply responds with the message it received. */
256156
public static class TestRpcHandler extends RpcHandler {
257157
@Override
@@ -264,10 +164,4 @@ public StreamManager getStreamManager() {
264164
return new OneForOneStreamManager();
265165
}
266166
}
267-
268-
private static void checkSecurityException(Throwable t) {
269-
assertNotNull("No exception was caught.", t);
270-
assertTrue("Expected SecurityException.",
271-
t.getMessage().contains(SecurityException.class.getName()));
272-
}
273167
}
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.shuffle;
19+
20+
import java.nio.ByteBuffer;
21+
import java.util.Arrays;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
import java.util.concurrent.CountDownLatch;
25+
import java.util.concurrent.atomic.AtomicReference;
26+
import java.util.function.Function;
27+
import java.util.function.Supplier;
28+
29+
import org.junit.BeforeClass;
30+
import org.junit.Test;
31+
32+
import static org.junit.Assert.*;
33+
import static org.mockito.Mockito.*;
34+
35+
import org.apache.spark.network.TestUtils;
36+
import org.apache.spark.network.TransportContext;
37+
import org.apache.spark.network.buffer.ManagedBuffer;
38+
import org.apache.spark.network.client.ChunkReceivedCallback;
39+
import org.apache.spark.network.client.TransportClient;
40+
import org.apache.spark.network.client.TransportClientBootstrap;
41+
import org.apache.spark.network.client.TransportClientFactory;
42+
import org.apache.spark.network.crypto.AuthClientBootstrap;
43+
import org.apache.spark.network.crypto.AuthServerBootstrap;
44+
import org.apache.spark.network.sasl.SaslClientBootstrap;
45+
import org.apache.spark.network.sasl.SaslServerBootstrap;
46+
import org.apache.spark.network.sasl.SecretKeyHolder;
47+
import org.apache.spark.network.server.OneForOneStreamManager;
48+
import org.apache.spark.network.server.TransportServer;
49+
import org.apache.spark.network.server.TransportServerBootstrap;
50+
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
51+
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
52+
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
53+
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
54+
import org.apache.spark.network.shuffle.protocol.StreamHandle;
55+
import org.apache.spark.network.util.MapConfigProvider;
56+
import org.apache.spark.network.util.TransportConf;
57+
58+
public class AppIsolationSuite {
59+
60+
// Use a long timeout to account for slow / overloaded build machines. In the normal case,
61+
// tests should finish way before the timeout expires.
62+
private static final long TIMEOUT_MS = 10_000;
63+
64+
private static SecretKeyHolder secretKeyHolder;
65+
private static TransportConf conf;
66+
67+
@BeforeClass
68+
public static void beforeAll() {
69+
Map<String, String> confMap = new HashMap<>();
70+
confMap.put("spark.network.crypto.enabled", "true");
71+
confMap.put("spark.network.crypto.saslFallback", "false");
72+
conf = new TransportConf("shuffle", new MapConfigProvider(confMap));
73+
74+
secretKeyHolder = mock(SecretKeyHolder.class);
75+
when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
76+
when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
77+
when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
78+
when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
79+
}
80+
81+
@Test
82+
public void testSaslAppIsolation() throws Exception {
83+
testAppIsolation(
84+
() -> new SaslServerBootstrap(conf, secretKeyHolder),
85+
appId -> new SaslClientBootstrap(conf, appId, secretKeyHolder));
86+
}
87+
88+
@Test
89+
public void testAuthEngineAppIsolation() throws Exception {
90+
testAppIsolation(
91+
() -> new AuthServerBootstrap(conf, secretKeyHolder),
92+
appId -> new AuthClientBootstrap(conf, appId, secretKeyHolder));
93+
}
94+
95+
private void testAppIsolation(
96+
Supplier<TransportServerBootstrap> serverBootstrap,
97+
Function<String, TransportClientBootstrap> clientBootstrapFactory) throws Exception {
98+
// Start a new server with the correct RPC handler to serve block data.
99+
ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
100+
ExternalBlockHandler blockHandler = new ExternalBlockHandler(
101+
new OneForOneStreamManager(), blockResolver);
102+
TransportServerBootstrap bootstrap = serverBootstrap.get();
103+
104+
try (
105+
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
106+
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
107+
// Create a client, and make a request to fetch blocks from a different app.
108+
TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
109+
Arrays.asList(clientBootstrapFactory.apply("app-1")));
110+
TransportClient client1 = clientFactory1.createClient(
111+
TestUtils.getLocalHost(), blockServer.getPort())) {
112+
113+
AtomicReference<Throwable> exception = new AtomicReference<>();
114+
115+
CountDownLatch blockFetchLatch = new CountDownLatch(1);
116+
BlockFetchingListener listener = new BlockFetchingListener() {
117+
@Override
118+
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
119+
blockFetchLatch.countDown();
120+
}
121+
@Override
122+
public void onBlockFetchFailure(String blockId, Throwable t) {
123+
exception.set(t);
124+
blockFetchLatch.countDown();
125+
}
126+
};
127+
128+
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
129+
OneForOneBlockFetcher fetcher =
130+
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
131+
fetcher.start();
132+
blockFetchLatch.await();
133+
checkSecurityException(exception.get());
134+
135+
// Register an executor so that the next steps work.
136+
ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
137+
new String[] { System.getProperty("java.io.tmpdir") }, 1,
138+
"org.apache.spark.shuffle.sort.SortShuffleManager");
139+
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
140+
client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
141+
142+
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
143+
// fetch any blocks, to keep the stream open.
144+
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
145+
ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
146+
StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
147+
long streamId = stream.streamId;
148+
149+
try (
150+
// Create a second client, authenticated with a different app ID, and try to read from
151+
// the stream created for the previous app.
152+
TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
153+
Arrays.asList(clientBootstrapFactory.apply("app-2")));
154+
TransportClient client2 = clientFactory2.createClient(
155+
TestUtils.getLocalHost(), blockServer.getPort())
156+
) {
157+
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
158+
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
159+
@Override
160+
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
161+
chunkReceivedLatch.countDown();
162+
}
163+
164+
@Override
165+
public void onFailure(int chunkIndex, Throwable t) {
166+
exception.set(t);
167+
chunkReceivedLatch.countDown();
168+
}
169+
};
170+
171+
exception.set(null);
172+
client2.fetchChunk(streamId, 0, callback);
173+
chunkReceivedLatch.await();
174+
checkSecurityException(exception.get());
175+
}
176+
}
177+
}
178+
179+
private static void checkSecurityException(Throwable t) {
180+
assertNotNull("No exception was caught.", t);
181+
assertTrue("Expected SecurityException.",
182+
t.getMessage().contains(SecurityException.class.getName()));
183+
}
184+
}

0 commit comments

Comments
 (0)