Skip to content

Commit 5c2e186

Browse files
committed
implement While/Combine in TaskOption correctly
1 parent 56304db commit 5c2e186

File tree

4 files changed

+271
-31
lines changed

4 files changed

+271
-31
lines changed

src/FsToolkit.ErrorHandling.TaskResult/TaskOptionCE.fs

Lines changed: 169 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,79 @@ type TaskOption<'T> = Task<'T option>
212212
type TaskOptionStateMachineData<'T> =
213213

214214
[<DefaultValue(false)>]
215-
val mutable Result: 'T option
215+
val mutable Result: 'T option voption
216216

217217
[<DefaultValue(false)>]
218218
val mutable MethodBuilder: AsyncTaskOptionMethodBuilder<'T>
219219

220+
member this.IsResultNone =
221+
match this.Result with
222+
| ValueNone -> false
223+
| ValueSome (None) -> true
224+
| ValueSome _ -> false
225+
226+
member this.IsTaskCompleted = this.MethodBuilder.Task.IsCompleted
227+
220228
and AsyncTaskOptionMethodBuilder<'TOverall> = AsyncTaskMethodBuilder<'TOverall option>
221229
and TaskOptionStateMachine<'TOverall> = ResumableStateMachine<TaskOptionStateMachineData<'TOverall>>
222230
and TaskOptionResumptionFunc<'TOverall> = ResumptionFunc<TaskOptionStateMachineData<'TOverall>>
223231
and TaskOptionResumptionDynamicInfo<'TOverall> = ResumptionDynamicInfo<TaskOptionStateMachineData<'TOverall>>
224232
and TaskOptionCode<'TOverall, 'T> = ResumableCode<TaskOptionStateMachineData<'TOverall>, 'T>
225233

226234

235+
module TaskOptionBuilderBase =
236+
237+
let rec WhileDynamic
238+
(
239+
sm: byref<TaskOptionStateMachine<_>>,
240+
condition: unit -> bool,
241+
body: TaskOptionCode<_, _>
242+
) : bool =
243+
if condition () then
244+
if body.Invoke(&sm) then
245+
if sm.Data.IsResultNone then
246+
// Set the result now to allow short-circuiting of the rest of the CE.
247+
// Run/RunDynamic will skip setting the result if it's already been set.
248+
// Combine/CombineDynamic will not continue if the result has been set.
249+
sm.Data.MethodBuilder.SetResult sm.Data.Result.Value
250+
true
251+
else
252+
WhileDynamic(&sm, condition, body)
253+
else
254+
let rf = sm.ResumptionDynamicInfo.ResumptionFunc
255+
256+
sm.ResumptionDynamicInfo.ResumptionFunc <-
257+
(TaskOptionResumptionFunc<_>(fun sm -> WhileBodyDynamicAux(&sm, condition, body, rf)))
258+
259+
false
260+
else
261+
true
262+
263+
and WhileBodyDynamicAux
264+
(
265+
sm: byref<TaskOptionStateMachine<_>>,
266+
condition: unit -> bool,
267+
body: TaskOptionCode<_, _>,
268+
rf: TaskOptionResumptionFunc<_>
269+
) : bool =
270+
if rf.Invoke(&sm) then
271+
if sm.Data.IsResultNone then
272+
// Set the result now to allow short-circuiting of the rest of the CE.
273+
// Run/RunDynamic will skip setting the result if it's already been set.
274+
// Combine/CombineDynamic will not continue if the result has been set.
275+
sm.Data.MethodBuilder.SetResult sm.Data.Result.Value
276+
true
277+
else
278+
WhileDynamic(&sm, condition, body)
279+
else
280+
let rf = sm.ResumptionDynamicInfo.ResumptionFunc
281+
282+
sm.ResumptionDynamicInfo.ResumptionFunc <-
283+
(TaskOptionResumptionFunc<_>(fun sm -> WhileBodyDynamicAux(&sm, condition, body, rf)))
284+
285+
false
286+
287+
227288
type TaskOptionBuilderBase() =
228289

