Skip to content

Commit 01b37bd

Browse files
committed
Add Xt.compare_and_set
1 parent 5919d85 commit 01b37bd

File tree

3 files changed

+102
-40
lines changed

3 files changed

+102
-40
lines changed

src/kcas/kcas.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,9 @@ module Xt = struct
796796
unsafe_update ~xt loc (fun actual ->
797797
if actual == before then after else actual)
798798

799+
let compare_and_set ~xt loc before after =
800+
compare_and_swap ~xt loc before after == before
801+
799802
let exchange ~xt loc after = unsafe_update ~xt loc (fun _ -> after)
800803
let fetch_and_add ~xt loc n = unsafe_update ~xt loc (( + ) n)
801804
let incr ~xt loc = unsafe_update ~xt loc inc |> ignore

src/kcas/kcas.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ module Xt : sig
441441
val swap : xt:'x t -> 'a Loc.t -> 'a Loc.t -> unit
442442
(** [swap ~xt l1 l2] is equivalent to [set ~xt l1 @@ exchange ~xt l2 @@ get ~xt l1]. *)
443443

444+
val compare_and_set : xt:'x t -> 'a Loc.t -> 'a -> 'a -> bool
445+
(** [compare_and_set ~xt r before after] is equivalent to
446+
[compare_and_swap ~xt r before after == before]. *)
447+
444448
val compare_and_swap : xt:'x t -> 'a Loc.t -> 'a -> 'a -> 'a
445449
(** [compare_and_swap ~xt r before after] is equivalent to
446450

test/kcas/test.ml

Lines changed: 95 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -42,48 +42,101 @@ let run_domains = function
4242
List.iter Domain.join others
4343

4444
let test_non_linearizable () =
45-
let barrier = Barrier.make 2
46-
and n_iter = 100 * Util.iter_factor
47-
and test_finished = ref false in
45+
[ Mode.obstruction_free; Mode.lock_free ]
46+
|> List.iter @@ fun mode ->
47+
let barrier = Barrier.make 2
48+
and n_iter = 100 * Util.iter_factor
49+
and test_finished = ref false in
50+
51+
let a = Loc.make ~mode 0 and b = Loc.make ~mode 0 in
52+
53+
let cass1a = [ Op.make_cmp b 0; Op.make_cas a 0 1 ]
54+
and cass1b = [ Op.make_cmp b 0; Op.make_cas a 1 0 ]
55+
and cass2a = [ Op.make_cas b 0 1; Op.make_cmp a 0 ]
56+
and cass2b = [ Op.make_cas b 1 0; Op.make_cmp a 0 ] in
57+
58+
let atomically cs =
59+
if Random.bool () then
60+
try Op.atomically ~mode:Mode.obstruction_free cs
61+
with Mode.Interference -> false
62+
else Op.atomically cs
63+
in
64+
65+
let thread1 () =
66+
Barrier.await barrier;
67+
while not !test_finished do
68+
if atomically cass1a then
69+
while not (atomically cass1b) do
70+
if is_single then Domain.cpu_relax ();
71+
assert (Loc.get a == 1 && Loc.get b == 0)
72+
done
73+
else if is_single then Domain.cpu_relax ()
74+
done
75+
and thread2 () =
76+
Barrier.await barrier;
77+
for _ = 1 to n_iter do
78+
if atomically cass2a then
79+
while not (atomically cass2b) do
80+
if is_single then Domain.cpu_relax ();
81+
assert (Loc.get a == 0 && Loc.get b == 1)
82+
done
83+
else if is_single then Domain.cpu_relax ()
84+
done;
85+
test_finished := true
86+
in
87+
88+
run_domains [ thread2; thread1 ]
4889

49-
let a = Loc.make 0 and b = Loc.make 0 in
50-
51-
let cass1a = [ Op.make_cmp b 0; Op.make_cas a 0 1 ]
52-
and cass1b = [ Op.make_cmp b 0; Op.make_cas a 1 0 ]
53-
and cass2a = [ Op.make_cas b 0 1; Op.make_cmp a 0 ]
54-
and cass2b = [ Op.make_cas b 1 0; Op.make_cmp a 0 ] in
55-
56-
let atomically cs =
57-
if Random.bool () then
58-
try Op.atomically ~mode:Mode.obstruction_free cs
59-
with Mode.Interference -> false
60-
else Op.atomically cs
61-
in
62-
63-
let thread1 () =
64-
Barrier.await barrier;
65-
while not !test_finished do
66-
if atomically cass1a then
67-
while not (atomically cass1b) do
68-
if is_single then Domain.cpu_relax ();
69-
assert (Loc.get a == 1 && Loc.get b == 0)
70-
done
71-
else if is_single then Domain.cpu_relax ()
72-
done
73-
and thread2 () =
74-
Barrier.await barrier;
75-
for _ = 1 to n_iter do
76-
if atomically cass2a then
77-
while not (atomically cass2b) do
78-
if is_single then Domain.cpu_relax ();
79-
assert (Loc.get a == 0 && Loc.get b == 1)
80-
done
81-
else if is_single then Domain.cpu_relax ()
82-
done;
83-
test_finished := true
84-
in
90+
(* *)
8591

86-
run_domains [ thread2; thread1 ]
92+
let test_non_linearizable_xt () =
93+
[ Mode.obstruction_free; Mode.lock_free ]
94+
|> List.iter @@ fun mode ->
95+
let barrier = Barrier.make 2
96+
and n_iter = 100 * Util.iter_factor
97+
and test_finished = ref false in
98+
99+
let a = Loc.make ~mode 0 and b = Loc.make ~mode 0 in
100+
101+
let cass1a ~xt =
102+
(Xt.get ~xt b == 0 && Xt.compare_and_set ~xt a 0 1) || Retry.invalid ()
103+
and cass1b ~xt =
104+
(Xt.get ~xt b == 0 && Xt.compare_and_set ~xt a 1 0) || Retry.invalid ()
105+
and cass2a ~xt =
106+
(Xt.compare_and_set ~xt b 0 1 && Xt.get ~xt a == 0) || Retry.invalid ()
107+
and cass2b ~xt =
108+
(Xt.compare_and_set ~xt b 1 0 && Xt.get ~xt a == 0) || Retry.invalid ()
109+
in
110+
111+
let atomically tx =
112+
if Random.bool () then Xt.commit ~mode:Mode.obstruction_free tx
113+
else Xt.commit tx
114+
in
115+
116+
let thread1 () =
117+
Barrier.await barrier;
118+
while not !test_finished do
119+
if atomically { tx = cass1a } then
120+
while not (atomically { tx = cass1b }) do
121+
if is_single then Domain.cpu_relax ();
122+
assert (Loc.get a == 1 && Loc.get b == 0)
123+
done
124+
else if is_single then Domain.cpu_relax ()
125+
done
126+
and thread2 () =
127+
Barrier.await barrier;
128+
for _ = 1 to n_iter do
129+
if atomically { tx = cass2a } then
130+
while not (atomically { tx = cass2b }) do
131+
if is_single then Domain.cpu_relax ();
132+
assert (Loc.get a == 0 && Loc.get b == 1)
133+
done
134+
else if is_single then Domain.cpu_relax ()
135+
done;
136+
test_finished := true
137+
in
138+
139+
run_domains [ thread2; thread1 ]
87140

88141
(* *)
89142

@@ -649,6 +702,8 @@ let () =
649702
[
650703
( "non linearizable",
651704
[ Alcotest.test_case "" `Quick test_non_linearizable ] );
705+
( "non linearizable xt",
706+
[ Alcotest.test_case "" `Quick test_non_linearizable_xt ] );
652707
("set", [ Alcotest.test_case "" `Quick test_set ]);
653708
("casn", [ Alcotest.test_case "" `Quick test_casn ]);
654709
("read casn", [ Alcotest.test_case "" `Quick test_read_casn ]);

0 commit comments

Comments
 (0)