Skip to content

Commit c43baca

Browse files
committed
Optimize loops using GADTs
1 parent a7c9893 commit c43baca

File tree

1 file changed

+110
-102
lines changed

1 file changed

+110
-102
lines changed

src/kcas/kcas.ml

Lines changed: 110 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -219,57 +219,56 @@ end
219219

220220
let[@inline] isnt_int x = not (Obj.is_int (Obj.repr x))
221221

222-
let rec release_after which = function
222+
let[@inline] rec release_after_rec which = function
223223
| T Leaf -> true
224-
| T (Node node_r) ->
225-
if is_node node_r.lt then release_after which node_r.lt |> ignore;
226-
let state = node_r.state in
227-
if is_cas which state then begin
228-
state.which <- W After;
229-
if isnt_int state.before then state.before <- Obj.magic ();
230-
resume_awaiters node_r.awaiters
231-
end;
232-
release_after which node_r.gt
233-
234-
let rec release_before which = function
224+
| T (Node node_r) -> release_after which (Node node_r)
225+
226+
and release_after which (Node node_r : [< `Node ] tdt) =
227+
release_after_rec which node_r.lt |> ignore;
228+
let state = node_r.state in
229+
if is_cas which state then begin
230+
state.which <- W After;
231+
if isnt_int state.before then state.before <- Obj.magic ();
232+
resume_awaiters node_r.awaiters
233+
end;
234+
release_after_rec which node_r.gt
235+
236+
let[@inline] rec release_before_rec which = function
235237
| T Leaf -> false
236-
| T (Node node_r) ->
237-
if is_node node_r.lt then release_before which node_r.lt |> ignore;
238-
let state = node_r.state in
239-
if is_cas which state then begin
240-
state.which <- W Before;
241-
if isnt_int state.after then state.after <- Obj.magic ();
242-
resume_awaiters node_r.awaiters
243-
end;
244-
release_before which node_r.gt
238+
| T (Node node_r) -> release_before which (Node node_r)
239+
240+
and release_before which (Node node_r : [< `Node ] tdt) =
241+
release_before_rec which node_r.lt |> ignore;
242+
let state = node_r.state in
243+
if is_cas which state then begin
244+
state.which <- W Before;
245+
if isnt_int state.after then state.after <- Obj.magic ();
246+
resume_awaiters node_r.awaiters
247+
end;
248+
release_before_rec which node_r.gt
245249

246250
let release which tree status =
247251
if status == After then release_after which tree
248252
else release_before which tree
249253

250-
let rec verify which = function
254+
let[@inline] rec verify_rec which = function
251255
| T Leaf -> After
252-
| T (Node node_r) ->
253-
if is_node node_r.lt then
254-
let status = verify which node_r.lt in
255-
if status == After then
256-
(* Fenceless is safe as [finish] has a fence after. *)
257-
if
258-
is_cmp which node_r.state
259-
&& fenceless_get (as_atomic node_r.loc) != node_r.state
260-
then Before
261-
else verify which node_r.gt
262-
else status
263-
else if
264-
(* Fenceless is safe as [finish] has a fence after. *)
265-
is_cmp which node_r.state
266-
&& fenceless_get (as_atomic node_r.loc) != node_r.state
267-
then Before
268-
else verify which node_r.gt
256+
| T (Node node_r) -> verify which (Node node_r)
257+
258+
and verify which (Node node_r : [< `Node ] tdt) =
259+
let status = verify_rec which node_r.lt in
260+
if status == After then
261+
(* Fenceless is safe as [finish] has a fence after. *)
262+
if
263+
is_cmp which node_r.state
264+
&& fenceless_get (as_atomic node_r.loc) != node_r.state
265+
then Before
266+
else verify_rec which node_r.gt
267+
else status
269268

270269
let finish which root status =
271270
if Atomic.compare_and_set (root_as_atomic which) (R root) (R status) then
272-
release which (T root) status
271+
release which root status
273272
else
274273
(* Fenceless is safe as we have a fence above. *)
275274
fenceless_get (root_as_atomic which) == R After
@@ -278,25 +277,26 @@ let a_cmp = 1
278277
let a_cas = 2
279278
let a_cmp_followed_by_a_cas = 4
280279

281-
let rec determine which status = function
280+
let[@inline] next_status a_cas_or_a_cmp status =
281+
let a_cmp_followed_by_a_cas = a_cas_or_a_cmp * 2 land (status * 4) in
282+
status lor a_cas_or_a_cmp lor a_cmp_followed_by_a_cas
283+
284+
let[@inline] rec determine_rec which status = function
282285
| T Leaf -> status
283-
| T (Node node_r) ->
284-
let status =
285-
if is_node node_r.lt then determine which status node_r.lt else status
286-
in
287-
if status < 0 then status
288-
else determine_eq Backoff.default which status (Node node_r)
286+
| T (Node node_r) -> determine which status (Node node_r)
287+
288+
and determine which status (Node node_r : [< `Node ] tdt) =
289+
let status = determine_rec which status node_r.lt in
290+
if status < 0 then status
291+
else determine_eq Backoff.default which status (Node node_r)
289292

