Skip to content

Commit 3a97fa8

Browse files
Adding a ShellCommandLocalSocketExecutorServer for the ShellExecutor to talk to the ShellMain.
PiperOrigin-RevId: 693905871
1 parent 35bdab8 commit 3a97fa8

File tree

4 files changed

+369
-0
lines changed

4 files changed

+369
-0
lines changed

services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ kt_android_library(
6464
"ShellCommandExecutor.java",
6565
"ShellCommandExecutorServer.java",
6666
"ShellCommandFileObserverExecutorServer.kt",
67+
"ShellCommandLocalSocketExecutorServer.kt",
6768
"ShellExecSharedConstants.java",
6869
"ShellMain.java",
6970
],
@@ -72,6 +73,8 @@ kt_android_library(
7273
deps = [
7374
":coroutine_file_observer",
7475
":file_observer_protocol",
76+
":local_socket_protocol",
77+
":local_socket_protocol_pb_java_proto_lite",
7578
"//services/speakeasy/java/androidx/test/services/speakeasy:protocol",
7679
"//services/speakeasy/java/androidx/test/services/speakeasy/client",
7780
"//services/speakeasy/java/androidx/test/services/speakeasy/client:tool_connection",
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Copyright (C) 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package androidx.test.services.shellexecutor
18+
19+
import android.net.LocalServerSocket
20+
import android.net.LocalSocket
21+
import android.net.LocalSocketAddress
22+
import android.os.Process as AndroidProcess
23+
import android.util.Log
24+
import androidx.test.services.shellexecutor.LocalSocketProtocol.asBinderKey
25+
import androidx.test.services.shellexecutor.LocalSocketProtocol.readRequest
26+
import androidx.test.services.shellexecutor.LocalSocketProtocol.sendResponse
27+
import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest
28+
import java.io.IOException
29+
import java.io.InterruptedIOException
30+
import java.security.SecureRandom
31+
import java.util.concurrent.Executors
32+
import java.util.concurrent.atomic.AtomicBoolean
33+
import kotlin.time.Duration
34+
import kotlin.time.Duration.Companion.milliseconds
35+
import kotlinx.coroutines.CoroutineScope
36+
import kotlinx.coroutines.Job
37+
import kotlinx.coroutines.SupervisorJob
38+
import kotlinx.coroutines.TimeoutCancellationException
39+
import kotlinx.coroutines.asCoroutineDispatcher
40+
import kotlinx.coroutines.async
41+
import kotlinx.coroutines.coroutineScope
42+
import kotlinx.coroutines.delay
43+
import kotlinx.coroutines.launch
44+
import kotlinx.coroutines.runBlocking
45+
import kotlinx.coroutines.runInterruptible
46+
import kotlinx.coroutines.withTimeout
47+
48+
/** Server that run shell commands for a client talking over a LocalSocket. */
49+
final class ShellCommandLocalSocketExecutorServer
50+
@JvmOverloads
51+
constructor(
52+
private val scope: CoroutineScope =
53+
CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher())
54+
) {
55+
// Use the same secret generation as SpeakEasy does.
56+
private val secret = java.lang.Long.toHexString(SecureRandom().nextLong())
57+
lateinit var socket: LocalServerSocket
58+
lateinit var address: LocalSocketAddress
59+
// Since LocalServerSocket.accept() has to be interrupted, we keep that in its own Job...
60+
lateinit var serverJob: Job
61+
// ...while all the child jobs are under a single SupervisorJob that we can join later.
62+
val shellJobs = SupervisorJob()
63+
val running = AtomicBoolean(true)
64+
65+
/** Returns the binder key to pass to client processes. */
66+
fun binderKey(): String {
67+
// The address can contain spaces, and since it gets passed through a command line, we need to
68+
// encode it. java.net.URLEncoder is conveniently available in all SDK versions.
69+
return address.asBinderKey(secret)
70+
}
71+
72+
/** Runs a simple server. */
73+
private suspend fun server() = coroutineScope {
74+
while (running.get()) {
75+
val connection =
76+
try {
77+
runInterruptible { socket.accept() }
78+
} catch (x: Exception) {
79+
// None of my tests have managed to trigger this one.
80+
Log.e(TAG, "LocalServerSocket.accept() failed", x)
81+
break
82+
}
83+
launch(scope.coroutineContext + shellJobs) { handleConnection(connection) }
84+
}
85+
}
86+
87+
/**
88+
* Relays the output of process to connection with a series of RunCommandResponses.
89+
*
90+
* @param process The process to relay output from.
91+
* @param connection The connection to relay output to.
92+
* @return false if there was a problem, true otherwise.
93+
*/
94+
private suspend fun relay(process: Process, connection: LocalSocket): Boolean {
95+
// Experiment shows that 64K is *much* faster than 4K, especially on API 21-23. Streaming 1MB
96+
// takes 3s with 4K buffers and 2s with 64K on API 23. 22 is a bit faster (2.6s -> 1.5s),
97+
// 21 faster still (630ms -> 545ms). Higher API levels are *much* faster (24 is 119 ms ->
98+
// 75ms).
99+
val buffer = ByteArray(65536)
100+
var size: Int
101+
102+
// LocalSocket.isOutputShutdown() throws UnsupportedOperationException, so we can't use
103+
// that as our loop constraint.
104+
while (true) {
105+
try {
106+
size = runInterruptible { process.inputStream.read(buffer) }
107+
if (size < 0) return true // EOF
108+
if (size == 0) {
109+
delay(1.milliseconds)
110+
continue
111+
}
112+
} catch (x: InterruptedIOException) {
113+
// We start getting these at API 24 when the timeout handling kicks in.
114+
Log.i(TAG, "Interrupted while reading from ${process}: ${x.message}")
115+
return false
116+
} catch (x: IOException) {
117+
Log.i(TAG, "Error reading from ${process}; did it time out?", x)
118+
return false
119+
}
120+
121+
if (!connection.sendResponse(buffer = buffer, size = size)) {
122+
return false
123+
}
124+
}
125+
}
126+
127+
/** Handle one connection. */
128+
private suspend fun handleConnection(connection: LocalSocket) {
129+
// connection.localSocketAddress is always null, so no point in logging it.
130+
131+
// Close the connection when done.
132+
connection.use {
133+
val request = connection.readRequest()
134+
135+
if (request.secret.compareTo(secret) != 0) {
136+
Log.w(TAG, "Ignoring request with wrong secret: $request")
137+
return
138+
}
139+
140+
val pb = request.toProcessBuilder()
141+
pb.redirectErrorStream(true)
142+
143+
val process: Process
144+
try {
145+
process = pb.start()
146+
} catch (x: IOException) {
147+
Log.e(TAG, "Failed to start process", x)
148+
connection.sendResponse(
149+
buffer = x.stackTraceToString().toByteArray(),
150+
exitCode = EXIT_CODE_FAILED_TO_START,
151+
)
152+
return
153+
}
154+
155+
// We will not be writing anything to the process' stdin.
156+
process.outputStream.close()
157+
158+
// Close the process' stdout when we're done reading.
159+
process.inputStream.use {
160+
// Launch a coroutine to relay the process' output to the client. If it times out, kill the
161+
// process and cancel the job. This is more coroutine-friendly than using waitFor() to
162+
// handle timeouts.
163+
val ioJob = scope.async { relay(process, connection) }
164+
165+
try {
166+
withTimeout(request.timeout()) {
167+
if (!ioJob.await()) {
168+
Log.w(TAG, "Relaying ${process} output failed")
169+
}
170+
runInterruptible { process.waitFor() }
171+
}
172+
} catch (x: TimeoutCancellationException) {
173+
Log.e(TAG, "Process ${process} timed out after ${request.timeout()}")
174+
process.destroy()
175+
ioJob.cancel()
176+
connection.sendResponse(exitCode = EXIT_CODE_TIMED_OUT)
177+
return
178+
}
179+
180+
connection.sendResponse(exitCode = process.exitValue())
181+
}
182+
}
183+
}
184+
185+
/** Starts the server. */
186+
fun start() {
187+
socket = LocalServerSocket("androidx.test.services ${AndroidProcess.myPid()}")
188+
address = socket.localSocketAddress
189+
Log.i(TAG, "Starting server on ${address.name}")
190+
191+
// Launch a coroutine to call socket.accept()
192+
serverJob = scope.launch { server() }
193+
}
194+
195+
/** Stops the server. */
196+
fun stop(timeout: Duration) {
197+
running.set(false)
198+
// Closing the socket does not interrupt accept()...
199+
socket.close()
200+
runBlocking(scope.coroutineContext) {
201+
try {
202+
// ...so we simply cancel that job...
203+
serverJob.cancel()
204+
// ...and play nicely with all the shell jobs underneath.
205+
withTimeout(timeout) {
206+
shellJobs.complete()
207+
shellJobs.join()
208+
}
209+
} catch (x: TimeoutCancellationException) {
210+
Log.w(TAG, "Shell jobs did not stop after $timeout", x)
211+
shellJobs.cancel()
212+
}
213+
}
214+
}
215+
216+
private fun RunCommandRequest.timeout(): Duration =
217+
if (timeoutMs <= 0) {
218+
Duration.INFINITE
219+
} else {
220+
timeoutMs.milliseconds
221+
}
222+
223+
/**
224+
* Sets up a ProcessBuilder with information from the request; other configuration is up to the
225+
* caller.
226+
*/
227+
private fun RunCommandRequest.toProcessBuilder(): ProcessBuilder {
228+
val pb = ProcessBuilder(argvList)
229+
val redacted = argvList.map { it.replace(secret, "(SECRET)") } // Don't log the secret!
230+
Log.i(TAG, "Command to execute: [${redacted.joinToString("] [")}] within ${timeout()}")
231+
if (environmentMap.isNotEmpty()) {
232+
pb.environment().putAll(environmentMap)
233+
val env = environmentMap.entries.map { (k, v) -> "$k=$v" }.joinToString(", ")
234+
Log.i(TAG, "Environment: $env")
235+
}
236+
return pb
237+
}
238+
239+
private companion object {
240+
const val TAG = "SCLSEServer" // up to 23 characters
241+
242+
const val EXIT_CODE_FAILED_TO_START = -1
243+
const val EXIT_CODE_TIMED_OUT = -2
244+
}
245+
}

services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,23 @@ axt_android_library_test(
109109
],
110110
)
111111

