Skip to content

Commit 368ccee

Browse files
committed
Change Hashtbl config to be part of the transactional data
1 parent 7fefe99 commit 368ccee

File tree

3 files changed

+124
-89
lines changed

3 files changed

+124
-89
lines changed

src/kcas_data/hashtbl.ml

Lines changed: 115 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,17 @@ type ('k, 'v) pending =
7272
new_buckets : ('k, 'v) Assoc.t Loc.t array Loc.t;
7373
}
7474

75-
type ('k, 'v) t = {
76-
pending : ('k, 'v) pending Loc.t;
75+
type ('k, 'v) r = {
76+
pending : ('k, 'v) pending;
7777
length : Accumulator.t;
78-
buckets : ('k, 'v) Assoc.t Loc.t array Loc.t;
78+
buckets : ('k, 'v) Assoc.t Loc.t array;
7979
hash : 'k -> int;
8080
equal : 'k -> 'k -> bool;
8181
min_buckets : int;
8282
max_buckets : int;
8383
}
8484

85+
type ('k, 'v) t = ('k, 'v) r Loc.t
8586
type 'k hashed_type = (module Stdlib.Hashtbl.HashedType with type t = 'k)
8687

8788
let lo_buckets = 1 lsl 5
@@ -112,6 +113,7 @@ let create ?hashed_type ?min_buckets ?max_buckets ?n_way () =
112113
| None -> min_buckets_default
113114
| Some c -> Int.max lo_buckets c |> Int.min hi_buckets |> Bits.ceil_pow_2
114115
in
116+
let t = Loc.make (Obj.magic ()) in
115117
let max_buckets =
116118
match max_buckets with
117119
| None -> Int.max min_buckets max_buckets_default
@@ -120,16 +122,19 @@ let create ?hashed_type ?min_buckets ?max_buckets ?n_way () =
120122
match hashed_type with
121123
| None -> (Stdlib.Hashtbl.seeded_hash (Random.bits ()), ( = ))
122124
| Some hashed_type -> HashedType.unpack hashed_type
123-
and pending = Loc.make Nothing
124-
and buckets = Loc.make [||]
125+
and pending = Nothing
126+
and buckets = Loc.make_array min_buckets []
125127
and length = Accumulator.make ?n_way 0 in
126-
Loc.set buckets @@ Loc.make_array min_buckets [];
127-
{ pending; length; buckets; hash; equal; min_buckets; max_buckets }
128+
Loc.set t { pending; length; buckets; hash; equal; min_buckets; max_buckets };
129+
t
130+
131+
let n_way_of t = Accumulator.n_way_of (Loc.get t).length
132+
let min_buckets_of t = (Loc.get t).min_buckets
133+
let max_buckets_of t = (Loc.get t).max_buckets
128134

129-
let n_way_of t = Accumulator.n_way_of t.length
130-
let min_buckets_of t = t.min_buckets
131-
let max_buckets_of t = t.max_buckets
132-
let hashed_type_of t = HashedType.pack t.hash t.equal
135+
let hashed_type_of t =
136+
let r = Loc.get t in
137+
HashedType.pack r.hash r.equal
133138

134139
let bucket_of hash key buckets =
135140
Array.unsafe_get buckets (hash key land (Array.length buckets - 1))
@@ -138,16 +143,16 @@ exception Done
138143

139144
module Xt = struct
140145
let find_opt ~xt t k =
141-
Xt.get ~xt t.buckets |> bucket_of t.hash k |> Xt.get ~xt
142-
|> Assoc.find_opt t.equal k
146+
let r = Xt.get ~xt t in
147+
r.buckets |> bucket_of r.hash k |> Xt.get ~xt |> Assoc.find_opt r.equal k
143148

144149
let find_all ~xt t k =
145-
Xt.get ~xt t.buckets |> bucket_of t.hash k |> Xt.get ~xt
146-
|> Assoc.find_all t.equal k
150+
let r = Xt.get ~xt t in
151+
r.buckets |> bucket_of r.hash k |> Xt.get ~xt |> Assoc.find_all r.equal k
147152

