Skip to content

Commit 9e0b9fd

Browse files
committed
fix TaskBuilder.TryFinally -- fixes #348
Fixes the uses of `use` keyword in `TaskBuilder` and other constructs that use `TryFinally`. Before, `TryFinally` called compensation (disposing) immediately. Now the compensation (disposing) is called only after the body runs. fixes #348
1 parent 7525861 commit 9e0b9fd

File tree

2 files changed

+105
-53
lines changed

2 files changed

+105
-53
lines changed

src/FSharpx.Extras/ComputationExpressions/Monad.fs

Lines changed: 101 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,95 +1024,145 @@ module Task =
10241024
List.foldBack cons s (returnM [])
10251025

10261026
let inline mapM f x = sequence (List.map f x)
1027-
1027+
1028+
10281029
type TaskBuilder(?continuationOptions, ?scheduler, ?cancellationToken) =
10291030
let contOptions = defaultArg continuationOptions TaskContinuationOptions.None
10301031
let scheduler = defaultArg scheduler TaskScheduler.Default
10311032
let cancellationToken = defaultArg cancellationToken CancellationToken.None
10321033

10331034
member this.Return x = returnM x
10341035

1035-
member this.Zero() = returnM ()
1036+
member this.Bind(m, f) = bindWithOptions cancellationToken contOptions scheduler f m
1037+
1038+
member this.Zero() : Task<unit> = this.Return ()
10361039

