Skip to content

Commit 4a5d840

Browse files
committed
Add awaiters and cleanup
1 parent b2ca519 commit 4a5d840

File tree

1 file changed

+68
-21
lines changed

1 file changed

+68
-21
lines changed

doc/simplified.ml

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
module Kcas : sig
1+
module type Awaiter = sig
2+
type t
3+
4+
val signal : t -> unit
5+
end
6+
7+
module Make (Awaiter : Awaiter) : sig
8+
module Awaiter : Awaiter
9+
210
type 'a loc
311

412
val make : 'a -> 'a loc
@@ -9,18 +17,36 @@ module Kcas : sig
917

1018
val atomically : cas list -> cmp list -> bool
1119
end = struct
20+
module Awaiter = Awaiter
21+
1222
type 'a loc = 'a state Atomic.t
13-
and 'a state = { before : 'a; after : 'a; casn : casn }
14-
and cass = CASS : 'a loc * 'a state -> cass
23+
24+
and 'a state = {
25+
mutable before : 'a;
26+
mutable after : 'a;
27+
casn : casn;
28+
awaiters : Awaiter.t list;
29+
}
30+
31+
and cass =
32+
| CASS : {
33+
loc : 'a loc;
34+
desired : 'a state;
35+
mutable awaiters : Awaiter.t list;
36+
}
37+
-> cass
38+
39+
and cmps = CMPS : { loc : 'a loc; current : 'a state } -> cmps
1540
and casn = status Atomic.t
1641

1742
and status =
18-
| Undetermined of { cass : cass list; cmps : cass list }
43+
| Undetermined of { cass : cass list; cmps : cmps list }
1944
| After
2045
| Before
2146

2247
let make after =
23-
Atomic.make { before = after; after; casn = Atomic.make After }
48+
Atomic.make
49+
{ before = after; after; casn = Atomic.make After; awaiters = [] }
2450

2551
type cas = CAS : 'a loc * 'a * 'a -> cas
2652
type cmp = CMP : 'a loc * 'a -> cmp
@@ -29,31 +55,44 @@ end = struct
2955
match Atomic.get casn with
3056
| After -> true
3157
| Before -> false
32-
| Undetermined { cmps; _ } as current ->
58+
| Undetermined undetermined as current ->
3359
let desired =
3460
if
3561
desired == After
36-
&& cmps
37-
|> List.exists @@ fun (CASS (loc, state)) ->
38-
Atomic.get loc != state
62+
&& undetermined.cmps
63+
|> List.exists @@ fun (CMPS cmps) ->
64+
Atomic.get cmps.loc != cmps.current
3965
then Before
4066
else desired
4167
in
42-
Atomic.compare_and_set casn current desired |> ignore;
68+
if Atomic.compare_and_set casn current desired then begin
69+
if desired == After then begin
70+
undetermined.cass
71+
|> List.iter @@ fun (CASS cass) ->
72+
List.iter Awaiter.signal cass.awaiters;
73+
cass.desired.before <- cass.desired.after
74+
end
75+
else begin
76+
undetermined.cass
77+
|> List.iter @@ fun (CASS cass) ->
78+
cass.desired.after <- cass.desired.before
79+
end
80+
end;
4381
Atomic.get casn == After
4482

4583
let rec gkmz casn = function
4684
| [] -> finish casn After
47-
| CASS (loc, desired) :: continue as retry -> begin
48-
let current = Atomic.get loc in
49-
if desired == current then gkmz casn continue
85+
| CASS cass :: continue as retry -> begin
86+
let current = Atomic.get cass.loc in
87+
if cass.desired == current then gkmz casn continue
5088
else
5189
let current_value = get_from current in
52-
if current_value != desired.before then finish casn Before
90+
if current_value != cass.desired.before then finish casn Before
5391
else
5492
match Atomic.get casn with
5593
| Undetermined _ ->
56-
if Atomic.compare_and_set loc current desired then
94+
cass.awaiters <- current.awaiters;
95+
if Atomic.compare_and_set cass.loc current cass.desired then
5796
gkmz casn continue
5897
else gkmz casn retry
5998
| After -> true
@@ -73,15 +112,17 @@ end = struct
73112
let cass =
74113
logical_cas_list
75114
|> List.map @@ function
76-
| CAS (loc, before, after) -> CASS (loc, { before; after; casn })
115+
| CAS (loc, before, after) ->
116+
let next = { before; after; casn; awaiters = [] } in
117+
CASS { loc; desired = next; awaiters = [] }
77118
in
78119
let cmps =
79120
logical_cmp_list
80121
|> List.map @@ function
81122
| CMP (loc, expected) ->
82123
let current = Atomic.get loc in
83124
if get_from current != expected then raise Exit
84-
else CASS (loc, current)
125+
else CMPS { loc; current }
85126
in
86127
Atomic.set casn (Undetermined { cass; cmps });
87128
gkmz casn cass
@@ -93,8 +134,14 @@ end = struct
93134
end
94135

95136
let () =
96-
let x = Kcas.make 82 in
97-
let y = Kcas.make 40 in
98-
assert (Kcas.atomically [ CAS (x, 82, 42) ] [ CMP (y, 40) ]);
99-
assert (Kcas.get x == 42 && Kcas.get y == 40);
137+
let module Awaiter = struct
138+
type t = unit
139+
140+
let signal = ignore
141+
end in
142+
let module STM = Make (Awaiter) in
143+
let x = STM.make 82 in
144+
let y = STM.make 40 in
145+
assert (STM.atomically [ CAS (x, 82, 42) ] [ CMP (y, 40) ]);
146+
assert (STM.get x == 42 && STM.get y == 40);
100147
()

0 commit comments

Comments
 (0)