148153
let mem ~xt t k =
149-
Xt.get ~xt t.buckets |> bucket_of t.hash k |> Xt.get ~xt
150-
|> Assoc.mem t.equal k
154+
let r = Xt.get ~xt t in
155+
r.buckets |> bucket_of r.hash k |> Xt.get ~xt |> Assoc.mem r.equal k
151156

152157
let get_or_alloc array_loc alloc =
153158
let tx ~xt =
@@ -167,15 +172,18 @@ module Xt = struct
167172
(* TODO: Implement pending operations such that multiple domains may be
168173
working to complete them in parallel by extending the [state] to an array
169174
of multiple partition [states]. *)
170-
let must_be_done_in_this_tx = Xt.is_in_log ~xt t.pending in
171-
match Xt.exchange ~xt t.pending Nothing with
172-
| Nothing -> ()
175+
let must_be_done_in_this_tx = Xt.is_in_log ~xt t in
176+
let r = Xt.get ~xt t in
177+
match r.pending with
178+
| Nothing -> r
173179
| Rehash { state; new_capacity; new_buckets } -> (
174180
let new_buckets =
175181
get_or_alloc new_buckets @@ fun () -> Loc.make_array new_capacity []
176182
in
177-
let old_buckets = Xt.exchange ~xt t.buckets new_buckets in
178-
let hash = t.hash and mask = new_capacity - 1 in
183+
let old_buckets = r.buckets in
184+
let r = { r with pending = Nothing; buckets = new_buckets } in
185+
Xt.set ~xt t r;
186+
let hash = r.hash and mask = new_capacity - 1 in
179187
let rehash_a_few_buckets ~xt =
180188
(* We process buckets in descending order as that is slightly faster
181189
with the transaction log. It also makes sure that we know when the
@@ -211,11 +219,14 @@ module Xt = struct
211219
at a time. This gives expected linear time, O(n). *)
212220
while true do
213221
Xt.commit { tx = rehash_a_few_buckets }
214-
done
215-
with Done -> ())
222+
done;
223+
r
224+
with Done -> r)
216225
| Snapshot { state; snapshot } -> (
217226
assert (not must_be_done_in_this_tx);
218-
let buckets = Xt.get ~xt t.buckets in
227+
let buckets = r.buckets in
228+
let r = { r with pending = Nothing } in
229+
Xt.set ~xt t r;
219230
(* Check state to ensure that buckets have not been updated. *)
220231
if Loc.fenceless_get state < 0 then Retry.invalid ();
221232
let snapshot =
@@ -233,11 +244,12 @@ module Xt = struct
233244
try
234245
while true do
235246
Xt.commit { tx = snapshot_a_few_buckets }
236-
done
237-
with Done -> ())
247+
done;
248+
r
249+
with Done -> r)
238250
| Filter_map { state; fn; raised; new_buckets } -> (
239251
assert (not must_be_done_in_this_tx);
240-
let old_buckets = Xt.get ~xt t.buckets in
252+
let old_buckets = r.buckets in
241253
(* Check state to ensure that buckets have not been updated. *)
242254
if Loc.fenceless_get state < 0 then Retry.invalid ();
243255
let new_capacity = Array.length old_buckets in
@@ -260,108 +272,122 @@ module Xt = struct
260272
while true do
261273
total_delta :=
262274
!total_delta + Xt.commit { tx = filter_map_a_few_buckets }
263-
done
275+
done;
276+
r
264277
with
265278
| Done ->
266-
Accumulator.Xt.add ~xt t.length !total_delta;
267-
Xt.set ~xt t.buckets new_buckets
268-
| exn -> Loc.compare_and_set raised Done exn |> ignore)
279+
Accumulator.Xt.add ~xt r.length !total_delta;
280+
let r = { r with pending = Nothing; buckets = new_buckets } in
281+
Xt.set ~xt t r;
282+
r
283+
| exn ->
284+
Loc.compare_and_set raised Done exn |> ignore;
285+
let r = { r with pending = Nothing } in
286+
Xt.set ~xt t r;
287+
r)
269288

