diff --git a/src/Function.cs b/src/Function.cs index 9308110..99b01e9 100644 --- a/src/Function.cs +++ b/src/Function.cs @@ -554,6 +554,22 @@ public static Function FromCallbackReturns true if the type signature of the function is valid or false if not. public bool CheckTypeSignature(Type? returnType = null, params Type[] parameters) { + // Check if the return type is Result, Result, ResultWithBacktrace or ResultWithBacktrace + if (returnType != null && returnType.IsResult()) + { + // Try to get the type the result wraps (may be null if it's one of the non-generic result types) + var wrappedReturnType = returnType.IsGenericType ? returnType.GetGenericArguments()[0] : null; + + // Check that the result does not attempt to wrap another result (e.g. Result>) + if (wrappedReturnType != null && wrappedReturnType.IsResult()) + { + return false; + } + + // Type check with the wrapped value instead of the result + return CheckTypeSignature(wrappedReturnType, parameters); + } + // Check if the func returns no values if that's expected if (Results.Count == 0 && returnType != null) { @@ -2007,10 +2023,10 @@ private unsafe TR InvokeWithReturn(ReadOnlySpan arguments, IReturnTyp try { - Invoke(arguments, output); + var trap = Invoke(arguments, output); // Note: null suppression is safe because `Invoke` checks that `store` is not null - return factory.Create(store!, output); + return factory.Create(store!, trap, output); } finally { @@ -2037,7 +2053,11 @@ private unsafe void InvokeWithoutReturn(ReadOnlySpan arguments) { try { - Invoke(arguments, stackalloc Value[0]); + var trap = Invoke(arguments, stackalloc Value[0]); + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } } finally { @@ -2103,19 +2123,23 @@ private unsafe void InvokeWithoutReturn(ReadOnlySpan arguments) throw new ArgumentNullException(nameof(store)); } - var context = store.Context; - // Convert arguments (ValueBox) into a form wasm can consume (Value) Span args = stackalloc Value[Parameters.Count]; - for (int i = 0; i < arguments.Length; ++i) + for (var i = 0; i < arguments.Length; ++i) + { args[i] = arguments[i].ToValue(Parameters[i]); + } // Make some space to store the return results Span resultsSpan = stackalloc Value[Results.Count]; try { - Invoke(args, resultsSpan); + var trap = Invoke(args, resultsSpan); + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } if (Results.Count == 0) { @@ -2161,9 +2185,9 @@ private unsafe void InvokeWithoutReturn(ReadOnlySpan arguments) /// The arguments to pass to the function, wrapped as `Value` /// Output span to store the results in, must be the correct length /// - /// Returns null if the function has no return value. + /// Returns the trap ptr or zero /// - private unsafe void Invoke(ReadOnlySpan arguments, Span resultsOut) + private unsafe IntPtr Invoke(ReadOnlySpan arguments, Span resultsOut) { if (IsNull) { @@ -2190,10 +2214,7 @@ private unsafe void Invoke(ReadOnlySpan arguments, Span resultsOut throw WasmtimeException.FromOwnedError(error); } - if (trap != IntPtr.Zero) - { - throw TrapException.FromOwnedTrap(trap); - } + return trap; } Extern IExternal.AsExtern() diff --git a/src/Result.cs b/src/Result.cs new file mode 100644 index 0000000..050c182 --- /dev/null +++ b/src/Result.cs @@ -0,0 +1,287 @@ +using System; + +namespace Wasmtime +{ + /// + /// Indicates what type of result this is + /// + public enum ResultType + { + /// + /// Excecution succeeded + /// + Ok = 0, + + /// + /// Result contains a trap + /// + Trap = 1, + } + + /// + /// A result from a function call which may represent a Value or a Trap. If a trap happens the full backtrace is captured. + /// + public readonly struct ResultWithBacktrace + { + /// + /// Indicates what type of result this contains + /// + public ResultType Type { get; } + + private readonly TrapException? _trap; + + internal ResultWithBacktrace(IntPtr trap) + { + Type = ResultType.Trap; + _trap = TrapException.FromOwnedTrap(trap); + } + + /// + /// Get the trap associated with this result + /// + /// Thrown if this Type != Types.Trap + public TrapException Trap + { + get + { + if (Type != ResultType.Trap) + { + throw new InvalidOperationException($"Cannot get 'Trap' from '{Type}' type result"); + } + + return _trap!; + } + } + } + + /// + /// A result from a function call which may represent a Value or a Trap + /// + public readonly struct Result + { + /// + /// Indicates what type of result this contains + /// + public ResultType Type { get; } + + private readonly TrapCode _trap; + + internal Result(IntPtr trap) + { + Type = ResultType.Trap; + _trap = TrapException.GetTrapCode(trap); + TrapException.Native.wasm_trap_delete(trap); + } + + /// + /// Get the trap associated with this result + /// + /// Thrown if this Type != Types.Trap + public TrapCode TrapCode + { + get + { + if (Type != ResultType.Trap) + { + throw new InvalidOperationException($"Cannot get 'Trap' from '{Type}' type result"); + } + + return _trap; + } + } + } + + /// + /// A result from a function call which may represent a Value or a Trap. If a trap happens the full backtrace is captured. + /// + /// Type of the return value contained in this result + public readonly struct ResultWithBacktrace + { + /// + /// Indicates what type of result this contains + /// + public ResultType Type { get; } + + private readonly T? _value; + private readonly TrapException? _trap; + + internal ResultWithBacktrace(T value) + { + Type = ResultType.Ok; + _value = value; + _trap = null; + } + + internal ResultWithBacktrace(IntPtr trap) + { + Type = ResultType.Trap; + _value = default; + _trap = TrapException.FromOwnedTrap(trap); + } + + /// + /// Convert this result into a value, throw if it is a Trap + /// + /// + /// + /// Thrown if Type == Trap + /// Thrown if Type property contains an unknown value + public static explicit operator T?(ResultWithBacktrace value) + { + switch (value.Type) + { + case ResultType.Ok: + return value._value; + + case ResultType.Trap: + throw value._trap!; + + default: + throw new ArgumentOutOfRangeException(nameof(value), $"Unknown Result Type property value '{value.Type}'"); + } + } + + /// + /// Get the value associated with this result + /// + /// Thrown if this Type != Types.Value + public T? Value + { + get + { + if (Type != ResultType.Ok) + { + throw new InvalidOperationException($"Cannot get 'Value' from '{Type}' type result"); + } + + return _value; + } + } + + /// + /// Get the trap associated with this result + /// + /// Thrown if this Type != Types.Trap + public TrapException Trap + { + get + { + if (Type != ResultType.Trap) + { + throw new InvalidOperationException($"Cannot get 'Trap' from '{Type}' type result"); + } + + return _trap!; + } + } + } + + /// + /// A result from a function call which may represent a Value or a Trap + /// + /// Type of the return value contained in this result + public readonly struct Result + { + /// + /// Indicates what type of result this contains + /// + public ResultType Type { get; } + + private readonly T? _value; + private readonly TrapCode _trap; + + internal Result(T value) + { + Type = ResultType.Ok; + + _value = value; + _trap = default; + } + + internal Result(IntPtr trap) + { + Type = ResultType.Trap; + + _value = default; + + _trap = TrapException.GetTrapCode(trap); + TrapException.Native.wasm_trap_delete(trap); + + } + + /// + /// Convert this result into a value, throw if it is a Trap + /// + /// + /// The value contained within this result if Type == ResultType.Ok + /// Thrown if Type == Trap + /// Thrown if Type property contains an unknoown value + public static explicit operator T?(Result value) + { + switch (value.Type) + { + case ResultType.Ok: + return value._value; + + case ResultType.Trap: + throw new TrapException($"{value._trap} trap", null, value._trap); + + default: + throw new ArgumentOutOfRangeException(nameof(value), $"Unknown Result Type property value '{value.Type}'"); + } + } + + /// + /// Get the value associated with this result + /// + /// Thrown if this Type != Types.Value + public T? Value + { + get + { + if (Type != ResultType.Ok) + { + throw new InvalidOperationException($"Cannot get 'Value' from '{Type}' type result"); + } + + return _value; + } + } + + /// + /// Get the trap associated with this result + /// + /// Thrown if this Type != Types.Trap + public TrapCode TrapCode + { + get + { + if (Type != ResultType.Trap) + { + throw new InvalidOperationException($"Cannot get 'Trap' from '{Type}' type result"); + } + + return _trap; + } + } + } + + internal static class TypeExtensions + { + internal static bool IsResult(this Type type) + { + if (type == typeof(Result) || type == typeof(ResultWithBacktrace)) + { + return true; + } + + if (!type.IsGenericType) + { + return false; + } + + var gtd = type.GetGenericTypeDefinition(); + return typeof(Result<>) == gtd || typeof(ResultWithBacktrace<>) == gtd; + } + } +} diff --git a/src/ReturnTypeFactory.cs b/src/ReturnTypeFactory.cs index 04d26b9..0b71b01 100644 --- a/src/ReturnTypeFactory.cs +++ b/src/ReturnTypeFactory.cs @@ -8,26 +8,54 @@ namespace Wasmtime { interface IReturnTypeFactory { - TReturn Create(IStore store, Span values); + TReturn Create(IStore store, IntPtr trap, Span values); static IReturnTypeFactory Create() { - var types = GetTupleTypes().ToList(); - - if (types.Count == 1) + // First, check if the value is one of the 4 result wrappers + if (typeof(TReturn).IsResult()) { - return new NonTupleTypeFactory(); + if (typeof(TReturn).IsGenericType) + { + var wrapperType = typeof(TReturn).GetGenericTypeDefinition(); + var wrappedType = typeof(TReturn).GetGenericArguments()[0]; + + var factoryType = wrapperType == typeof(Result<>) + ? typeof(ResultTFactory<>) + : typeof(ResultWithBacktraceTFactory<>); + + return (IReturnTypeFactory)Activator.CreateInstance(factoryType.MakeGenericType(wrappedType))!; + } + else + { + var wrapperType = typeof(TReturn); + + var factoryType = wrapperType == typeof(Result) + ? typeof(ResultFactory) + : typeof(ResultWithBacktraceFactory); + + return (IReturnTypeFactory)Activator.CreateInstance(factoryType)!; + } } + else + { + var types = GetTupleTypes().ToList(); - // All of the factories take parameters: - // Add TupleType to the start of the list - types.Insert(0, typeof(TReturn)); + if (types.Count == 1) + { + return new NonTupleTypeFactory(); + } - Type factoryType = GetFactoryType(types.Count - 1); - return (IReturnTypeFactory)Activator.CreateInstance(factoryType.MakeGenericType(types.ToArray()))!; + // All of the factories take parameters: + // Add TupleType to the start of the list + types.Insert(0, typeof(TReturn)); + + Type factoryType = GetTupleFactoryType(types.Count - 1); + return (IReturnTypeFactory)Activator.CreateInstance(factoryType.MakeGenericType(types.ToArray()))!; + } } - protected static Type GetFactoryType(int arity) + private static Type GetTupleFactoryType(int arity) { return arity switch { @@ -52,14 +80,83 @@ private static IReadOnlyList GetTupleTypes() return new[] { typeof(TReturn) }; } } + } - protected static MethodInfo GetCreateMethodInfo(int arity) + internal class ResultFactory + : IReturnTypeFactory + { + public Result Create(IStore store, IntPtr trap, Span values) { - return typeof(ValueTuple) - .GetMethods(BindingFlags.Public | BindingFlags.Static) - .Where(a => a.Name == "Create") - .Where(a => a.ContainsGenericParameters && a.IsGenericMethod) - .First(a => a.GetGenericArguments().Length == arity); + if (trap != IntPtr.Zero) + { + return new Result(trap); + } + else + { + return new Result(); + } + } + } + + internal class ResultTFactory + : IReturnTypeFactory> + { + private readonly IReturnTypeFactory _factory; + + public ResultTFactory() + { + _factory = IReturnTypeFactory.Create(); + } + + public Result Create(IStore store, IntPtr trap, Span values) + { + if (trap != IntPtr.Zero) + { + return new Result(trap); + } + else + { + return new Result(_factory.Create(store, trap, values)); + } + } + } + + internal class ResultWithBacktraceFactory + : IReturnTypeFactory + { + public ResultWithBacktrace Create(IStore store, IntPtr trap, Span values) + { + if (trap != IntPtr.Zero) + { + return new ResultWithBacktrace(trap); + } + else + { + return new ResultWithBacktrace(); + } + } + } + + internal class ResultWithBacktraceTFactory + : IReturnTypeFactory> + { + private readonly IReturnTypeFactory _factory; + + public ResultWithBacktraceTFactory() + { + _factory = IReturnTypeFactory.Create(); + } + + public ResultWithBacktrace Create(IStore store, IntPtr trap, Span values) + { + if (trap != IntPtr.Zero) + { + return new ResultWithBacktrace(trap); + } + else + { + return new ResultWithBacktrace(_factory.Create(store, trap, values)); + } } } @@ -73,8 +170,13 @@ public NonTupleTypeFactory() converter = ValueBox.Converter(); } - public TReturn Create(IStore store, Span values) + public TReturn Create(IStore store, IntPtr trap, Span values) { + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } + return converter.Unbox(store, values[0].ToValueBox()); } } @@ -90,12 +192,21 @@ protected BaseTupleFactory() // Get all the generic arguments of TFunc. All of the Parameters, followed by the return type var args = typeof(TFunc).GetGenericArguments(); - Factory = (TFunc)IReturnTypeFactory.GetCreateMethodInfo(args.Length - 1) + Factory = (TFunc)GetCreateMethodInfo(args.Length - 1) .MakeGenericMethod(args[..^1]) .CreateDelegate(typeof(TFunc)); } - public abstract TReturn Create(IStore store, Span values); + protected static MethodInfo GetCreateMethodInfo(int arity) + { + return typeof(ValueTuple) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .Where(a => a.Name == "Create") + .Where(a => a.ContainsGenericParameters && a.IsGenericMethod) + .First(a => a.GetGenericArguments().Length == arity); + } + + public abstract TReturn Create(IStore store, IntPtr trap, Span values); } internal class TupleFactory2 @@ -110,8 +221,13 @@ public TupleFactory2() converterB = ValueBox.Converter(); } - public override TReturn Create(IStore store, Span values) + public override TReturn Create(IStore store, IntPtr trap, Span values) { + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } + return Factory( converterA.Unbox(store, values[0].ToValueBox()), converterB.Unbox(store, values[1].ToValueBox()) @@ -133,8 +249,13 @@ public TupleFactory3() converterC = ValueBox.Converter(); } - public override TReturn Create(IStore store, Span values) + public override TReturn Create(IStore store, IntPtr trap, Span values) { + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } + return Factory( converterA.Unbox(store, values[0].ToValueBox()), converterB.Unbox(store, values[1].ToValueBox()), @@ -159,8 +280,13 @@ public TupleFactory4() converterD = ValueBox.Converter(); } - public override TReturn Create(IStore store, Span values) + public override TReturn Create(IStore store, IntPtr trap, Span values) { + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } + return Factory( converterA.Unbox(store, values[0].ToValueBox()), converterB.Unbox(store, values[1].ToValueBox()), @@ -188,8 +314,13 @@ public TupleFactory5() converterE = ValueBox.Converter(); } - public override TReturn Create(IStore store, Span values) + public override TReturn Create(IStore store, IntPtr trap, Span values) { + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } + return Factory( converterA.Unbox(store, values[0].ToValueBox()), converterB.Unbox(store, values[1].ToValueBox()), @@ -220,8 +351,13 @@ public TupleFactory6() converterF = ValueBox.Converter(); } - public override TReturn Create(IStore store, Span values) + public override TReturn Create(IStore store, IntPtr trap, Span values) { + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } + return Factory( converterA.Unbox(store, values[0].ToValueBox()), converterB.Unbox(store, values[1].ToValueBox()), @@ -255,8 +391,13 @@ public TupleFactory7() converterG = ValueBox.Converter(); } - public override TReturn Create(IStore store, Span values) + public override TReturn Create(IStore store, IntPtr trap, Span values) { + if (trap != IntPtr.Zero) + { + throw TrapException.FromOwnedTrap(trap); + } + return Factory( converterA.Unbox(store, values[0].ToValueBox()), converterB.Unbox(store, values[1].ToValueBox()), diff --git a/src/TrapException.cs b/src/TrapException.cs index ac910e7..abaa3d6 100644 --- a/src/TrapException.cs +++ b/src/TrapException.cs @@ -47,9 +47,9 @@ public class TrapFrame { unsafe internal TrapFrame(IntPtr frame) { - FunctionOffset = (int)Native.wasm_frame_func_offset(frame); + FunctionOffset = Native.wasm_frame_func_offset(frame); FunctionName = null; - ModuleOffset = (int)Native.wasm_frame_module_offset(frame); + ModuleOffset = Native.wasm_frame_module_offset(frame); ModuleName = null; var bytes = Native.wasmtime_frame_func_name(frame); @@ -68,7 +68,7 @@ unsafe internal TrapFrame(IntPtr frame) /// /// Gets the frame's byte offset from the start of the function. /// - public int FunctionOffset { get; private set; } + public nuint FunctionOffset { get; private set; } /// /// Gets the frame's function name. @@ -78,7 +78,7 @@ unsafe internal TrapFrame(IntPtr frame) /// /// Gets the frame's module offset from the start of the module. /// - public int ModuleOffset { get; private set; } + public nuint ModuleOffset { get; private set; } /// /// Gets the frame's module name. @@ -128,14 +128,20 @@ public TrapException(string message, Exception inner) : base(message, inner) { } /// public int? ExitCode { get; private set; } + /// + /// Indentifies which type of trap this is. + /// + public TrapCode Type { get; private set; } + /// protected TrapException(SerializationInfo info, StreamingContext context) : base(info, context) { } - private TrapException(string message, IReadOnlyList frames) : base(message) + internal TrapException(string message, IReadOnlyList? frames, TrapCode type) : base(message) { + Type = type; Frames = frames; } - private static TrapCode GetTrapCode(IntPtr trap) + internal static TrapCode GetTrapCode(IntPtr trap) { if (Native.wasmtime_trap_code(trap, out TrapCode code)) { @@ -176,7 +182,7 @@ internal static TrapException FromOwnedTrap(IntPtr trap) Native.wasm_trap_delete(trap); - var trappedException = new TrapException(message, trapFrames); + var trappedException = new TrapException(message, trapFrames, trapCode); if (trappedExit) { trappedException.ExitCode = exitStatus; @@ -186,7 +192,7 @@ internal static TrapException FromOwnedTrap(IntPtr trap) } } - private static class Native + internal static class Native { [StructLayout(LayoutKind.Sequential)] public unsafe struct FrameArray : IDisposable diff --git a/tests/Modules/Trap.wat b/tests/Modules/Trap.wat index f3015bc..6310eed 100644 --- a/tests/Modules/Trap.wat +++ b/tests/Modules/Trap.wat @@ -1,5 +1,8 @@ (module (export "run" (func $run)) + (export "run_div_zero" (func $run_div_zero)) + (export "run_div_zero_with_result" (func $run_div_zero_with_result)) + (func $run (call $first) ) @@ -12,4 +15,15 @@ (func $third unreachable ) + + (func $run_div_zero_with_result (result i32) + (i32.const 1) + (i32.const 0) + (i32.div_s) + ) + + (func $run_div_zero + (call $run_div_zero_with_result) + (drop) + ) ) diff --git a/tests/TrapTests.cs b/tests/TrapTests.cs index 5785ef6..271fb5c 100644 --- a/tests/TrapTests.cs +++ b/tests/TrapTests.cs @@ -46,6 +46,59 @@ public void ItIncludesAStackTrace() .WithMessage("wasm trap: wasm `unreachable` instruction executed*"); } + [Fact] + public void ItReturnsATrapCodeResult() + { + + var instance = Linker.Instantiate(Store, Fixture.Module); + var run = instance.GetFunction("run_div_zero"); + var result = run(); + + result.Type.Should().Be(ResultType.Trap); + result.TrapCode.Should().Be(TrapCode.IntegerDivisionByZero); + } + + [Fact] + public void ItReturnsATrapCodeAndBacktraceResult() + { + + var instance = Linker.Instantiate(Store, Fixture.Module); + var run = instance.GetFunction("run_div_zero"); + var result = run(); + + result.Type.Should().Be(ResultType.Trap); + result.Trap.Type.Should().Be(TrapCode.IntegerDivisionByZero); + result.Trap.Frames.Count.Should().Be(2); + result.Trap.Frames[0].FunctionName.Should().Be("run_div_zero_with_result"); + result.Trap.Frames[1].FunctionName.Should().Be("run_div_zero"); + } + + [Fact] + public void ItReturnsATrapCodeGenericResult() + { + + var instance = Linker.Instantiate(Store, Fixture.Module); + var run = instance.GetFunction>("run_div_zero_with_result"); + var result = run(); + + result.Type.Should().Be(ResultType.Trap); + result.TrapCode.Should().Be(TrapCode.IntegerDivisionByZero); + } + + [Fact] + public void ItReturnsATrapCodeAndBacktraceGenericResult() + { + + var instance = Linker.Instantiate(Store, Fixture.Module); + var run = instance.GetFunction>("run_div_zero_with_result"); + var result = run(); + + result.Type.Should().Be(ResultType.Trap); + result.Trap.Type.Should().Be(TrapCode.IntegerDivisionByZero); + result.Trap.Frames.Count.Should().Be(1); + result.Trap.Frames[0].FunctionName.Should().Be("run_div_zero_with_result"); + } + public void Dispose() { Store.Dispose();