290293
and determine_eq backoff which status (Node node_r as eq : [< `Node ] tdt) =
291294
let current = atomic_get (as_atomic node_r.loc) in
292295
let state = node_r.state in
293296
if state == current then begin
294297
let a_cas_or_a_cmp = 1 + Bool.to_int (is_cas which state) in
295-
let a_cmp_followed_by_a_cas = a_cas_or_a_cmp * 2 land (status * 4) in
296298
if is_determined which then raise_notrace Exit;
297-
determine which
298-
(status lor a_cas_or_a_cmp lor a_cmp_followed_by_a_cas)
299-
node_r.gt
299+
determine_rec which (next_status a_cas_or_a_cmp status) node_r.gt
300300
end
301301
else
302302
let matches_expected () =
@@ -322,8 +322,7 @@ and determine_eq backoff which status (Node node_r as eq : [< `Node ] tdt) =
322322
awaiters. *)
323323
if current.awaiters != [] then node_r.awaiters <- current.awaiters;
324324
if Atomic.compare_and_set (as_atomic node_r.loc) current state then
325-
let a_cmp_followed_by_a_cas = a_cas * 2 land (status * 4) in
326-
determine which (status lor a_cas lor a_cmp_followed_by_a_cas) node_r.gt
325+
determine_rec which (next_status a_cas status) node_r.gt
327326
else determine_eq (Backoff.once backoff) which status eq
328327
end
329328
else -1
@@ -334,10 +333,10 @@ and is_after = function
334333
match fenceless_get (root_as_atomic which) with
335334
| R (Node node_r) -> begin
336335
let root = Node node_r in
337-
match determine which 0 (T root) with
336+
match determine which 0 root with
338337
| status ->
339338
finish which root
340-
(if a_cmp_followed_by_a_cas < status then verify which (T root)
339+
(if a_cmp_followed_by_a_cas < status then verify which root
341340
else if 0 <= status then After
342341
else Before)
343342
| exception Exit ->
@@ -652,12 +651,14 @@ module Xt = struct
652651
(* Fenceless is safe inside transactions as each log update has a fence. *)
653652
if before != eval (fenceless_get (as_atomic loc)) then Retry.invalid ()
654653

655-
let rec validate_all which = function
654+
let[@inline] rec validate_all_rec which = function
656655
| T Leaf -> ()
657-
| T (Node node_r) ->
658-
if is_node node_r.lt then validate_all which node_r.lt;
659-
validate_one which node_r.loc node_r.state;
660-
validate_all which node_r.gt
656+
| T (Node node_r) -> validate_all which (Node node_r)
657+
658+
and validate_all which (Node node_r : [< `Node ] tdt) =
659+
validate_all_rec which node_r.lt;
660+
validate_one which node_r.loc node_r.state;
661+
validate_all_rec which node_r.gt
661662

662663
let[@inline] maybe_validate_log xt =
663664
let c0 = xt.validate_counter in
@@ -666,7 +667,7 @@ module Xt = struct
666667
(* Validate whenever counter reaches next power of 2. *)
667668
if c0 land c1 = 0 then begin
668669
Timeout.check (timeout_as_atomic xt);
669-
validate_all xt.which xt.tree
670+
validate_all_rec xt.which xt.tree
670671
end
671672

672673
let[@inline] update_new loc f xt lt gt =
@@ -838,35 +839,38 @@ module Xt = struct
838839

839840
let[@inline] call ~xt { tx } = tx ~xt
840841

841-
let rec add_awaiters awaiter which = function
842+
let[@inline] rec add_awaiters_rec awaiter which = function
842843
| T Leaf -> T Leaf
843-
| T (Node node_r) as stop -> begin
844-
match
845-
if is_node node_r.lt then add_awaiters awaiter which node_r.lt
846-
else node_r.lt
847-
with
848-
| T Leaf ->
849-
if
850-
add_awaiter node_r.loc
851-
(let state = node_r.state in
852-
if is_cmp which state then eval state else state.before)
853-
awaiter
854-
then add_awaiters awaiter which node_r.gt
855-
else stop
856-
| T (Node _) as stop -> stop
857-
end
844+
| T (Node node_r) -> add_awaiters awaiter which (Node node_r)
858845

859-
let rec remove_awaiters awaiter which stop = function
860-
| T Leaf -> ()
861-
| T (Node node_r) as current ->
862-
if is_node node_r.lt then remove_awaiters awaiter which stop node_r.lt;
863-
if current != stop then begin
846+
and add_awaiters awaiter which (Node node_r as stop : [< `Node ] tdt) =
847+
match add_awaiters_rec awaiter which node_r.lt with
848+
| T Leaf ->
849+
if
850+
add_awaiter node_r.loc
851+
(let state = node_r.state in
852+
if is_cmp which state then eval state else state.before)
853+
awaiter
854+
then add_awaiters_rec awaiter which node_r.gt
855+
else T stop
856+
| T (Node _) as stop -> stop
857+
858+
let[@inline] rec remove_awaiters_rec awaiter which stop = function
859+
| T Leaf -> T Leaf
860+
| T (Node node_r) -> remove_awaiters awaiter which stop (Node node_r)
861+
862+
and remove_awaiters awaiter which stop (Node node_r as at : [< `Node ] tdt) =
863+
match remove_awaiters_rec awaiter which stop node_r.lt with
864+
| T Leaf ->
865+
if T at != stop then begin
864866
remove_awaiter Backoff.default node_r.loc
865867
(let state = node_r.state in
866868
if is_cmp which state then eval state else state.before)
867869
awaiter;
868-
remove_awaiters awaiter which stop node_r.gt
870+
remove_awaiters_rec awaiter which stop node_r.gt
869871
end
872+
else stop
873+
| T (Node _) as stop -> stop
870874

871875
let initial_validate_period = 16
872876

@@ -908,10 +912,10 @@ module Xt = struct
908912
fenceless_set (root_as_atomic xt.which) (R root);
909913
(* The end result is a cyclic data structure, which is why we cannot
910914
initialize the [which] atomic directly. *)
911-
match determine xt.which 0 (T root) with
915+
match determine xt.which 0 root with
912916
| status ->
913917
if a_cmp_followed_by_a_cas < status then begin
914-
if finish xt.which root (verify xt.which (T root)) then
918+
if finish xt.which root (verify xt.which root) then
915919
success xt result
916920
else begin
917921
(* We switch to [Mode.lock_free] as there was
@@ -937,25 +941,29 @@ module Xt = struct
937941
Timeout.check (timeout_as_atomic xt);
938942
commit (Backoff.once backoff) mode (reset_quick xt) tx
939943
| exception Retry.Later -> begin
940-
if xt.tree == T Leaf then invalid_retry ();
941-
let t = Domain_local_await.prepare_for_await () in
942-
let alive = Timeout.await (timeout_as_atomic xt) t.release in
943-
match add_awaiters t.release xt.which xt.tree with
944-
| T Leaf -> begin
945-
match t.await () with
946-
| () ->
947-
remove_awaiters t.release xt.which (T Leaf) xt.tree;
944+
match xt.tree with
945+
| T Leaf -> invalid_retry ()
946+
| T (Node node_r) -> begin
947+
let root = Node node_r in
948+
let t = Domain_local_await.prepare_for_await () in
949+
let alive = Timeout.await (timeout_as_atomic xt) t.release in
950+
match add_awaiters t.release xt.which root with
951+
| T Leaf -> begin
952+
match t.await () with
953+
| () ->
954+
remove_awaiters t.release xt.which (T Leaf) root |> ignore;
955+
Timeout.unawait (timeout_as_atomic xt) alive;
956+
commit (Backoff.reset backoff) mode (reset_quick xt) tx
957+
| exception cancellation_exn ->
958+
remove_awaiters t.release xt.which (T Leaf) root |> ignore;
959+
Timeout.cancel_alive alive;
960+
raise cancellation_exn
961+
end
962+
| T (Node _) as stop ->
963+
remove_awaiters t.release xt.which stop root |> ignore;
948964
Timeout.unawait (timeout_as_atomic xt) alive;
949-
commit (Backoff.reset backoff) mode (reset_quick xt) tx
950-
| exception cancellation_exn ->
951-
remove_awaiters t.release xt.which (T Leaf) xt.tree;
952-
Timeout.cancel_alive alive;
953-
raise cancellation_exn
965+
commit (Backoff.once backoff) mode (reset_quick xt) tx
954966
end
955-
| T (Node _) as stop ->
956-
remove_awaiters t.release xt.which stop xt.tree;
957-
Timeout.unawait (timeout_as_atomic xt) alive;
958-
commit (Backoff.once backoff) mode (reset_quick xt) tx
959967
end
960968
| exception exn ->
961969
Timeout.cancel (timeout_as_atomic xt);

0 commit comments

Comments
 (0)