270289
let make_rehash old_capacity new_capacity =
271290
let state = Loc.make old_capacity and new_buckets = Loc.make [||] in
272291
Rehash { state; new_capacity; new_buckets }
292+
[@@inline]
273293

274294
let reset ~xt t =
275-
perform_pending ~xt t;
276-
Xt.set ~xt t.buckets [||];
277-
Accumulator.Xt.set ~xt t.length 0;
278-
Xt.set ~xt t.pending @@ make_rehash 0 t.min_buckets
295+
let r = perform_pending ~xt t in
296+
Accumulator.Xt.set ~xt r.length 0;
297+
Xt.set ~xt t
298+
{ r with pending = make_rehash 0 r.min_buckets; buckets = [||] }
279299

280300
let clear ~xt t = reset ~xt t
281301

282302
let remove ~xt t k =
283-
perform_pending ~xt t;
284-
let buckets = Xt.get ~xt t.buckets in
303+
let r = perform_pending ~xt t in
304+
let buckets = r.buckets in
285305
let mask = Array.length buckets - 1 in
286-
let bucket = Array.unsafe_get buckets (t.hash k land mask) in
306+
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
287307
let change = ref `None in
288308
Xt.unsafe_modify ~xt bucket (fun kvs ->
289-
let kvs' = Assoc.remove t.equal change k kvs in
309+
let kvs' = Assoc.remove r.equal change k kvs in
290310
if !change != `None then kvs' else kvs);
291311
if !change == `Removed then (
292-
Accumulator.Xt.decr ~xt t.length;
293-
if t.min_buckets <= mask && Random.bits () land mask = 0 then
312+
Accumulator.Xt.decr ~xt r.length;
313+
if r.min_buckets <= mask && Random.bits () land mask = 0 then
294314
let capacity = mask + 1 in
295-
let length = Accumulator.Xt.get ~xt t.length in
315+
let length = Accumulator.Xt.get ~xt r.length in
296316
if length * 4 < capacity then
297-
Xt.set ~xt t.pending @@ make_rehash capacity (capacity asr 1))
317+
Xt.set ~xt t
318+
{ r with pending = make_rehash capacity (capacity asr 1) })
298319

299320
let add ~xt t k v =
300-
perform_pending ~xt t;
301-
let buckets = Xt.get ~xt t.buckets in
321+
let r = perform_pending ~xt t in
322+
let buckets = r.buckets in
302323
let mask = Array.length buckets - 1 in
303-
let bucket = Array.unsafe_get buckets (t.hash k land mask) in
324+
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
304325
Xt.unsafe_modify ~xt bucket (List.cons (k, v));
305-
Accumulator.Xt.incr ~xt t.length;
306-
if mask + 1 < t.max_buckets && Random.bits () land mask = 0 then
326+
Accumulator.Xt.incr ~xt r.length;
327+
if mask + 1 < r.max_buckets && Random.bits () land mask = 0 then
307328
let capacity = mask + 1 in
308-
let length = Accumulator.Xt.get ~xt t.length in
329+
let length = Accumulator.Xt.get ~xt r.length in
309330
if capacity < length then
310-
Xt.set ~xt t.pending @@ make_rehash capacity (capacity * 2)
331+
Xt.set ~xt t { r with pending = make_rehash capacity (capacity * 2) }
311332

312333
let replace ~xt t k v =
313-
perform_pending ~xt t;
314-
let buckets = Xt.get ~xt t.buckets in
334+
let r = perform_pending ~xt t in
335+
let buckets = r.buckets in
315336
let mask = Array.length buckets - 1 in
316-
let bucket = Array.unsafe_get buckets (t.hash k land mask) in
337+
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
317338
let change = ref `None in
318339
Xt.unsafe_modify ~xt bucket (fun kvs ->
319-
let kvs' = Assoc.replace t.equal change k v kvs in
340+
let kvs' = Assoc.replace r.equal change k v kvs in
320341
if !change != `None then kvs' else kvs);
321342
if !change == `Added then (
322-
Accumulator.Xt.incr ~xt t.length;
323-
if mask + 1 < t.max_buckets && Random.bits () land mask = 0 then
343+
Accumulator.Xt.incr ~xt r.length;
344+
if mask + 1 < r.max_buckets && Random.bits () land mask = 0 then
324345
let capacity = mask + 1 in
325-
let length = Accumulator.Xt.get ~xt t.length in
346+
let length = Accumulator.Xt.get ~xt r.length in
326347
if capacity < length then
327-
Xt.set ~xt t.pending @@ make_rehash capacity (capacity * 2))
348+
Xt.set ~xt t { r with pending = make_rehash capacity (capacity * 2) })
328349

