1- module Kcas : sig
1+ module type Awaiter = sig
2+ type t
3+
4+ val signal : t -> unit
5+ end
6+
7+ module Make (Awaiter : Awaiter ) : sig
8+ module Awaiter : Awaiter
9+
210 type 'a loc
311
412 val make : 'a -> 'a loc
@@ -9,18 +17,36 @@ module Kcas : sig
917
1018 val atomically : cas list -> cmp list -> bool
1119end = struct
20+ module Awaiter = Awaiter
21+
1222 type 'a loc = 'a state Atomic .t
13- and 'a state = { before : 'a ; after : 'a ; casn : casn }
14- and cass = CASS : 'a loc * 'a state -> cass
23+
24+ and 'a state = {
25+ mutable before : 'a ;
26+ mutable after : 'a ;
27+ casn : casn ;
28+ awaiters : Awaiter .t list ;
29+ }
30+
31+ and cass =
32+ | CASS : {
33+ loc : 'a loc ;
34+ desired : 'a state ;
35+ mutable awaiters : Awaiter .t list ;
36+ }
37+ -> cass
38+
39+ and cmps = CMPS : { loc : 'a loc ; current : 'a state } -> cmps
1540 and casn = status Atomic. t
1641
1742 and status =
18- | Undetermined of { cass : cass list ; cmps : cass list }
43+ | Undetermined of { cass : cass list ; cmps : cmps list }
1944 | After
2045 | Before
2146
2247 let make after =
23- Atomic. make { before = after; after; casn = Atomic. make After }
48+ Atomic. make
49+ { before = after; after; casn = Atomic. make After ; awaiters = [] }
2450
2551 type cas = CAS : 'a loc * 'a * 'a -> cas
2652 type cmp = CMP : 'a loc * 'a -> cmp
@@ -29,31 +55,44 @@ end = struct
2955 match Atomic. get casn with
3056 | After -> true
3157 | Before -> false
32- | Undetermined { cmps; _ } as current ->
58+ | Undetermined undetermined as current ->
3359 let desired =
3460 if
3561 desired == After
36- && cmps
37- |> List. exists @@ fun (CASS ( loc , state ) ) ->
38- Atomic. get loc != state
62+ && undetermined. cmps
63+ |> List. exists @@ fun (CMPS cmps ) ->
64+ Atomic. get cmps. loc != cmps.current
3965 then Before
4066 else desired
4167 in
42- Atomic. compare_and_set casn current desired |> ignore;
68+ if Atomic. compare_and_set casn current desired then begin
69+ if desired == After then begin
70+ undetermined.cass
71+ |> List. iter @@ fun (CASS cass ) ->
72+ List. iter Awaiter. signal cass.awaiters;
73+ cass.desired.before < - cass.desired.after
74+ end
75+ else begin
76+ undetermined.cass
77+ |> List. iter @@ fun (CASS cass ) ->
78+ cass.desired.after < - cass.desired.before
79+ end
80+ end;
4381 Atomic. get casn == After
4482
4583 let rec gkmz casn = function
4684 | [] -> finish casn After
47- | CASS ( loc , desired ) :: continue as retry -> begin
48- let current = Atomic. get loc in
49- if desired == current then gkmz casn continue
85+ | CASS cass :: continue as retry -> begin
86+ let current = Atomic. get cass. loc in
87+ if cass. desired == current then gkmz casn continue
5088 else
5189 let current_value = get_from current in
52- if current_value != desired.before then finish casn Before
90+ if current_value != cass. desired.before then finish casn Before
5391 else
5492 match Atomic. get casn with
5593 | Undetermined _ ->
56- if Atomic. compare_and_set loc current desired then
94+ cass.awaiters < - current.awaiters;
95+ if Atomic. compare_and_set cass.loc current cass.desired then
5796 gkmz casn continue
5897 else gkmz casn retry
5998 | After -> true
@@ -73,15 +112,17 @@ end = struct
73112 let cass =
74113 logical_cas_list
75114 |> List. map @@ function
76- | CAS (loc , before , after ) -> CASS (loc, { before; after; casn })
115+ | CAS (loc , before , after ) ->
116+ let next = { before; after; casn; awaiters = [] } in
117+ CASS { loc; desired = next; awaiters = [] }
77118 in
78119 let cmps =
79120 logical_cmp_list
80121 |> List. map @@ function
81122 | CMP (loc , expected ) ->
82123 let current = Atomic. get loc in
83124 if get_from current != expected then raise Exit
84- else CASS ( loc, current)
125+ else CMPS { loc; current }
85126 in
86127 Atomic. set casn (Undetermined { cass; cmps });
87128 gkmz casn cass
@@ -93,8 +134,14 @@ end = struct
93134end
94135
95136let () =
96- let x = Kcas. make 82 in
97- let y = Kcas. make 40 in
98- assert (Kcas. atomically [ CAS (x, 82 , 42 ) ] [ CMP (y, 40 ) ]);
99- assert (Kcas. get x == 42 && Kcas. get y == 40 );
137+ let module Awaiter = struct
138+ type t = unit
139+
140+ let signal = ignore
141+ end in
142+ let module STM = Make (Awaiter ) in
143+ let x = STM. make 82 in
144+ let y = STM. make 40 in
145+ assert (STM. atomically [ CAS (x, 82 , 42 ) ] [ CMP (y, 40 ) ]);
146+ assert (STM. get x == 42 && STM. get y == 40 );
100147 ()
0 commit comments