229290
member inline _.Delay(generator: unit -> TaskOptionCode<'TOverall, 'T>) : TaskOptionCode<'TOverall, 'T> =
@@ -236,10 +297,42 @@ type TaskOptionBuilderBase() =
236297
member inline _.Return(value: 'T) : TaskOptionCode<'T, 'T> =
237298
TaskOptionCode<'T, _>
238299
(fun sm ->
239-
sm.Data.Result <- Some value
300+
sm.Data.Result <- ValueSome(Some value)
240301
true)
241302

242303

304+
305+
static member inline CombineDynamic
306+
(
307+
sm: byref<TaskOptionStateMachine<_>>,
308+
task1: TaskOptionCode<'TOverall, unit>,
309+
task2: TaskOptionCode<'TOverall, 'T>
310+
) : bool =
311+
let shouldContinue = task1.Invoke(&sm)
312+
313+
if sm.Data.IsTaskCompleted then
314+
true
315+
elif shouldContinue then
316+
task2.Invoke(&sm)
317+
else
318+
let rec resume (mf: TaskOptionResumptionFunc<_>) =
319+
TaskOptionResumptionFunc<_>
320+
(fun sm ->
321+
let shouldContinue = mf.Invoke(&sm)
322+
323+
if sm.Data.IsTaskCompleted then
324+
true
325+
elif shouldContinue then
326+
task2.Invoke(&sm)
327+
else
328+
sm.ResumptionDynamicInfo.ResumptionFunc <-
329+
(resume (sm.ResumptionDynamicInfo.ResumptionFunc))
330+
331+
false)
332+
333+
sm.ResumptionDynamicInfo.ResumptionFunc <- (resume (sm.ResumptionDynamicInfo.ResumptionFunc))
334+
false
335+
243336
/// Chains together a step with its following step.
244337
/// Note that this requires that the first step has no result.
245338
/// This prevents constructs like `task { return 1; return 2; }`.
@@ -248,15 +341,58 @@ type TaskOptionBuilderBase() =
248341
task1: TaskOptionCode<'TOverall, unit>,
249342
task2: TaskOptionCode<'TOverall, 'T>
250343
) : TaskOptionCode<'TOverall, 'T> =
251-
ResumableCode.Combine(task1, task2)
344+
345+
TaskOptionCode<'TOverall, 'T>
346+
(fun sm ->
347+
if __useResumableCode then
348+
//-- RESUMABLE CODE START
349+
// NOTE: The code for code1 may contain await points! Resuming may branch directly
350+
// into this code!
351+
// printfn "Combine Called Before Invoke --> "
352+
let __stack_fin = task1.Invoke(&sm)
353+
// printfn "Combine Called After Invoke --> %A " sm.Data.MethodBuilder.Task.Status
354+
355+
if sm.Data.IsTaskCompleted then true
356+
elif __stack_fin then task2.Invoke(&sm)
357+
else false
358+
else
359+
TaskOptionBuilderBase.CombineDynamic(&sm, task1, task2))
252360

253361
/// Builds a step that executes the body while the condition predicate is true.
254362
member inline _.While
255363
(
256364
[<InlineIfLambda>] condition: unit -> bool,
257365
body: TaskOptionCode<'TOverall, unit>
258366
) : TaskOptionCode<'TOverall, unit> =
259-
ResumableCode.While(condition, body)
367+
TaskOptionCode<'TOverall, unit>
368+
(fun sm ->
369+
if __useResumableCode then
370+
//-- RESUMABLE CODE START
371+
let mutable __stack_go = true
372+
373+
while __stack_go
374+
&& not sm.Data.IsResultNone
375+
&& condition () do
376+
// printfn "While -> %A" sm.Data.Result
377+
// NOTE: The body of the state machine code for 'while' may contain await points, so resuming
378+
// the code will branch directly into the expanded 'body', branching directly into the while loop
379+
let __stack_body_fin = body.Invoke(&sm)
380+
// printfn "While After Invoke --> %A" sm.Data.Result
381+
// If the body completed, we go back around the loop (__stack_go = true)
382+
// If the body yielded, we yield (__stack_go = false)
383+
__stack_go <- __stack_body_fin
384+
385+
if sm.Data.IsResultNone then
386+
// Set the result now to allow short-circuiting of the rest of the CE.
387+
// Run/RunDynamic will skip setting the result if it's already been set.
388+
// Combine/CombineDynamic will not continue if the result has been set.
389+
sm.Data.MethodBuilder.SetResult sm.Data.Result.Value
390+
391+
__stack_go
392+
//-- RESUMABLE CODE END
393+
else
394+
TaskOptionBuilderBase.WhileDynamic(&sm, condition, body))
395+
260396