329-
let length ~xt t = Accumulator.Xt.get ~xt t.length
350+
let length ~xt t = Accumulator.Xt.get ~xt (Xt.get ~xt t).length
351+
let swap = Xt.swap
330352
end
331353

332354
let find_opt t k =
333-
Loc.get t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
355+
let t = Loc.get t in
356+
t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
334357
|> Assoc.find_opt t.equal k
335358

336359
let find_all t k =
337-
Loc.get t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
360+
let t = Loc.get t in
361+
t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
338362
|> Assoc.find_all t.equal k
339363

340364
let find t k = match find_opt t k with None -> raise Not_found | Some v -> v
341365

342366
let mem t k =
343-
Loc.get t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
344-
|> Assoc.mem t.equal k
367+
let t = Loc.get t in
368+
t.buckets |> bucket_of t.hash k |> Loc.fenceless_get |> Assoc.mem t.equal k
345369

346370
let clear t = Kcas.Xt.commit { tx = Xt.clear t }
347371
let reset t = Kcas.Xt.commit { tx = Xt.reset t }
348372
let remove t k = Kcas.Xt.commit { tx = Xt.remove t k }
349373
let add t k v = Kcas.Xt.commit { tx = Xt.add t k v }
350374
let replace t k v = Kcas.Xt.commit { tx = Xt.replace t k v }
351-
let length t = Accumulator.get t.length
375+
let length t = Accumulator.get (Loc.get t).length
376+
let swap t1 t2 = Kcas.Xt.commit { tx = Xt.swap t1 t2 }
352377

353-
let snapshot ?length t =
378+
let snapshot ?length ?record t =
354379
let state = Loc.make 0 and snapshot = Loc.make [||] in
355380
let pending = Snapshot { state; snapshot } in
356381
let tx ~xt =
357-
Xt.perform_pending ~xt t;
382+
let r = Xt.perform_pending ~xt t in
358383
length
359-
|> Option.iter (fun length -> length := Accumulator.Xt.get ~xt t.length);
360-
Loc.set state (Array.length (Kcas.Xt.get ~xt t.buckets));
361-
Kcas.Xt.set ~xt t.pending pending
384+
|> Option.iter (fun length -> length := Accumulator.Xt.get ~xt r.length);
385+
record |> Option.iter (fun record -> record := r);
386+
Loc.set state (Array.length r.buckets);
387+
Kcas.Xt.set ~xt t { r with pending }
362388
in
363389
Kcas.Xt.commit { tx };
364-
Kcas.Xt.commit { tx = Xt.perform_pending t };
390+
Kcas.Xt.commit { tx = Xt.perform_pending t } |> ignore;
365391
Loc.fenceless_get snapshot
366392

367393
let to_seq t =
@@ -384,29 +410,33 @@ let of_seq ?hashed_type ?min_buckets ?max_buckets ?n_way xs =
384410
t
385411

