Skip to content

Commit 3045cfb

Browse files
flrdvdeckarep
authored andcommitted
less pointers and more memory-efficient Clear()
`threadUnsafeSet` is already a map that is a reference-type, so passing it by-value still lets us mutate an origin object, but decreases a number of pointers, which in turn, decreases a GC-pressure, (potentially) makes less useless operations and (definitely) makes code better to read/understand. Another point of the change is to a bit "upgrade" a `Clear()` method. In the old implementation, it simply constructed a new `threadUnsafeSet` and did mutate an origin pointer to point at our new set. This commits an extra allocation (I would expect a `Clear()` method to be allocations-free), and fallbacks already allocated map's size to the default one. So I replaced this with clearing itself using `mapclear()` function (that implicitly replaces `for key := range d { delete(d, key) }`). This is actually pretty expensive operation, but does no allocations, that is maybe even cheaper than vice-versa. Anyway user always can construct a new instance of the set by itself, but cannot clear underlying map manually in case he really needs it - to decrease amount of memory used in average, number of which is actually affected a lot by internally allocating a new map
1 parent 543b3d7 commit 3045cfb

File tree

1 file changed

+73
-67
lines changed

1 file changed

+73
-67
lines changed

threadunsafe.go

Lines changed: 73 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -35,138 +35,143 @@ import (
3535
type threadUnsafeSet[T comparable] map[T]struct{}
3636

3737
// Assert concrete type:threadUnsafeSet adheres to Set interface.
38-
var _ Set[string] = (*threadUnsafeSet[string])(nil)
38+
var _ Set[string] = (threadUnsafeSet[string])(nil)
3939

4040
func newThreadUnsafeSet[T comparable]() threadUnsafeSet[T] {
4141
return make(threadUnsafeSet[T])
4242
}
4343

44-
func (s *threadUnsafeSet[T]) Add(v T) bool {
45-
prevLen := len(*s)
46-
(*s)[v] = struct{}{}
47-
return prevLen != len(*s)
44+
func (s threadUnsafeSet[T]) Add(v T) bool {
45+
prevLen := len(s)
46+
s[v] = struct{}{}
47+
return prevLen != len(s)
4848
}
4949

5050
// private version of Add which doesn't return a value
51-
func (s *threadUnsafeSet[T]) add(v T) {
52-
(*s)[v] = struct{}{}
51+
func (s threadUnsafeSet[T]) add(v T) {
52+
s[v] = struct{}{}
5353
}
5454

55-
func (s *threadUnsafeSet[T]) Cardinality() int {
56-
return len(*s)
55+
func (s threadUnsafeSet[T]) Cardinality() int {
56+
return len(s)
5757
}
5858

59-
func (s *threadUnsafeSet[T]) Clear() {
60-
*s = newThreadUnsafeSet[T]()
59+
func (s threadUnsafeSet[T]) Clear() {
60+
// Constructions like this are optimised by compiler, and replaced by
61+
// mapclear() function, defined in
62+
// https://github.com/golang/go/blob/29bbca5c2c1ad41b2a9747890d183b6dd3a4ace4/src/runtime/map.go#L993)
63+
for key := range s {
64+
delete(s, key)
65+
}
6166
}
6267

63-
func (s *threadUnsafeSet[T]) Clone() Set[T] {
68+
func (s threadUnsafeSet[T]) Clone() Set[T] {
6469
clonedSet := make(threadUnsafeSet[T], s.Cardinality())
65-
for elem := range *s {
70+
for elem := range s {
6671
clonedSet.add(elem)
6772
}
68-
return &clonedSet
73+
return clonedSet
6974
}
7075

71-
func (s *threadUnsafeSet[T]) Contains(v ...T) bool {
76+
func (s threadUnsafeSet[T]) Contains(v ...T) bool {
7277
for _, val := range v {
73-
if _, ok := (*s)[val]; !ok {
78+
if _, ok := s[val]; !ok {
7479
return false
7580
}
7681
}
7782
return true
7883
}
7984

8085
// private version of Contains for a single element v
81-
func (s *threadUnsafeSet[T]) contains(v T) bool {
82-
_, ok := (*s)[v]
86+
func (s threadUnsafeSet[T]) contains(v T) (ok bool) {
87+
_, ok = s[v]
8388
return ok
8489
}
8590

86-
func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
87-
o := other.(*threadUnsafeSet[T])
91+
func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
92+
o := other.(threadUnsafeSet[T])
8893

8994
diff := newThreadUnsafeSet[T]()
90-
for elem := range *s {
95+
for elem := range s {
9196
if !o.contains(elem) {
9297
diff.add(elem)
9398
}
9499
}
95-
return &diff
100+
return diff
96101
}
97102

98-
func (s *threadUnsafeSet[T]) Each(cb func(T) bool) {
99-
for elem := range *s {
103+
func (s threadUnsafeSet[T]) Each(cb func(T) bool) {
104+
for elem := range s {
100105
if cb(elem) {
101106
break
102107
}
103108
}
104109
}
105110

106-
func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool {
107-
o := other.(*threadUnsafeSet[T])
111+
func (s threadUnsafeSet[T]) Equal(other Set[T]) bool {
112+
o := other.(threadUnsafeSet[T])
108113

109114
if s.Cardinality() != other.Cardinality() {
110115
return false
111116
}
112-
for elem := range *s {
117+
for elem := range s {
113118
if !o.contains(elem) {
114119
return false
115120
}
116121
}
117122
return true
118123
}
119124

120-
func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
121-
o := other.(*threadUnsafeSet[T])
125+
func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
126+
o := other.(threadUnsafeSet[T])
122127

123128
intersection := newThreadUnsafeSet[T]()
124129
// loop over smaller set
125130
if s.Cardinality() < other.Cardinality() {
126-
for elem := range *s {
131+
for elem := range s {
127132
if o.contains(elem) {
128133
intersection.add(elem)
129134
}
130135
}
131136
} else {
132-
for elem := range *o {
137+
for elem := range o {
133138
if s.contains(elem) {
134139
intersection.add(elem)
135140
}
136141
}
137142
}
138-
return &intersection
143+
return intersection
139144
}
140145

141-
func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
146+
func (s threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
142147
return s.Cardinality() < other.Cardinality() && s.IsSubset(other)
143148
}
144149

145-
func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
150+
func (s threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
146151
return s.Cardinality() > other.Cardinality() && s.IsSuperset(other)
147152
}
148153

149-
func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
150-
o := other.(*threadUnsafeSet[T])
154+
func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
155+
o := other.(threadUnsafeSet[T])
151156
if s.Cardinality() > other.Cardinality() {
152157
return false
153158
}
154-
for elem := range *s {
159+
for elem := range s {
155160
if !o.contains(elem) {
156161
return false
157162
}
158163
}
159164
return true
160165
}
161166

162-
func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
167+
func (s threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
163168
return other.IsSubset(s)
164169
}
165170

166-
func (s *threadUnsafeSet[T]) Iter() <-chan T {
171+
func (s threadUnsafeSet[T]) Iter() <-chan T {
167172
ch := make(chan T)
168173
go func() {
169-
for elem := range *s {
174+
for elem := range s {
170175
ch <- elem
171176
}
172177
close(ch)
@@ -175,12 +180,12 @@ func (s *threadUnsafeSet[T]) Iter() <-chan T {
175180
return ch
176181
}
177182

178-
func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] {
183+
func (s threadUnsafeSet[T]) Iterator() *Iterator[T] {
179184
iterator, ch, stopCh := newIterator[T]()
180185

181186
go func() {
182187
L:
183-
for elem := range *s {
188+
for elem := range s {
184189
select {
185190
case <-stopCh:
186191
break L
@@ -193,77 +198,78 @@ func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] {
193198
return iterator
194199
}
195200

196-
// TODO: how can we make this properly , return T but can't return nil.
197-
func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) {
198-
for item := range *s {
199-
delete(*s, item)
201+
// Pop returns a popped item in case set is not empty, or nil-value of T
202+
// if set is already empty
203+
func (s threadUnsafeSet[T]) Pop() (v T, ok bool) {
204+
for item := range s {
205+
delete(s, item)
200206
return item, true
201207
}
202-
return
208+
return v, false
203209
}
204210

205-
func (s *threadUnsafeSet[T]) Remove(v T) {
206-
delete(*s, v)
211+
func (s threadUnsafeSet[T]) Remove(v T) {
212+
delete(s, v)
207213
}
208214

209-
func (s *threadUnsafeSet[T]) String() string {
210-
items := make([]string, 0, len(*s))
215+
func (s threadUnsafeSet[T]) String() string {
216+
items := make([]string, 0, len(s))
211217

212-
for elem := range *s {
218+
for elem := range s {
213219
items = append(items, fmt.Sprintf("%v", elem))
214220
}
215221
return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
216222
}
217223

218-
func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
219-
o := other.(*threadUnsafeSet[T])
224+
func (s threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
225+
o := other.(threadUnsafeSet[T])
220226

221227
sd := newThreadUnsafeSet[T]()
222-
for elem := range *s {
228+
for elem := range s {
223229
if !o.contains(elem) {
224230
sd.add(elem)
225231
}
226232
}
227-
for elem := range *o {
233+
for elem := range o {
228234
if !s.contains(elem) {
229235
sd.add(elem)
230236
}
231237
}
232-
return &sd
238+
return sd
233239
}
234240

235-
func (s *threadUnsafeSet[T]) ToSlice() []T {
241+
func (s threadUnsafeSet[T]) ToSlice() []T {
236242
keys := make([]T, 0, s.Cardinality())
237-
for elem := range *s {
243+
for elem := range s {
238244
keys = append(keys, elem)
239245
}
240246

241247
return keys
242248
}
243249

244-
func (s *threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
245-
o := other.(*threadUnsafeSet[T])
250+
func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
251+
o := other.(threadUnsafeSet[T])
246252

247253
n := s.Cardinality()
248254
if o.Cardinality() > n {
249255
n = o.Cardinality()
250256
}
251257
unionedSet := make(threadUnsafeSet[T], n)
252258

253-
for elem := range *s {
259+
for elem := range s {
254260
unionedSet.add(elem)
255261
}
256-
for elem := range *o {
262+
for elem := range o {
257263
unionedSet.add(elem)
258264
}
259-
return &unionedSet
265+
return unionedSet
260266
}
261267

262268
// MarshalJSON creates a JSON array from the set, it marshals all elements
263-
func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
269+
func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
264270
items := make([]string, 0, s.Cardinality())
265271

266-
for elem := range *s {
272+
for elem := range s {
267273
b, err := json.Marshal(elem)
268274
if err != nil {
269275
return nil, err
@@ -277,7 +283,7 @@ func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
277283

278284
// UnmarshalJSON recreates a set from a JSON array, it only decodes
279285
// primitive types. Numbers are decoded as json.Number.
280-
func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
286+
func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
281287
var i []any
282288

283289
d := json.NewDecoder(bytes.NewReader(b))

0 commit comments

Comments
 (0)