261397
/// Wraps a step in a try/with. This catches exceptions both in the evaluation of the function
262398
/// to retrieve the step, and in the continuation of the step (if any).
@@ -349,8 +485,6 @@ type TaskOptionBuilderBase() =
349485
member inline this.Source(taskOption: ValueTask<'T option>) : TaskOption<'T> = task { return! taskOption }
350486

351487

352-
353-
354488
type TaskOptionBuilder() =
355489

356490
inherit TaskOptionBuilderBase()
@@ -376,8 +510,11 @@ type TaskOptionBuilder() =
376510
sm.ResumptionDynamicInfo.ResumptionData <- null
377511
let step = info.ResumptionFunc.Invoke(&sm)
378512

513+
// If the `sm.Data.MethodBuilder` has already been set somewhere else (like While/WhileDynamic), we shouldn't continue
514+
if sm.Data.IsTaskCompleted then ()
515+
379516
if step then
380-
sm.Data.MethodBuilder.SetResult(sm.Data.Result)
517+
sm.Data.MethodBuilder.SetResult(sm.Data.Result.Value)
381518
else
382519
let mutable awaiter =
383520
sm.ResumptionDynamicInfo.ResumptionData :?> ICriticalNotifyCompletion
@@ -412,8 +549,8 @@ type TaskOptionBuilder() =
412549
try
413550
let __stack_code_fin = code.Invoke(&sm)
414551

415-
if __stack_code_fin then
416-
sm.Data.MethodBuilder.SetResult(sm.Data.Result)
552+
if __stack_code_fin && not sm.Data.IsTaskCompleted then
553+
sm.Data.MethodBuilder.SetResult(sm.Data.Result.Value)
417554
with
418555
| exn -> __stack_exn <- exn
419556
// Run SetException outside the stack unwind, see https://github.com/dotnet/roslyn/issues/26567
@@ -459,8 +596,8 @@ type BackgroundTaskOptionBuilder() =
459596
try
460597
let __stack_code_fin = code.Invoke(&sm)
461598

462-
if __stack_code_fin then
463-
sm.Data.MethodBuilder.SetResult(sm.Data.Result)
599+
if __stack_code_fin && not sm.Data.IsTaskCompleted then
600+
sm.Data.MethodBuilder.SetResult(sm.Data.Result.Value)
464601
with
465602
| exn -> sm.Data.MethodBuilder.SetException exn
466603
//-- RESUMABLE CODE END
@@ -513,7 +650,7 @@ module TaskOptionCEExtensionsLowPriority =
513650
[<NoEagerConstraintApplication>]
514651
static member inline BindDynamic< ^TaskLike, 'TResult1, 'TResult2, ^Awaiter, 'TOverall when ^TaskLike: (member GetAwaiter :
515652
unit -> ^Awaiter) and ^Awaiter :> ICriticalNotifyCompletion and ^Awaiter: (member get_IsCompleted :
516-
unit -> bool) and ^Awaiter: (member GetResult : unit -> 'TResult1)>
653+
unit -> bool) and ^Awaiter: (member GetResult : unit -> 'TResult1 option)>
517654
(
518655
sm: byref<_>,
519656
task: ^TaskLike,
@@ -527,9 +664,13 @@ module TaskOptionCEExtensionsLowPriority =
527664
(TaskOptionResumptionFunc<'TOverall>
528665
(fun sm ->
529666
let result =
530-
(^Awaiter: (member GetResult : unit -> 'TResult1) (awaiter))
667+
(^Awaiter: (member GetResult : unit -> 'TResult1 option) (awaiter))
531668

532-
(continuation result).Invoke(&sm)))
669+
match result with
670+
| Some result -> (continuation result).Invoke(&sm)
671+
| None ->
672+
sm.Data.Result <- ValueSome None
673+
true))
533674

534675
// shortcut to continue immediately
535676
if (^Awaiter: (member get_IsCompleted : unit -> bool) (awaiter)) then
@@ -542,7 +683,7 @@ module TaskOptionCEExtensionsLowPriority =
542683
[<NoEagerConstraintApplication>]
543684
member inline _.Bind< ^TaskLike, 'TResult1, 'TResult2, ^Awaiter, 'TOverall when ^TaskLike: (member GetAwaiter :
544685
unit -> ^Awaiter) and ^Awaiter :> ICriticalNotifyCompletion and ^Awaiter: (member get_IsCompleted :
545-
unit -> bool) and ^Awaiter: (member GetResult : unit -> 'TResult1)>
686+
unit -> bool) and ^Awaiter: (member GetResult : unit -> 'TResult1 option)>
546687
(
547688
task: ^TaskLike,
548689
continuation: ('TResult1 -> TaskOptionCode<'TOverall, 'TResult2>)
@@ -566,9 +707,13 @@ module TaskOptionCEExtensionsLowPriority =
566707

567708
if __stack_fin then
568709
let result =
569-
(^Awaiter: (member GetResult : unit -> 'TResult1) (awaiter))
710+
(^Awaiter: (member GetResult : unit -> 'TResult1 option) (awaiter))
570711

571-
(continuation result).Invoke(&sm)
712+
match result with
713+
| Some result -> (continuation result).Invoke(&sm)
714+
| None ->
715+
sm.Data.Result <- ValueSome None
716+
true
572717
else
573718
sm.Data.MethodBuilder.AwaitUnsafeOnCompleted(&awaiter, &sm)
574719
false
@@ -583,7 +728,7 @@ module TaskOptionCEExtensionsLowPriority =
583728

584729
[<NoEagerConstraintApplication>]
585730
member inline this.ReturnFrom< ^TaskLike, ^Awaiter, 'T when ^TaskLike: (member GetAwaiter : unit -> ^Awaiter) and ^Awaiter :> ICriticalNotifyCompletion and ^Awaiter: (member get_IsCompleted :
586-
unit -> bool) and ^Awaiter: (member GetResult : unit -> 'T)>
731+
unit -> bool) and ^Awaiter: (member GetResult : unit -> 'T option)>
587732
(task: ^TaskLike)
588733
: TaskOptionCode<'T, 'T> =
589734

@@ -626,7 +771,9 @@ module TaskOptionCEExtensionsHighPriority =
626771

627772
match result with
628773
| Some result -> (continuation result).Invoke(&sm)
629-
| None -> true))
774+
| None ->
775+
sm.Data.Result <- ValueSome None
776+
true))
630777

631778
// shortcut to continue immediately
632779
if awaiter.IsCompleted then
@@ -662,7 +809,9 @@ module TaskOptionCEExtensionsHighPriority =
662809

663810
match result with
664811
| Some result -> (continuation result).Invoke(&sm)
665-
| None -> true
812+
| None ->
813+
sm.Data.Result <- ValueSome None
814+
true
666815

667816
else
668817
sm.Data.MethodBuilder.AwaitUnsafeOnCompleted(&awaiter, &sm)

src/FsToolkit.ErrorHandling.TaskResult/TaskResultCE.fs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,18 @@ module TaskResultCE =
134134

135135
let taskResult = TaskResultBuilder()
136136

137+
[<AutoOpen>]
138+
// Having members as extensions gives them lower priority in
139+
// overload resolution and allows skipping more type annotations.
140+
module TaskResultCEExtensionsLower =
141+
142+
type TaskResultBuilder with
143+
member inline this.Source(t: ^TaskLike) : Task<Result<'T, 'Error>> =
144+
task {
145+
let! r = t
146+
return Ok r
147+
}
148+
137149
// Having members as extensions gives them lower priority in
138150
// overload resolution between Task<_> and Task<Result<_,_>>.
139151
[<AutoOpen>]
@@ -305,10 +317,7 @@ type TaskResultBuilderBase() =
305317
TaskResultCode<'T, 'Error, _>
306318
(fun sm ->
307319
// printfn "Return Called --> "
308-
309-
match sm.Data.Result with
310-
| Ok _ -> sm.Data.Result <- Ok value
311-
| Error e -> ()
320+
sm.Data.Result <- Ok value
312321

313322
true)
314323

@@ -381,19 +390,19 @@ type TaskResultBuilderBase() =
381390
if __useResumableCode then
382391
//-- RESUMABLE CODE START
383392
let mutable __stack_go = true
384-
let mutable errored = false
385393

386-
while __stack_go && not errored && condition () do
394+
while __stack_go
395+
&& not sm.Data.IsResultError
396+
&& condition () do
387397
// NOTE: The body of the state machine code for 'while' may contain await points, so resuming
388398
// the code will branch directly into the expanded 'body', branching directly into the while loop
389399
let __stack_body_fin = body.Invoke(&sm)
390400
// printfn "While After Invoke --> %A" sm.Data.Result
391401
// If the body completed, we go back around the loop (__stack_go = true)
392402
// If the body yielded, we yield (__stack_go = false)
393403
__stack_go <- __stack_body_fin
394-
errored <- sm.Data.IsResultError
395404

396-
if errored then
405+
if sm.Data.IsResultError then
397406
// Set the result now to allow short-circuiting of the rest of the CE.
398407
// Run/RunDynamic will skip setting the result if it's already been set.
399408
// Combine/CombineDynamic will not continue if the result has been set.

0 commit comments

Comments
 (0)