386412
let rebuild ?hashed_type ?min_buckets ?max_buckets ?n_way t =
413+
let record = ref (Obj.magic ()) and length = ref 0 in
414+
let snapshot = snapshot ~length ~record t in
415+
let r = !record in
387416
let min_buckets =
388417
match min_buckets with
389-
| None -> min_buckets_of t
418+
| None -> r.min_buckets
390419
| Some c -> Int.max lo_buckets c |> Int.min hi_buckets |> Bits.ceil_pow_2
391420
in
392421
let max_buckets =
393422
match max_buckets with
394-
| None -> Int.max min_buckets (max_buckets_of t)
423+
| None -> Int.max min_buckets r.max_buckets
395424
| Some c -> Int.max min_buckets c |> Int.min hi_buckets |> Bits.ceil_pow_2
396-
and n_way = match n_way with None -> n_way_of t | Some n -> n
397-
and length = ref 0 in
398-
let snapshot = snapshot ~length t in
425+
and n_way =
426+
match n_way with None -> Accumulator.n_way_of r.length | Some n -> n
427+
in
399428
let is_same_hashed_type =
400429
match hashed_type with
401430
| None -> true
402-
| Some hashed_type -> HashedType.is_same_as t.hash t.equal hashed_type
431+
| Some hashed_type -> HashedType.is_same_as r.hash r.equal hashed_type
403432
and length = !length in
404433
if is_same_hashed_type && min_buckets <= length && length <= max_buckets then (
405-
let pending = Loc.make Nothing
406-
and buckets = Loc.make [||]
434+
let t = Loc.make (Obj.magic ()) in
435+
let pending = Nothing
436+
and buckets = Array.map Loc.make snapshot
407437
and length = Accumulator.make ~n_way length in
408-
Loc.set buckets @@ Array.map Loc.make snapshot;
409-
{ t with pending; length; buckets; min_buckets; max_buckets })
438+
Loc.set t { r with pending; length; buckets; min_buckets; max_buckets };
439+
t)
410440
else
411441
let t = create ?hashed_type ~min_buckets ~max_buckets ~n_way () in
412442
snapshot
@@ -427,12 +457,12 @@ let filter_map_inplace fn t =
427457
and new_buckets = Loc.make [||] in
428458
let pending = Filter_map { state; fn; raised; new_buckets } in
429459
let tx ~xt =
430-
Xt.perform_pending ~xt t;
431-
Loc.set state (Array.length (Kcas.Xt.get ~xt t.buckets));
432-
Kcas.Xt.set ~xt t.pending pending
460+
let r = Xt.perform_pending ~xt t in
461+
Loc.set state (Array.length r.buckets);
462+
Kcas.Xt.set ~xt t { r with pending }
433463
in
434464
Kcas.Xt.commit { tx };
435-
Kcas.Xt.commit { tx = Xt.perform_pending t };
465+
Kcas.Xt.commit { tx = Xt.perform_pending t } |> ignore;
436466
match Loc.fenceless_get raised with Done -> () | exn -> raise exn
437467

438468
let stats t =

src/kcas_data/hashtbl_intf.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ module type Ops = sig
1515
val clear : ('x, ('k, 'v) t -> unit) fn
1616
(** [clear] is a synonym for {!reset}. *)
1717

18+
val swap : ('x, ('k, 'v) t -> ('k, 'v) t -> unit) fn
19+
(** [swap t1 t2] exchanges the contents of the hash tables [t1] and [t2]. *)
20+
1821
val remove : ('x, ('k, 'v) t -> 'k -> unit) fn
1922
(** [remove t k] removes the most recent existing binding of key [k], if any,
2023
from the hash table [t] thereby revealing the earlier binding of [k], if

test/kcas_data/hashtbl_test.ml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ let () =
5555
assert (
5656
Hashtbl.to_seq t |> List.of_seq = [ ("key", 3); ("key", 2); ("key", 1) ]);
5757
let u = Hashtbl.to_seq t |> Hashtbl.of_seq in
58-
assert (Hashtbl.find u "key" = 1);
59-
assert (Hashtbl.find t "key" = 3);
60-
Hashtbl.filter_map_inplace (fun _ v -> if v = 1 then None else Some (-v)) t;
61-
assert (Hashtbl.find_all t "key" = [ -3; -2 ]);
58+
Hashtbl.swap t u;
59+
assert (Hashtbl.find t "key" = 1);
60+
assert (Hashtbl.find u "key" = 3);
61+
Hashtbl.filter_map_inplace (fun _ v -> if v = 1 then None else Some (-v)) u;
62+
assert (Hashtbl.find_all u "key" = [ -3; -2 ]);
63+
Hashtbl.swap u t;
6264
assert (Hashtbl.length t = 2);
6365
(match
6466
Hashtbl.filter_map_inplace

0 commit comments

Comments
 (0)