diff --git a/MODULE.bazel b/MODULE.bazel index 62fd10e42..fe0f88325 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -15,6 +15,8 @@ bazel_dep(name = "rules_java", version = "8.9.0") bazel_dep(name = "rules_python", version = "0.23.1") bazel_dep(name = "rules_android", version = "0.6.1") bazel_dep(name = "bazel_features", version = "1.25.0") +bazel_dep(name = "bazel_worker_api", version = "0.0.4") +bazel_dep(name = "bazel_worker_java", version = "0.0.4") rules_java_toolchains = use_extension("@rules_java//java:extensions.bzl", "toolchains") use_repo(rules_java_toolchains, "remote_java_tools") diff --git a/src/main/kotlin/io/bazel/kotlin/builder/BUILD b/src/main/kotlin/io/bazel/kotlin/builder/BUILD index ae71308a8..b9479b46a 100644 --- a/src/main/kotlin/io/bazel/kotlin/builder/BUILD +++ b/src/main/kotlin/io/bazel/kotlin/builder/BUILD @@ -28,7 +28,6 @@ java_library( "//src/main/kotlin/io/bazel/worker", "//src/main/protobuf:deps_java_proto", "//src/main/protobuf:kotlin_model_java_proto", - "//src/main/protobuf:worker_protocol_java_proto", "//third_party:dagger", "@kotlin_rules_maven//:javax_annotation_javax_annotation_api", "@kotlin_rules_maven//:javax_inject_javax_inject", diff --git a/src/main/kotlin/io/bazel/kotlin/builder/tasks/BUILD.bazel b/src/main/kotlin/io/bazel/kotlin/builder/tasks/BUILD.bazel index d8acfc648..72b9f8404 100644 --- a/src/main/kotlin/io/bazel/kotlin/builder/tasks/BUILD.bazel +++ b/src/main/kotlin/io/bazel/kotlin/builder/tasks/BUILD.bazel @@ -31,7 +31,6 @@ kt_bootstrap_library( "//src/main/kotlin/io/bazel/worker", "//src/main/protobuf:deps_java_proto", "//src/main/protobuf:kotlin_model_java_proto", - "//src/main/protobuf:worker_protocol_java_proto", "@kotlin_rules_maven//:com_google_protobuf_protobuf_java", "@kotlin_rules_maven//:com_google_protobuf_protobuf_java_util", "@kotlin_rules_maven//:javax_inject_javax_inject", diff --git a/src/main/kotlin/io/bazel/kotlin/builder/utils/BUILD.bazel b/src/main/kotlin/io/bazel/kotlin/builder/utils/BUILD.bazel index 1f5982c55..5a5adc361 100644 --- a/src/main/kotlin/io/bazel/kotlin/builder/utils/BUILD.bazel +++ b/src/main/kotlin/io/bazel/kotlin/builder/utils/BUILD.bazel @@ -10,7 +10,6 @@ kt_bootstrap_library( deps = [ "//src/main/protobuf:deps_java_proto", "//src/main/protobuf:kotlin_model_java_proto", - "//src/main/protobuf:worker_protocol_java_proto", "@bazel_tools//tools/java/runfiles", "@kotlin_rules_maven//:com_google_protobuf_protobuf_java", "@kotlin_rules_maven//:com_google_protobuf_protobuf_java_util", diff --git a/src/main/kotlin/io/bazel/worker/BUILD.bazel b/src/main/kotlin/io/bazel/worker/BUILD.bazel index 246b618db..d2db05f89 100644 --- a/src/main/kotlin/io/bazel/worker/BUILD.bazel +++ b/src/main/kotlin/io/bazel/worker/BUILD.bazel @@ -1,16 +1,9 @@ -load("@rules_java//java:defs.bzl", "java_binary", "java_import") - -# General purpose Bazel worker implemented Kotlin. Best suited for jvm based tools. +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") load("//src/main/kotlin:bootstrap.bzl", "kt_bootstrap_library") -java_binary( - name = "worker_proto_bundle_bin", - runtime_deps = ["//src/main/protobuf:worker_protocol_java_proto"], -) - -java_import( - name = "worker_proto", - jars = [":worker_proto_bundle_bin_deploy.jar"], +java_proto_library( + name = "worker_protocol_java_proto", + deps = ["@bazel_worker_api//:worker_protocol_proto"], ) kt_bootstrap_library( @@ -20,6 +13,13 @@ kt_bootstrap_library( "//:__subpackages__", ], deps = [ - ":worker_proto", + ":worker_protocol_java_proto", + "@bazel_worker_java//src/main/java/com/google/devtools/build/lib/worker:work_request_handlers", + "@kotlin_rules_maven//:com_google_auto_value_auto_value", + "@kotlin_rules_maven//:com_google_code_findbugs_jsr305", + "@kotlin_rules_maven//:com_google_errorprone_error_prone_annotations", + "@kotlin_rules_maven//:com_google_guava_guava", + "@kotlin_rules_maven//:com_google_protobuf_protobuf_java", + "@kotlin_rules_maven//:com_google_protobuf_protobuf_java_util", ], ) diff --git a/src/main/kotlin/io/bazel/worker/CpuTimeBasedGcScheduler.kt b/src/main/kotlin/io/bazel/worker/CpuTimeBasedGcScheduler.kt deleted file mode 100644 index e4983f888..000000000 --- a/src/main/kotlin/io/bazel/worker/CpuTimeBasedGcScheduler.kt +++ /dev/null @@ -1,42 +0,0 @@ -package io.bazel.worker - -import com.sun.management.OperatingSystemMXBean -import src.main.kotlin.io.bazel.worker.GcScheduler -import java.lang.management.ManagementFactory -import java.time.Duration -import java.util.concurrent.atomic.AtomicReference - -// This class is intended to mirror https://github.com/Bencodes/bazel/blob/3835d9b21ad524d06873dfbf465ffd2dfb635ba8/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java#L431-L474 -class CpuTimeBasedGcScheduler( - /** - * After this much CPU time has elapsed, we may force a GC run. Set to [Duration.ZERO] to - * disable. - */ - private val cpuUsageBeforeGc: Duration, -) : GcScheduler { - /** The total process CPU time at the last GC run (or from the start of the worker). */ - private val cpuTime: Duration - get() = if (cpuUsageBeforeGc.isZero) Duration.ZERO else Duration.ofNanos(bean.processCpuTime) - private val cpuTimeAtLastGc: AtomicReference = AtomicReference(cpuTime) - - /** Call occasionally to perform a GC if enough CPU time has been used. */ - override fun maybePerformGc() { - if (!cpuUsageBeforeGc.isZero) { - val currentCpuTime = cpuTime - val lastCpuTime = cpuTimeAtLastGc.get() - // Do GC when enough CPU time has been used, but only if nobody else beat us to it. - if (currentCpuTime.minus(lastCpuTime).compareTo(cpuUsageBeforeGc) > 0 && - cpuTimeAtLastGc.compareAndSet(lastCpuTime, currentCpuTime) - ) { - System.gc() - // Avoid counting GC CPU time against CPU time before next GC. - cpuTimeAtLastGc.compareAndSet(currentCpuTime, cpuTime) - } - } - } - - companion object { - /** Used to get the CPU time used by this process. */ - private val bean = ManagementFactory.getOperatingSystemMXBean() as OperatingSystemMXBean - } -} diff --git a/src/main/kotlin/io/bazel/worker/GcScheduler.kt b/src/main/kotlin/io/bazel/worker/GcScheduler.kt deleted file mode 100644 index 4e3d9beb4..000000000 --- a/src/main/kotlin/io/bazel/worker/GcScheduler.kt +++ /dev/null @@ -1,6 +0,0 @@ -package src.main.kotlin.io.bazel.worker - -/** GcScheduler for invoking garbage collection in a persistent worker. */ -fun interface GcScheduler { - fun maybePerformGc() -} diff --git a/src/main/kotlin/io/bazel/worker/IO.kt b/src/main/kotlin/io/bazel/worker/IO.kt deleted file mode 100644 index 9d899df5c..000000000 --- a/src/main/kotlin/io/bazel/worker/IO.kt +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright 2020 The Bazel Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package io.bazel.worker - -import java.io.BufferedInputStream -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream -import java.io.Closeable -import java.io.InputStream -import java.io.OutputStream -import java.io.PrintStream -import java.nio.charset.StandardCharsets - -class IO( - val input: InputStream, - val output: PrintStream, - private val captured: CapturingOutputStream, - private val restore: () -> Unit = {}, -) : Closeable { - /** - * Reads the captured std out and err as a UTF-8 string and then resets the - * captured ByteArrayOutputStream. - * - * Resetting the ByteArrayOutputStream prevents the worker from returning - * the same console output multiple times - **/ - fun readCapturedAsUtf8String(): String { - captured.flush() - val out = captured.toByteArray().toString(StandardCharsets.UTF_8) - captured.reset() - return out - } - - companion object { - fun capture(): IO { - val stdErr = System.err - val stdIn = BufferedInputStream(System.`in`) - val stdOut = System.out - val inputBuffer = ByteArrayInputStream(ByteArray(0)) - val captured = CapturingOutputStream() - val outputBuffer = PrintStream(captured) - - // delegate the system defaults to capture execution information - System.setErr(outputBuffer) - System.setOut(outputBuffer) - System.setIn(inputBuffer) - - return IO(stdIn, stdOut, captured) { - System.setOut(stdOut) - System.setIn(stdIn) - System.setErr(stdErr) - } - } - } - - class CapturingOutputStream : OutputStream() { - private val backing = - object : ThreadLocal() { - override fun initialValue(): ByteArrayOutputStream = ByteArrayOutputStream() - } - - override fun write(b: Int) { - backing.get().write(b) - } - - fun reset() { - backing.get().reset() - } - - fun toByteArray(): ByteArray = backing.get().toByteArray() - } - - override fun close() { - restore.invoke() - } -} diff --git a/src/main/kotlin/io/bazel/worker/PersistentWorker.kt b/src/main/kotlin/io/bazel/worker/PersistentWorker.kt index 20d02c83f..f5916af32 100644 --- a/src/main/kotlin/io/bazel/worker/PersistentWorker.kt +++ b/src/main/kotlin/io/bazel/worker/PersistentWorker.kt @@ -17,111 +17,59 @@ package io.bazel.worker -import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest -import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse -import src.main.kotlin.io.bazel.worker.GcScheduler -import java.io.InputStream +import com.google.devtools.build.lib.worker.ProtoWorkerMessageProcessor +import com.google.devtools.build.lib.worker.WorkRequestHandler +import com.google.devtools.build.lib.worker.WorkerProtocol +import java.io.IOException +import java.io.PrintWriter import java.time.Duration -import java.util.concurrent.ExecutorCompletionService -import java.util.concurrent.ExecutorService -import java.util.concurrent.Executors -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicLong /** * PersistentWorker satisfies Bazel persistent worker protocol for executing work. * * Supports multiplex (https://docs.bazel.build/versions/master/multiplex-worker.html) provided * the work is thread/coroutine safe. - * - * @param executor thread pool for executing tasks. - * @param captureIO to avoid writing stdout and stderr while executing. - * @param cpuTimeBasedGcScheduler to trigger gc cleanup. */ -class PersistentWorker( - private val captureIO: () -> IO, - private val executor: ExecutorService, - private val cpuTimeBasedGcScheduler: GcScheduler, -) : Worker { - constructor( - executor: ExecutorService, - captureIO: () -> IO, - ) : this( - captureIO, - executor, - GcScheduler {}, - ) - - constructor() : this( - IO.Companion::capture, - Executors.newCachedThreadPool(), - CpuTimeBasedGcScheduler(Duration.ofSeconds(10)), - ) - - override fun start(execute: Work) = - WorkerContext.run { - captureIO().use { io -> - val running = AtomicLong(0) - val completion = ExecutorCompletionService(executor) - val producer = - executor.submit { - io.input.readRequestAnd { request -> - running.incrementAndGet() - completion.submit { - doTask( +class PersistentWorker : Worker { + override fun start(execute: Work): Int { + return WorkerContext.run { + val realStdErr = System.err + try { + val workerHandler: WorkRequestHandler = + WorkRequestHandler + .WorkRequestHandlerBuilder( + WorkRequestHandler.WorkRequestCallback { + request: WorkerProtocol.WorkRequest, + pw: PrintWriter, + -> + return@WorkRequestCallback doTask( name = "request ${request.requestId}", task = request.workTo(execute), - ).asResponseTo(request.requestId, io) - } - } - } - val consumer = - executor.submit { - while (!producer.isDone || running.get() > 0) { - // poll time is how long before checking producer liveliness. Too long, worker hangs - // when being shutdown -- too short, and it starves the process. - completion.poll(1, TimeUnit.SECONDS)?.run { - running.decrementAndGet() - get().writeDelimitedTo(io.output) - io.output.flush() - } - cpuTimeBasedGcScheduler.maybePerformGc() - } - } - producer.get() - consumer.get() - io.output.close() + ).asResponse(pw) + }, + realStdErr, + ProtoWorkerMessageProcessor(System.`in`, System.out), + ).setCpuUsageBeforeGc(Duration.ofSeconds(10)) + .build() + workerHandler.processRequests() + } catch (e: IOException) { + this.error(e, { "Unknown IO exception" }) + e.printStackTrace(realStdErr) + return@run 1 } return@run 0 } + } - private fun WorkRequest.workTo(execute: Work): (sub: WorkerContext.TaskContext) -> Status = - { ctx -> execute(ctx, argumentsList.toList()) } - - private fun InputStream.readRequestAnd(action: (WorkRequest) -> Unit) { - while (true) { - WorkRequest - .parseDelimitedFrom(this) - ?.run(action) - ?: return + private fun WorkerProtocol.WorkRequest.workTo( + execute: Work, + ): (sub: WorkerContext.TaskContext) -> Status = + { ctx -> + execute(ctx, argumentsList.toList()) } - } - private fun TaskResult.asResponseTo( - id: Int, - io: IO, - ): WorkResponse = - WorkResponse - .newBuilder() - .apply { - val cap = io.readCapturedAsUtf8String() - // append whatever falls through standard out. - output = - listOf( - log.out.toString(), - cap, - ).joinToString("\n").trim() - exitCode = status.exit - requestId = id - }.build() + private fun TaskResult.asResponse(pw: PrintWriter): Int { + pw.print(log.out.toString()) + return status.exit + } } diff --git a/src/main/protobuf/BUILD b/src/main/protobuf/BUILD index 7993d0434..a684fcbfc 100644 --- a/src/main/protobuf/BUILD +++ b/src/main/protobuf/BUILD @@ -8,11 +8,6 @@ alias( actual = "@bazel_tools//src/main/protobuf:deps_java_proto", ) -alias( - name = "worker_protocol_java_proto", - actual = "@bazel_tools//src/main/protobuf:worker_protocol_java_proto", -) - proto_library( name = "kotlin_model_proto", srcs = [":kotlin_model.proto"], @@ -29,6 +24,5 @@ java_library( exports = [ ":deps_java_proto", ":kotlin_model_java_proto", - ":worker_protocol_java_proto", ], ) diff --git a/src/test/kotlin/io/bazel/worker/BUILD.bazel b/src/test/kotlin/io/bazel/worker/BUILD.bazel index d3a2040bd..e51dfe28c 100644 --- a/src/test/kotlin/io/bazel/worker/BUILD.bazel +++ b/src/test/kotlin/io/bazel/worker/BUILD.bazel @@ -1,17 +1,5 @@ load("//kotlin:jvm.bzl", "kt_jvm_library", "kt_jvm_test") -kt_jvm_test( - name = "IOTest", - srcs = [ - "IOTest.kt", - ], - test_class = "io.bazel.worker.IOTest", - deps = [ - "//src/main/kotlin/io/bazel/worker", - "@kotlin_rules_maven//:com_google_truth_truth", - ], -) - kt_jvm_test( name = "WorkerContextTest", srcs = [ @@ -39,9 +27,6 @@ kt_jvm_test( kt_jvm_library( name = "WorkerEnvironment", srcs = ["WorkerEnvironment.kt"], - deps = [ - "//src/main/protobuf:worker_protocol_java_proto", - ], ) kt_jvm_test( @@ -53,21 +38,6 @@ kt_jvm_test( deps = [ ":WorkerEnvironment", "//src/main/kotlin/io/bazel/worker", - "//src/main/protobuf:worker_protocol_java_proto", - "@kotlin_rules_maven//:com_google_truth_truth", - ], -) - -kt_jvm_test( - name = "JavaPersistentWorkerTest", - srcs = [ - "JavaPersistentWorkerTest.kt", - ], - test_class = "io.bazel.worker.JavaPersistentWorkerTest", - deps = [ - ":WorkerEnvironment", - "//src/main/kotlin/io/bazel/worker", - "//src/main/protobuf:worker_protocol_java_proto", "@kotlin_rules_maven//:com_google_truth_truth", ], ) @@ -75,12 +45,8 @@ kt_jvm_test( test_suite( name = "worker_tests", tests = [ - ":IOTest", ":InvocationWorkerTest", - ":JavaPersistentWorkerTest", ":WorkerContextTest", - # TODO(restingbull): Re-enable when not flaky. - #":WorkerEnvironmentTest", ], ) diff --git a/src/test/kotlin/io/bazel/worker/IOTest.kt b/src/test/kotlin/io/bazel/worker/IOTest.kt deleted file mode 100644 index 55cdf4116..000000000 --- a/src/test/kotlin/io/bazel/worker/IOTest.kt +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2020 The Bazel Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package io.bazel.worker - -import com.google.common.truth.Truth.assertThat -import org.junit.After -import org.junit.Before -import org.junit.Test -import java.io.BufferedInputStream -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream -import java.io.PrintStream -import java.nio.charset.StandardCharsets - -class IOTest { - - private fun ByteArrayOutputStream.written() = String(toByteArray(), StandardCharsets.UTF_8) - - private val stdErr = System.err - private val stdIn = BufferedInputStream(System.`in`) - private val stdOut = System.out - private val inputBuffer = ByteArrayInputStream(ByteArray(0)) - private val captured = ByteArrayOutputStream() - private val outputBuffer = PrintStream(captured) - - @Before - fun captureSystem() { - // delegate the system defaults to capture execution information - System.setErr(outputBuffer) - System.setOut(outputBuffer) - System.setIn(inputBuffer) - } - - @After - fun restoreSystem() { - System.setErr(stdErr) - System.setIn(stdIn) - System.setOut(stdOut) - } - - @Test - fun capture() { - assertThat(captured.written()).isEmpty() - IO.capture().use { io -> - println("foo foo is on the loose") - assertThat(io.readCapturedAsUtf8String()).isEqualTo("foo foo is on the loose\n") - } - assertThat(captured.written()).isEmpty() - } - - @Test - fun captureDoesNotRepeatOutput() { - assertThat(captured.written()).isEmpty() - IO.capture().use { io -> - println("foo foo is on the loose") - assertThat(io.readCapturedAsUtf8String()).isEqualTo("foo foo is on the loose\n") - println("bar bar is on the loose") - assertThat(io.readCapturedAsUtf8String()).isEqualTo("bar bar is on the loose\n") - } - assertThat(captured.written()).isEmpty() - } -} diff --git a/src/test/kotlin/io/bazel/worker/JavaPersistentWorkerTest.kt b/src/test/kotlin/io/bazel/worker/JavaPersistentWorkerTest.kt deleted file mode 100644 index 937c3ea83..000000000 --- a/src/test/kotlin/io/bazel/worker/JavaPersistentWorkerTest.kt +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright 2020 The Bazel Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package io.bazel.worker - -import com.google.common.truth.Truth.assertThat -import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest -import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse -import org.junit.Test -import java.nio.charset.StandardCharsets.UTF_8 -import java.util.concurrent.Executors - -class JavaPersistentWorkerTest { - - @Test - fun persistent() { - val requests = listOf( - WorkRequest.newBuilder().addAllArguments(listOf("--mammal", "bunny")).setRequestId(1) - .build(), - WorkRequest.newBuilder().addAllArguments(listOf("--mammal", "squirrel")).setRequestId(2) - .build() - ) - - val expectedResponses = mapOf( - 1 to WorkResponse - .newBuilder() - .setRequestId(1) - .setOutput("sidhe disciplined\n\nSqueek!") - .setExitCode(1), - 2 to WorkResponse.newBuilder().setRequestId(2).setOutput("sidhe commended").setExitCode(0) - ) - - val captured = IO.CapturingOutputStream() - - val actualResponses = WorkerEnvironment.inProcess { - task { stdIn, stdOut -> - PersistentWorker(Executors.newCachedThreadPool()) { - IO(stdIn, stdOut, captured) - }.start { ctx, args -> - when (args.toList()) { - listOf("--mammal", "bunny") -> { - ctx.info { "sidhe disciplined" } - captured.write("Squeek!".toByteArray(UTF_8)) - return@start Status.ERROR - } - listOf("--mammal", "squirrel") -> { - ctx.info { "sidhe commended" } - return@start Status.SUCCESS - } - else -> throw IllegalArgumentException("unexpected forest: $args") - } - } - } - requests.forEach { writeStdIn(it) } - closeStdIn() - waitForStdOut() - return@inProcess generateSequence { - readStdOut().apply { - println("sequence $this") - } - } - }.associateBy { workResponse -> - workResponse.requestId - } - - assertThat(actualResponses.keys).isEqualTo(expectedResponses.keys) - - expectedResponses.forEach { (resId, res) -> - assertThat(actualResponses[resId]?.output).contains(res.output) - assertThat(actualResponses[resId]?.exitCode).isEqualTo(res.exitCode) - } - } - - @Test - fun error() { - val captured = IO.CapturingOutputStream() - val actualResponses = WorkerEnvironment.inProcess { - task { stdIn, stdOut -> - PersistentWorker(Executors.newCachedThreadPool()) { - IO(stdIn, stdOut, captured) - }.start { _, _ -> - throw IllegalArgumentException("missing forest fairy") - } - } - writeStdIn( - WorkRequest.newBuilder() - .addAllArguments(listOf("--mammal", "bunny")) - .setRequestId(1) - .build() - ) - closeStdIn() - return@inProcess readStdOut() - } - - assertThat(actualResponses?.requestId).isEqualTo(1) - assertThat(actualResponses?.output) - .contains("java.lang.IllegalArgumentException: missing forest fairy") - } -}