1037-
member this.ReturnFrom (a: Task<'T>) = a
1040+
member this.ReturnFrom (a: Task<'a>) = a
10381041

1039-
member this.Bind(m, f) = bindWithOptions cancellationToken contOptions scheduler f m
1042+
member this.Run (body : unit -> Task<'a>) = body()
10401043

1041-
member this.Combine(comp1, comp2) =
1042-
this.Bind(comp1, comp2)
1044+
member this.Delay (body : unit -> Task<'a>) : unit -> Task<'a> = fun () -> this.Bind(this.Return(), body)
10431045

1044-
member this.While(guard, m) =
1045-
if not(guard()) then this.Zero() else
1046-
this.Bind(m(), fun () -> this.While(guard, m))
1047-
1048-
member this.TryWith(body:unit -> Task<_>, catchFn:exn -> Task<_>) =
1049-
try
1050-
body()
1051-
.ContinueWith(fun (t:Task<_>) ->
1052-
match t.IsFaulted with
1053-
| false -> returnM(t.Result)
1054-
| true -> catchFn(t.Exception.GetBaseException()))
1055-
.Unwrap()
1046+
member this.Combine(t1:Task<unit>, t2 : unit -> Task<'b>) : Task<'b> = this.Bind(t1, t2)
1047+
1048+
member this.While(guard, body : unit -> Task<unit>) : Task<unit> =
1049+
if not(guard())
1050+
then this.Zero()
1051+
else this.Bind(body(), fun () -> this.While(guard, body))
1052+
1053+
member this.TryWith(body : unit -> Task<'a>, catchFn:exn -> Task<'a>) : Task<'a> =
1054+
let continuation (t:Task<'a>) : Task<'a> =
1055+
if t.IsFaulted
1056+
then catchFn(t.Exception.GetBaseException())
1057+
else this.Return(t.Result)
1058+
1059+
try body().ContinueWith(continuation).Unwrap()
10561060
with e -> catchFn(e)
10571061

1058-
member this.TryFinally(m, compensation) =
1059-
try this.ReturnFrom m
1060-
finally compensation()
1062+
member this.TryFinally(body : unit -> Task<'a>, compensation) : Task<'a> =
1063+
let wrapOk (x:'a) : Task<'a> =
1064+
compensation()
1065+
this.Return x
10611066

1062-
member this.Using(res: #IDisposable, body: #IDisposable -> Task<_>) =
1063-
this.TryFinally(body res, fun () -> match res with null -> () | disp -> disp.Dispose())
1067+
let wrapCrash (e:exn) : Task<'a> =
1068+
compensation()
1069+
raise e
10641070

1065-
member this.For(sequence: seq<_>, body) =
1066-
this.Using(sequence.GetEnumerator(),
1067-
fun enum -> this.While(enum.MoveNext, fun () -> body enum.Current))
1071+
this.Bind(this.TryWith(body, wrapCrash), wrapOk)
1072+
1073+
member this.Using(res:#IDisposable, body : #IDisposable -> Task<'a>) : Task<'a> =
1074+
let compensation() =
1075+
match res with
1076+
| null -> ()
1077+
| disp -> disp.Dispose()
10681078

1069-
member this.Delay (f: unit -> Task<'T>) = f
1079+
this.TryFinally((fun () -> body res), compensation)
1080+
1081+
member this.For(sequence:seq<'a>, body : 'a -> Task<unit>) : Task<unit> =
1082+
this.Using( sequence.GetEnumerator()
1083+
, fun enum -> this.While( enum.MoveNext
1084+
, fun () -> body enum.Current
1085+
)
1086+
)
10701087

1071-
member this.Run (f: unit -> Task<'T>) = f()
10721088

10731089
let task = TaskBuilder()
10741090

1091+
type TokenToTask<'a> = CancellationToken -> Task<'a>
10751092
type TaskBuilderWithToken(?continuationOptions, ?scheduler) =
10761093
let contOptions = defaultArg continuationOptions TaskContinuationOptions.None
10771094
let scheduler = defaultArg scheduler TaskScheduler.Default
10781095

10791096
let lift (t: Task<_>) = fun (_: CancellationToken) -> t
1080-
let bind (t: CancellationToken -> Task<'T>) (f: 'T -> (CancellationToken -> Task<'U>)) =
1097+
1098+
let bind (t:TokenToTask<'a>) (f : 'a -> TokenToTask<'b>) =
10811099
fun (token: CancellationToken) ->
1082-
(t token).ContinueWith((fun (x: Task<_>) -> f x.Result token), token, contOptions, scheduler).Unwrap()
1100+
(t token).ContinueWith( fun (x: Task<_>) -> f x.Result token
1101+
, token
1102+
, contOptions
1103+
, scheduler
1104+
)
1105+
.Unwrap()
10831106

10841107
member this.Return x = lift (returnM x)
10851108

1109+
member this.Bind(t, f) = bind t f
1110+
1111+
member this.Bind(t, f) = bind (lift t) f
1112+
10861113
member this.ReturnFrom t = lift t
10871114

1088-
member this.ReturnFrom (t: CancellationToken -> Task<'T>) = t
1115+
member this.ReturnFrom (t:TokenToTask<'a>) = t
10891116

1090-
member this.Zero() = this.Return ()
1117+
member this.Zero() : TokenToTask<unit> = this.Return ()
10911118

1092-
member this.Bind(t, f) = bind t f
1119+
member this.Run (body : unit -> TokenToTask<'a>) = body()
10931120

1094-
member this.Bind(t, f) = bind (lift t) f
1121+
member this.Delay (body : unit -> TokenToTask<'a>) : unit -> TokenToTask<'a> = fun () -> this.Bind(this.Return(), body)
10951122

1096-
member this.Combine(t1, t2) = bind t1 (konst t2)
1123+
member this.Combine(t1 : TokenToTask<unit>, t2 : unit -> TokenToTask<'b>) : TokenToTask<'b> = this.Bind(t1, t2)
10971124

1098-
member this.While(guard, m) =
1099-
if not(guard()) then
1100-
this.Zero()
1101-
else
1102-
bind m (fun () -> this.While(guard, m))
1125+
member this.While(guard, body : unit -> TokenToTask<unit>) : TokenToTask<unit> =
1126+
if not(guard())
1127+
then this.Zero()
1128+
else this.Bind(body(), fun () -> this.While(guard, body))
11031129

1104-
member this.TryFinally(t : CancellationToken -> Task<'T>, compensation) =
1105-
try t
1106-
finally compensation()
1130+
member this.TryWith(body : unit -> TokenToTask<'a>, catchFn : exn -> TokenToTask<'a>) : TokenToTask<'a> = fun token ->
1131+
let continuation (t:Task<'a>) : Task<'a> =
1132+
if t.IsFaulted
1133+
then catchFn(t.Exception.GetBaseException())
1134+
else this.Return(t.Result)
1135+
<| token
11071136

1108-
member this.Using(res: #IDisposable, body: #IDisposable -> (CancellationToken -> Task<'T>)) =
1109-
this.TryFinally(body res, fun () -> match res with null -> () | disp -> disp.Dispose())
1137+
try (body() token).ContinueWith(continuation).Unwrap()
1138+
with e -> catchFn(e) token
1139+
1140+
member this.TryFinally(body : unit -> TokenToTask<'a>, compensation) : TokenToTask<'a> =
1141+
let wrapOk (x:'a) : TokenToTask<'a> =
1142+
compensation()
1143+
this.Return x
1144+
1145+
let wrapCrash (e:exn) : TokenToTask<'a> =
1146+
compensation()
1147+
raise e
1148+
1149+
this.Bind(this.TryWith(body, wrapCrash), wrapOk)
1150+
1151+
member this.Using(res:#IDisposable, body : #IDisposable -> TokenToTask<'a>) : TokenToTask<'a> =
1152+
let compensation() =
1153+
match res with
1154+
| null -> ()
1155+
| disp -> disp.Dispose()
1156+
1157+
this.TryFinally((fun () -> body res), compensation)
1158+
1159+
member this.For(sequence:seq<'a>, body : 'a -> TokenToTask<unit>) : TokenToTask<unit> =
1160+
this.Using( sequence.GetEnumerator()
1161+
, fun enum -> this.While( enum.MoveNext
1162+
, fun () -> body enum.Current
1163+
)
1164+
)
11101165

1111-
member this.For(sequence: seq<'T>, body) =
1112-
this.Using(sequence.GetEnumerator(),
1113-
fun enum -> this.While(enum.MoveNext, fun token -> body enum.Current token))
1114-
1115-
member this.Delay f = this.Bind(this.Return (), f)
11161166

11171167
/// Converts a Task into Task<unit>
11181168
let ToTaskUnit (t:Task) =
@@ -1155,4 +1205,4 @@ module Task =
11551205
}
11561206
tasks
11571207
|> Seq.map throttleTask
1158-
|> Parallel
1208+
|> Parallel

tests/FSharpx.Tests/TaskTests.fs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ let ``exception in task``() =
9191
failwith "error"
9292
}
9393
match Task.run t with
94-
| Task.Error e -> Assert.AreEqual("error", e.Message)
94+
| Task.Error e -> Assert.AreEqual("error", e.InnerException.Message)
9595
| _ -> Assert.Fail "task should have errored"
9696

9797
[<Test>]
@@ -197,5 +197,7 @@ type TaskGen =
197197
let ``run delay law``() =
198198
Arb.register<TaskGen>() |> ignore
199199
let task = Task.TaskBuilder(continuationOptions = TaskContinuationOptions.ExecuteSynchronously)
200-
fsCheck "run delay law" (fun a -> (task.Run << task.Delay << konst) a = a)
200+
let delay = konst >> task.Delay >> task.Run
201+
let run (transform : _ -> Task<_>) t = (transform t).Result
201202

203+
fsCheck "run delay law" (fun t -> run id t = run delay t)

0 commit comments

Comments
 (0)