diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java index ab73946534cd..112d1dd9229b 100644 --- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java +++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java @@ -46,6 +46,7 @@ import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.joda.time.Duration; @@ -613,12 +614,32 @@ private void executeAsync(Callable callable) throws UserCodeExecutionExcep private static void parseAndThrow(Future future, ExecutionException e) throws UserCodeExecutionException { future.cancel(true); - if (e.getCause() == null) { - throw new UserCodeExecutionException(e); + + try { + UserCodeExecutionException genericException = null; + for (Throwable throwable : Throwables.getCausalChain(e)) { + if (throwable instanceof UserCodeQuotaException) { + throw (UserCodeQuotaException) throwable; + } else if (throwable instanceof UserCodeTimeoutException) { + throw (UserCodeTimeoutException) throwable; + } else if (throwable instanceof UserCodeRemoteSystemException) { + throw (UserCodeRemoteSystemException) throwable; + } else if (genericException == null && throwable instanceof UserCodeExecutionException) { + genericException = (UserCodeExecutionException) throwable; + } + } + if (genericException != null) { + throw genericException; + } + } catch (IllegalArgumentException iae) { + // Circular reference detected in causal chain + throw new UserCodeExecutionException( + "circular reference detected in exception causal chain", e); } - Throwable cause = checkStateNotNull(e.getCause()); - if (cause instanceof UserCodeQuotaException) { - throw new UserCodeQuotaException(cause); + + Throwable cause = e.getCause(); + if (cause == null) { + throw new UserCodeExecutionException(e); } throw new UserCodeExecutionException(cause); } diff --git a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java index 0e572bdd2d64..4784f40f1822 100644 --- a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java +++ b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java @@ -104,6 +104,34 @@ public void givenCallerThrowsUserCodeExecutionException_emitsIntoFailurePCollect pipeline.run(); } + @Test + public void givenCallerThrowsNonUserCodeException_emitsWrappedUserCodeExecutionException() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new CallerThrowsRuntimeException(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(1L); + + pipeline.run(); + } + + @Test + public void givenCallerThrowsCircularCausalChain_emitsUserCodeExecutionException() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new CallerThrowsCircularCause(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(1L); + + pipeline.run(); + } + @Test public void givenCallerThrowsQuotaException_emitsIntoFailurePCollection() { Result result = @@ -142,7 +170,7 @@ public void givenCallerTimeout_emitsFailurePCollection() { } @Test - public void givenCallerThrowsTimeoutException_emitsFailurePCollection() { + public void givenCallerThrowsTimeoutException_thenPreservesExceptionType() { Result result = pipeline .apply(Create.of(new Request("a"))) @@ -150,7 +178,7 @@ public void givenCallerThrowsTimeoutException_emitsFailurePCollection() { PCollection failures = result.getFailures(); PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) - .isEqualTo(1L); + .isEqualTo(0L); PAssert.thatSingleton(countStackTracesOf(failures, UserCodeQuotaException.class)).isEqualTo(0L); PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) .isEqualTo(1L); @@ -158,6 +186,153 @@ public void givenCallerThrowsTimeoutException_emitsFailurePCollection() { pipeline.run(); } + @Test + public void givenCallerThrowsRemoteSystemException_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of(new CallerThrowsRemoteSystemException(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeRemoteSystemException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void givenNestedExecutionException_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of( + new CallerThrowsNestedExecutionException(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(0L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeRemoteSystemException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void givenCallerThrowsGenericWrappingTimeout_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of( + new CallerThrowsGenericWrappingTimeout(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void givenCallerThrowsGenericWrappingQuota_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of(new CallerThrowsGenericWrappingQuota(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeQuotaException.class)).isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void givenCallerThrowsGenericWrappingRemoteSystem_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of( + new CallerThrowsGenericWrappingRemoteSystem(), + NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeRemoteSystemException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void + givenCallerThrowsUncheckedExecutionExceptionWrappingTimeout_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of( + new CallerThrowsUncheckedExecutionExceptionWrappingTimeout(), + NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void + givenCallerThrowsUncheckedExecutionExceptionWrappingRemoteSystem_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of( + new CallerThrowsUncheckedExecutionExceptionWrappingRemoteSystem(), + NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeRemoteSystemException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void givenCallerThrowsTripleNestedTimeout_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of(new CallerThrowsTripleNestedTimeout(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + @Test public void givenSetupThrowsUserCodeExecutionException_throwsError() { pipeline @@ -376,6 +551,25 @@ public Response call(Request request) throws UserCodeExecutionException { } } + private static class CallerThrowsRuntimeException implements Caller { + + @Override + public Response call(Request request) { + throw new RuntimeException("unexpected error"); + } + } + + private static class CallerThrowsCircularCause implements Caller { + + @Override + public Response call(Request request) { + Exception a = new Exception("a"); + Exception b = new Exception("b", a); + a.initCause(b); // a -> b -> a (circular reference) + throw new RuntimeException("boom", a); + } + } + private static class CallerThrowsTimeout implements Caller { @Override @@ -384,6 +578,74 @@ public Response call(Request request) throws UserCodeExecutionException { } } + private static class CallerThrowsRemoteSystemException implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeRemoteSystemException(""); + } + } + + private static class CallerThrowsNestedExecutionException implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UncheckedExecutionException(new UserCodeExecutionException("nested")); + } + } + + private static class CallerThrowsGenericWrappingTimeout implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeExecutionException("generic", new UserCodeTimeoutException("timeout")); + } + } + + private static class CallerThrowsGenericWrappingQuota implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeExecutionException("generic", new UserCodeQuotaException("quota")); + } + } + + private static class CallerThrowsGenericWrappingRemoteSystem + implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeExecutionException("generic", new UserCodeRemoteSystemException("remote")); + } + } + + private static class CallerThrowsUncheckedExecutionExceptionWrappingTimeout + implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UncheckedExecutionException(new UserCodeTimeoutException("timeout")); + } + } + + private static class CallerThrowsUncheckedExecutionExceptionWrappingRemoteSystem + implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UncheckedExecutionException(new UserCodeRemoteSystemException("remote")); + } + } + + private static class CallerThrowsTripleNestedTimeout implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UncheckedExecutionException( + new RuntimeException(new UserCodeTimeoutException("deep timeout"))); + } + } + private static class CallerInvokesQuotaException implements Caller { @Override diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 9c5306e143ec..94e9a0644d04 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -883,8 +883,15 @@ def __init__(self, fn, *args, **kwargs): # Ensure fn and side inputs are picklable for remote execution. try: self.fn = pickler.roundtrip(self.fn) - except RuntimeError as e: - raise RuntimeError('Unable to pickle fn %s: %s' % (self.fn, e)) + except (RuntimeError, TypeError, Exception) as e: + raise RuntimeError( + 'Unable to pickle fn %s: %s. ' + 'User code must be serializable (picklable) for distributed ' + 'execution. This usually happens when lambdas or closures capture ' + 'non-serializable objects like file handles, database connections, ' + 'or thread locks. Try: (1) using module-level functions instead of ' + 'lambdas, (2) initializing resources in setup() methods, ' + '(3) checking what your closure captures.' % (self.fn, e)) from e self.args = pickler.roundtrip(self.args) self.kwargs = pickler.roundtrip(self.kwargs) diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 9a9bf6ff0a74..8c2acefccdb3 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -163,6 +163,25 @@ def test_do_with_side_input_as_keyword_arg(self): lambda x, addon: [x + addon], addon=pvalue.AsSingleton(side)) assert_that(result, equal_to([11, 12, 13])) + def test_callable_non_serializable_error_message(self): + class NonSerializable: + def __getstate__(self): + raise RuntimeError('nope') + + bad = NonSerializable() + + with self.assertRaises(RuntimeError) as context: + _ = beam.Map(lambda x: bad) + + message = str(context.exception) + self.assertIn('Unable to pickle fn', message) + self.assertIn( + 'User code must be serializable (picklable) for distributed execution.', + message) + self.assertIn('non-serializable objects like file handles', message) + self.assertIn( + 'Try: (1) using module-level functions instead of lambdas', message) + def test_do_with_do_fn_returning_string_raises_warning(self): ex_details = r'.*Returning a str from a ParDo or FlatMap is discouraged.'