112+
axt_android_library_test(
113+
name = "ShellCommandLocalSocketExecutorServerTest",
114+
srcs = [
115+
"ShellCommandLocalSocketExecutorServerTest.kt",
116+
],
117+
deps = [
118+
"//runner/monitor",
119+
"//services/shellexecutor:exec_server",
120+
"//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol",
121+
"//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol_pb_java_proto_lite",
122+
"@com_google_protobuf//:protobuf_javalite",
123+
"@maven//:com_google_truth_truth",
124+
"@maven//:junit_junit",
125+
"@maven//:org_jetbrains_kotlinx_kotlinx_coroutines_android",
126+
],
127+
)
128+
112129
axt_android_library_test(
113130
name = "ShellExecutorTest",
114131
srcs = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package androidx.test.services.shellexecutor
2+
3+
import android.net.LocalSocket
4+
import android.os.Build
5+
import androidx.test.services.shellexecutor.LocalSocketProtocol.addressFromBinderKey
6+
import androidx.test.services.shellexecutor.LocalSocketProtocol.hasExited
7+
import androidx.test.services.shellexecutor.LocalSocketProtocol.readResponse
8+
import androidx.test.services.shellexecutor.LocalSocketProtocol.secretFromBinderKey
9+
import androidx.test.services.shellexecutor.LocalSocketProtocol.sendRequest
10+
import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandResponse
11+
import com.google.common.truth.Truth.assertThat
12+
import kotlin.time.Duration.Companion.milliseconds
13+
import kotlinx.coroutines.runBlocking
14+
import org.junit.Test
15+
import org.junit.runner.RunWith
16+
import org.junit.runners.JUnit4
17+
18+
@RunWith(JUnit4::class)
19+
class ShellCommandLocalSocketExecutorServerTest {
20+
21+
@Test
22+
fun success_simple() {
23+
val responses = mutableListOf<RunCommandResponse>()
24+
runBlocking {
25+
val server = ShellCommandLocalSocketExecutorServer()
26+
server.start()
27+
val client = LocalSocket(LocalSocket.SOCKET_STREAM)
28+
client.connect(addressFromBinderKey(server.binderKey()))
29+
client.sendRequest(
30+
secretFromBinderKey(server.binderKey()),
31+
listOf("echo", "\${POTRZEBIE}"),
32+
mapOf("POTRZEBIE" to "furshlugginer"),
33+
1000.milliseconds,
34+
)
35+
do {
36+
client.readResponse()?.let { responses.add(it) }
37+
} while (!responses.last().hasExited())
38+
server.stop(100.milliseconds)
39+
}
40+
if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.LOLLIPOP_MR1) {
41+
// On API 21 and 22, echo only exists as a shell builtin!
42+
assertThat(responses).hasSize(1)
43+
assertThat(responses[0].exitCode).isEqualTo(-1)
44+
assertThat(responses[0].buffer.toStringUtf8()).contains("Permission denied")
45+
} else {
46+
// On rare occasions, the output of the command will come back in two packets! So to keep
47+
// this test from being 1% flaky:
48+
val stdout = buildString {
49+
for (response in responses) {
50+
if (response.buffer.size() > 0) append(response.buffer.toStringUtf8())
51+
}
52+
}
53+
assertThat(stdout).isEqualTo("\${POTRZEBIE}\n")
54+
assertThat(responses.last().hasExited()).isTrue()
55+
assertThat(responses.last().exitCode).isEqualTo(0)
56+
}
57+
}
58+
59+
@Test
60+
fun success_shell_expansion() {
61+
val responses = mutableListOf<RunCommandResponse>()
62+
runBlocking {
63+
val server = ShellCommandLocalSocketExecutorServer()
64+
server.start()
65+
val client = LocalSocket(LocalSocket.SOCKET_STREAM)
66+
client.connect(addressFromBinderKey(server.binderKey()))
67+
client.sendRequest(
68+
secretFromBinderKey(server.binderKey()),
69+
listOf("sh", "-c", "echo \${POTRZEBIE}"),
70+
mapOf("POTRZEBIE" to "furshlugginer"),
71+
1000.milliseconds,
72+
)
73+
do {
74+
client.readResponse()?.let { responses.add(it) }
75+
} while (!responses.last().hasExited())
76+
server.stop(100.milliseconds)
77+
}
78+
val stdout = buildString {
79+
for (response in responses) {
80+
if (response.buffer.size() > 0) append(response.buffer.toStringUtf8())
81+
}
82+
}
83+
assertThat(stdout).isEqualTo("furshlugginer\n")
84+
assertThat(responses.last().hasExited()).isTrue()
85+
assertThat(responses.last().exitCode).isEqualTo(0)
86+
}
87+
88+
@Test
89+
fun failure_bad_secret() {
90+
runBlocking {
91+
val server = ShellCommandLocalSocketExecutorServer()
92+
server.start()
93+
val client = LocalSocket(LocalSocket.SOCKET_STREAM)
94+
client.connect(addressFromBinderKey(server.binderKey()))
95+
client.sendRequest(
96+
"potrzebie!",
97+
listOf("sh", "-c", "echo \${POTRZEBIE}"),
98+
mapOf("POTRZEBIE" to "furshlugginer"),
99+
1000.milliseconds,
100+
)
101+
assertThat(client.inputStream.read()).isEqualTo(-1)
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)