diff --git a/src/Hedgehog.Experimental/Gen.fs b/src/Hedgehog.Experimental/Gen.fs index 0246b12..a85ba7b 100644 --- a/src/Hedgehog.Experimental/Gen.fs +++ b/src/Hedgehog.Experimental/Gen.fs @@ -57,35 +57,33 @@ module AutoGenConfig = /// The type is expected to have static methods that return Gen<'a>. /// These methods can have parameters which are required to be of type Gen<_>. let addGenerators<'a> (config: AutoGenConfig) = - let isGen (t: Type) = t.IsGenericType && t.GetGenericTypeDefinition() = typedefof> - - // Ensure that all the parameters are of type Gen<_>, and return the unwrapped types. - let unwrapGenParameters (methodInfo: MethodInfo) : Type[] = - methodInfo.GetParameters() - |> Array.map (fun param -> - if isGen param.ParameterType then - param.ParameterType.GetGenericArguments()[0] - else - failwithf "Method %s.%s has a parameter '%s' which is not of type Gen<...>" - methodInfo.DeclaringType.Name - methodInfo.Name - param.Name) + let isGen (t: Type) = + t.IsGenericType && t.GetGenericTypeDefinition() = typedefof> + + let tryUnwrapGenParameters (methodInfo: MethodInfo) : Option = + methodInfo.GetParameters() + |> Array.fold (fun acc param -> + match acc, isGen param.ParameterType with + | Some types, true -> + Some (Array.append types [| param.ParameterType.GetGenericArguments().[0] |]) + | _ -> None + ) (Some [||]) - // find all the static methods that return Gen<'a> - let methods = typeof<'a>.GetMethods(BindingFlags.Static ||| BindingFlags.Public) - |> Seq.filter (fun m -> isGen m.ReturnType) - - // Register these methods as generator factories - methods - |> Seq.fold (fun cfg methodInfo -> - let targetType = methodInfo.ReturnType.GetGenericArguments()[0] - let typeArray = unwrapGenParameters methodInfo - let factory: Type[] -> obj[] -> obj = fun types gens -> - let methodToCall = if Array.isEmpty types then methodInfo else methodInfo.MakeGenericMethod(types) - methodToCall.Invoke(null, gens) - cfg |> mapGenerators (GeneratorCollection.addGenerator targetType typeArray factory) - ) config + |> Seq.choose (fun methodInfo -> + match isGen methodInfo.ReturnType, tryUnwrapGenParameters methodInfo with + | true, Some typeArray -> + let targetType = methodInfo.ReturnType.GetGenericArguments().[0] + let factory: Type[] -> obj[] -> obj = fun types gens -> + let methodToCall = + if Array.isEmpty types then methodInfo + else methodInfo.MakeGenericMethod(types) + methodToCall.Invoke(null, gens) + Some (targetType, typeArray, factory) + | _ -> None) + |> Seq.fold (fun cfg (targetType, typeArray, factory) -> + cfg |> mapGenerators (GeneratorCollection.addGenerator targetType typeArray factory)) + config module GenX =