Skip to content

Commit 845a2be

Browse files
authored
Add support for custom hash functions via context type (#13)
* feat: context support for hash_set * chore: add tests * chore: feature update to readme * fix: add context example to readme * fix: self ref and test context rename * fix: improve max_load_percentage handling * chore: improve sizeOf tests * fix: example in readme
1 parent fc2aa43 commit 845a2be

File tree

3 files changed

+299
-7
lines changed

3 files changed

+299
-7
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ This implementation gives credit and acknowledgement to the [Zig language](https
5959
* pop
6060
* Fully documented and robustly tested - in progress
6161
* Performance aware to minimize unecessary allocs/iteration internally
62+
* Custom hash function support
6263
* "string" support - coming soon
6364
* Benchmarks - coming soon
6465
#
@@ -116,6 +117,37 @@ Output of `A | B` - the union of A and B (order is not guaranteed)
116117

117118
#
118119

120+
#### Custom Hash Function
121+
122+
To use a custom hash function, you can use the following types:
123+
124+
- `HashSetUnmanagedWithContext`
125+
- `HashSetManagedWithContext`
126+
127+
Example:
128+
129+
```zig
130+
const SimpleHasher = struct {
131+
const Self = @This();
132+
pub fn hash(_: Self, key: u32) u64 {
133+
return @as(u64, key) *% 0x517cc1b727220a95;
134+
}
135+
pub fn eql(_: Self, a: u32, b: u32) bool {
136+
return a == b;
137+
}
138+
};
139+
140+
const ctx = SimpleHasher{};
141+
var set = HashSetUnmanagedWithContext(u32, SimpleHasher, 75).initContext(ctx);
142+
defer set.deinit(testing.allocator);
143+
144+
_ = try set.add(testing.allocator, 123);
145+
try expect(set.contains(123));
146+
try expect(!set.contains(456));
147+
```
148+
149+
#
150+
119151
#### Installation of Module
120152

121153
To add this module, update your applications build.zig.zon file by adding the `.ziglang-set` dependency definition.

src/hash_set/managed.zig

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,39 @@ const std = @import("std");
2323
const mem = std.mem;
2424
const Allocator = mem.Allocator;
2525
const SetUnmanaged = @import("unmanaged.zig").HashSetUnmanaged;
26+
const SetUnmanagedWithContext = @import("unmanaged.zig").HashSetUnmanagedWithContext;
2627

2728
/// fn HashSetManaged(E) creates a set based on element type E.
2829
/// This implementation is backed by the std.AutoHashMap implementation
2930
/// where a Value is not needed and considered to be void and
3031
/// a Key is considered to be a Set element of type E.
3132
/// The Set comes complete with the common set operations expected
3233
/// in a comprehensive set-based data-structure.
34+
/// Note that max_load_percentage is passed as undefined, because the underlying
35+
/// std.AutoHashMap/std.StringHashMap defaults are used.
3336
pub fn HashSetManaged(comptime E: type) type {
37+
return HashSetManagedWithContext(E, void, undefined);
38+
}
39+
40+
/// HashSetManagedWithContext creates a set based on element type E with custom hashing behavior.
41+
/// This variant allows specifying:
42+
/// - A Context type that implements hash() and eql() functions for custom element hashing
43+
/// - A max_load_percentage (1-100) that controls hash table resizing
44+
/// If Context is undefined, then max_load_percentage is ignored.
45+
///
46+
/// The Context type must provide:
47+
/// fn hash(self: Context, key: K) u64
48+
/// fn eql(self: Context, a: K, b: K) bool
49+
pub fn HashSetManagedWithContext(comptime E: type, comptime Context: type, comptime max_load_percentage: u8) type {
3450
return struct {
3551
allocator: std.mem.Allocator,
3652

3753
map: Map,
54+
context: if (Context == void) void else Context = if (Context == void) {} else undefined,
55+
max_load_percentage: if (Context == void) void else u8 = if (Context == void) {} else max_load_percentage,
3856

3957
/// The type of the internal hash map
40-
pub const Map = SetUnmanaged(E); //selectMap(E);
41-
//pub const Map = std.AutoHashMap(E, void);
42-
/// The integer type used to store the size of the map, borrowed from map
58+
pub const Map = SetUnmanagedWithContext(E, Context, max_load_percentage);
4359
pub const Size = Map.Size;
4460
/// The iterator type returned by iterator(), key-only for sets
4561
pub const Iterator = Map.Iterator;
@@ -51,6 +67,17 @@ pub fn HashSetManaged(comptime E: type) type {
5167
return .{
5268
.allocator = allocator,
5369
.map = Map.init(),
70+
.context = if (Context == void) {} else undefined,
71+
.max_load_percentage = if (Context == void) {} else max_load_percentage,
72+
};
73+
}
74+
75+
pub fn initContext(allocator: std.mem.Allocator, context: Context) Self {
76+
return .{
77+
.allocator = allocator,
78+
.map = Map.initContext(context),
79+
.context = context,
80+
.max_load_percentage = max_load_percentage,
5481
};
5582
}
5683

@@ -835,9 +862,37 @@ test "in-place methods" {
835862
test "sizeOf" {
836863
const unmanagedSize = @sizeOf(SetUnmanaged(u32));
837864
const managedSize = @sizeOf(HashSetManaged(u32));
865+
const managedWithVoidContextSize = @sizeOf(HashSetManagedWithContext(u32, void, undefined));
866+
const managedWithContextSize = @sizeOf(HashSetManagedWithContext(u32, TestContext, 75));
838867

839868
// The managed must be only 16 bytes larger, the cost of the internal allocator
840869
// otherwise we've added some CRAP!
841870
const expectedDiff = 16;
842871
try expectEqual(expectedDiff, managedSize - unmanagedSize);
872+
873+
// The managed with void context must be the same size as the managed.
874+
// The managed with context must be larger by the size of the Context type,
875+
// due to the added Context + allocator and alignment padding.
876+
const expectedContextDiff = 16;
877+
try expectEqual(expectedDiff, managedWithVoidContextSize - unmanagedSize);
878+
try expectEqual(expectedContextDiff, managedWithContextSize - managedSize);
879+
}
880+
881+
const TestContext = struct {
882+
const Self = @This();
883+
pub fn hash(_: Self, key: u32) u64 {
884+
return @as(u64, key) *% 0x517cc1b727220a95;
885+
}
886+
pub fn eql(_: Self, a: u32, b: u32) bool {
887+
return a == b;
888+
}
889+
};
890+
891+
test "custom hash function" {
892+
const context = TestContext{};
893+
var set = HashSetManagedWithContext(u32, TestContext, 75).initContext(testing.allocator, context);
894+
defer set.deinit();
895+
896+
_ = try set.add(123);
897+
try expect(set.contains(123));
843898
}

src/hash_set/unmanaged.zig

Lines changed: 209 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
///
2121
const std = @import("std");
2222
const mem = std.mem;
23+
const math = std.math;
2324
const Allocator = mem.Allocator;
2425

2526
/// comptime selection of the map type for string vs everything else.
@@ -33,27 +34,60 @@ fn selectMap(comptime E: type) type {
3334
}
3435
}
3536

37+
/// Select a context-aware hash map type
38+
fn selectMapWithContext(comptime E: type, comptime Context: type, comptime max_load_percentage: u8) type {
39+
return std.HashMapUnmanaged(E, void, Context, max_load_percentage);
40+
}
41+
3642
/// HashSetUnmanaged is an implementation of a Set where there is no internal
3743
/// allocator and all allocating methods require a first argument allocator.
3844
/// This is a more compact Set built on top of the the HashMapUnmanaged
3945
/// datastructure.
46+
/// Note that max_load_percentage defaults to undefined, because the underlying
47+
/// std.AutoHashMap/std.StringHashMap defaults are used.
4048
pub fn HashSetUnmanaged(comptime E: type) type {
49+
return HashSetUnmanagedWithContext(E, void, undefined);
50+
}
51+
52+
/// HashSetUnmanagedWithContext creates a set based on element type E with custom hashing behavior.
53+
/// This variant allows specifying:
54+
/// - A Context type that implements hash() and eql() functions for custom element hashing
55+
/// - A max_load_percentage (1-100) that controls hash table resizing
56+
/// If Context is undefined, then max_load_percentage is ignored.
57+
///
58+
/// The Context type must provide:
59+
/// fn hash(self: Context, key: K) u64
60+
/// fn eql(self: Context, a: K, b: K) bool
61+
pub fn HashSetUnmanagedWithContext(comptime E: type, comptime Context: type, comptime max_load_percentage: u8) type {
4162
return struct {
4263
/// The type of the internal hash map
43-
pub const Map = selectMap(E);
64+
pub const Map = if (Context == void) selectMap(E) else selectMapWithContext(E, Context, max_load_percentage);
4465

4566
unmanaged: Map,
67+
context: if (Context == void) void else Context = if (Context == void) {} else undefined,
68+
max_load_percentage: if (Context == void) void else u8 = if (Context == void) {} else max_load_percentage,
4669

4770
pub const Size = Map.Size;
4871
/// The iterator type returned by iterator(), key-only for sets
4972
pub const Iterator = Map.KeyIterator;
5073

5174
const Self = @This();
5275

53-
/// Initialzies a Set with the given Allocator
76+
/// Initialize a default set without context
5477
pub fn init() Self {
5578
return .{
5679
.unmanaged = Map{},
80+
.context = if (Context == void) {} else undefined,
81+
.max_load_percentage = if (Context == void) {} else max_load_percentage,
82+
};
83+
}
84+
85+
/// Initialize with a custom context
86+
pub fn initContext(context: Context) Self {
87+
return .{
88+
.unmanaged = Map{},
89+
.context = context,
90+
.max_load_percentage = max_load_percentage,
5791
};
5892
}
5993

@@ -765,6 +799,177 @@ test "sizeOf matches" {
765799
// No bloat guarantee, after all we're just building on top of what's good.
766800
// "What's good Miley!?!?""
767801
const expectedByteSize = 24;
768-
try expectEqual(expectedByteSize, @sizeOf(std.hash_map.AutoHashMapUnmanaged(u32, void)));
769-
try expectEqual(expectedByteSize, @sizeOf(HashSetUnmanaged(u32)));
802+
const autoHashMapSize = @sizeOf(std.hash_map.AutoHashMapUnmanaged(u32, void));
803+
const hashSetSize = @sizeOf(HashSetUnmanaged(u32));
804+
try expectEqual(expectedByteSize, autoHashMapSize);
805+
try expectEqual(expectedByteSize, hashSetSize);
806+
807+
// The unmanaged with void context must be the same size as the unmanaged.
808+
// The unmanaged with context must be larger by the size of the empty Context struct,
809+
// due to the added Context and alignment padding.
810+
const expectedContextDiff = 8;
811+
const hashSetWithVoidContextSize = @sizeOf(HashSetUnmanagedWithContext(u32, void, undefined));
812+
const hashSetWithContextSize = @sizeOf(HashSetUnmanagedWithContext(u32, TestContext, 75));
813+
try expectEqual(0, hashSetWithVoidContextSize - hashSetSize);
814+
try expectEqual(expectedContextDiff, hashSetWithContextSize - hashSetSize);
815+
}
816+
817+
const TestContext = struct {
818+
const Self = @This();
819+
pub fn hash(_: Self, key: u32) u64 {
820+
return @as(u64, key) *% 0x517cc1b727220a95;
821+
}
822+
pub fn eql(_: Self, a: u32, b: u32) bool {
823+
return a == b;
824+
}
825+
};
826+
827+
test "custom hash function comprehensive" {
828+
const context = TestContext{};
829+
var set = HashSetUnmanagedWithContext(u32, TestContext, 75).initContext(context);
830+
defer set.deinit(testing.allocator);
831+
832+
// Test basic operations
833+
_ = try set.add(testing.allocator, 123);
834+
_ = try set.add(testing.allocator, 456);
835+
try expect(set.contains(123));
836+
try expect(set.contains(456));
837+
try expect(!set.contains(789));
838+
try expectEqual(set.cardinality(), 2);
839+
840+
// Test clone with custom context
841+
var cloned = try set.clone(testing.allocator);
842+
defer cloned.deinit(testing.allocator);
843+
try expect(cloned.contains(123));
844+
try expect(set.eql(cloned));
845+
846+
// Test set operations with custom context
847+
var other = HashSetUnmanagedWithContext(u32, TestContext, 75).initContext(context);
848+
defer other.deinit(testing.allocator);
849+
_ = try other.add(testing.allocator, 456);
850+
_ = try other.add(testing.allocator, 789);
851+
852+
// Test union
853+
var union_set = try set.unionOf(testing.allocator, other);
854+
defer union_set.deinit(testing.allocator);
855+
try expectEqual(union_set.cardinality(), 3);
856+
try expect(union_set.containsAllSlice(&.{ 123, 456, 789 }));
857+
858+
// Test intersection
859+
var intersection = try set.intersectionOf(testing.allocator, other);
860+
defer intersection.deinit(testing.allocator);
861+
try expectEqual(intersection.cardinality(), 1);
862+
try expect(intersection.contains(456));
863+
864+
// Test difference
865+
var difference = try set.differenceOf(testing.allocator, other);
866+
defer difference.deinit(testing.allocator);
867+
try expectEqual(difference.cardinality(), 1);
868+
try expect(difference.contains(123));
869+
870+
// Test symmetric difference
871+
var sym_diff = try set.symmetricDifferenceOf(testing.allocator, other);
872+
defer sym_diff.deinit(testing.allocator);
873+
try expectEqual(sym_diff.cardinality(), 2);
874+
try expect(sym_diff.containsAllSlice(&.{ 123, 789 }));
875+
876+
// Test in-place operations
877+
try set.unionUpdate(testing.allocator, other);
878+
try expectEqual(set.cardinality(), 3);
879+
try expect(set.containsAllSlice(&.{ 123, 456, 789 }));
880+
}
881+
882+
test "custom hash function with different load factors" {
883+
const context = TestContext{};
884+
885+
// Test with low load factor
886+
var low_load = HashSetUnmanagedWithContext(u32, TestContext, 25).initContext(context);
887+
defer low_load.deinit(testing.allocator);
888+
889+
// Test with high load factor
890+
var high_load = HashSetUnmanagedWithContext(u32, TestContext, 90).initContext(context);
891+
defer high_load.deinit(testing.allocator);
892+
893+
// Add same elements to both
894+
for (0..100) |i| {
895+
_ = try low_load.add(testing.allocator, @intCast(i));
896+
_ = try high_load.add(testing.allocator, @intCast(i));
897+
}
898+
899+
// Verify functionality is identical despite different load factors
900+
try expectEqual(low_load.cardinality(), high_load.cardinality());
901+
try expect(low_load.capacity() != high_load.capacity()); // Should be different due to load factors
902+
903+
// Verify both sets contain the same elements
904+
for (0..100) |i| {
905+
const val: u32 = @intCast(i);
906+
try expect(low_load.contains(val) and high_load.contains(val));
907+
}
908+
}
909+
910+
test "custom hash function error cases" {
911+
const context = TestContext{};
912+
var set = HashSetUnmanagedWithContext(u32, TestContext, 75).initContext(context);
913+
defer set.deinit(testing.allocator);
914+
915+
// Test allocation failures
916+
var failing_allocator = std.testing.FailingAllocator.init(testing.allocator, .{ .fail_index = 0 });
917+
try std.testing.expectError(error.OutOfMemory, set.add(failing_allocator.allocator(), 123));
918+
}
919+
920+
// String context for testing string usage with custom hash function
921+
const StringContext = struct {
922+
pub fn hash(self: @This(), str: []const u8) u64 {
923+
_ = self;
924+
// Simple FNV-1a hash
925+
var h: u64 = 0xcbf29ce484222325;
926+
for (str) |b| {
927+
h = (h ^ b) *% 0x100000001b3;
928+
}
929+
return h;
930+
}
931+
932+
pub fn eql(self: @This(), a: []const u8, b: []const u8) bool {
933+
_ = self;
934+
return std.mem.eql(u8, a, b);
935+
}
936+
};
937+
938+
test "custom hash function string usage" {
939+
const context = StringContext{};
940+
var A = HashSetUnmanagedWithContext([]const u8, StringContext, 75).initContext(context);
941+
defer A.deinit(testing.allocator);
942+
943+
var B = HashSetUnmanagedWithContext([]const u8, StringContext, 75).initContext(context);
944+
defer B.deinit(testing.allocator);
945+
946+
_ = try A.add(testing.allocator, "Hello");
947+
_ = try B.add(testing.allocator, "World");
948+
949+
var C = try A.unionOf(testing.allocator, B);
950+
defer C.deinit(testing.allocator);
951+
try expectEqual(2, C.cardinality());
952+
try expect(C.containsAllSlice(&.{ "Hello", "World" }));
953+
954+
// Test string-specific behavior
955+
try expect(A.contains("Hello"));
956+
try expect(!A.contains("hello")); // Case sensitive
957+
try expect(!A.contains("Hell")); // Prefix doesn't match
958+
try expect(!A.contains("Hello ")); // Trailing space matters
959+
960+
// Test with longer strings
961+
_ = try A.add(testing.allocator, "This is a longer string to test hash collisions");
962+
_ = try A.add(testing.allocator, "This is another longer string to test hash collisions");
963+
try expectEqual(3, A.cardinality());
964+
965+
// Test with empty string
966+
_ = try A.add(testing.allocator, "");
967+
try expect(A.contains(""));
968+
try expectEqual(4, A.cardinality());
969+
970+
// Test with strings containing special characters
971+
_ = try A.add(testing.allocator, "Hello\n");
972+
_ = try A.add(testing.allocator, "Hello\r");
973+
_ = try A.add(testing.allocator, "Hello\t");
974+
try expectEqual(7, A.cardinality());
770975
}

0 commit comments

Comments
 (0)