Skip to content

Commit f2f3c8a

Browse files
committed
internal/slices: Add 'CollectWithError'.
1 parent db22bee commit f2f3c8a

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

internal/slices/slices.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package slices
55

66
import (
7+
"iter"
78
"slices"
89
)
910

@@ -167,3 +168,23 @@ func Strings[S ~[]E, E stringable](s S) []string {
167168
return string(e)
168169
})
169170
}
171+
172+
// CollectWithError collects values from seq into a new slice and returns it.
173+
// The first non-nil error in seq is returned.
174+
// If seq is empty, the result is nil.
175+
func CollectWithError[E any](seq iter.Seq2[E, error]) ([]E, error) {
176+
return AppendSeqWithError([]E(nil), seq)
177+
}
178+
179+
// AppendSeqWithError appends the values from seq to the slice and returns the extended slice.
180+
// The first non-nil error in seq is returned.
181+
// If seq is empty, the result preserves the nilness of s.
182+
func AppendSeqWithError[S ~[]E, E any](s S, seq iter.Seq2[E, error]) (S, error) {
183+
for v, err := range seq {
184+
if err != nil {
185+
return nil, err
186+
}
187+
s = append(s, v)
188+
}
189+
return s, nil
190+
}

internal/slices/slices_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
package slices
55

66
import (
7+
"errors"
8+
"maps"
79
"strings"
810
"testing"
911

@@ -365,3 +367,46 @@ func TestRange(t *testing.T) {
365367
})
366368
}
367369
}
370+
371+
func TestCollectWithError(t *testing.T) {
372+
t.Parallel()
373+
374+
type testCase struct {
375+
input map[int]error
376+
wantErr bool
377+
}
378+
tests := map[string]testCase{
379+
"no error": {
380+
input: map[int]error{
381+
1: nil,
382+
2: nil,
383+
3: nil,
384+
},
385+
},
386+
"has error": {
387+
input: map[int]error{
388+
1: nil,
389+
2: errors.New("test error"),
390+
3: nil,
391+
},
392+
wantErr: true,
393+
},
394+
}
395+
396+
for name, test := range tests {
397+
t.Run(name, func(t *testing.T) {
398+
t.Parallel()
399+
400+
got, err := CollectWithError(maps.All(test.input))
401+
402+
if got, want := err != nil, test.wantErr; !cmp.Equal(got, want) {
403+
t.Errorf("CollectWithError() err %t, want %t", got, want)
404+
}
405+
if err == nil {
406+
if got, want := len(got), len(test.input); !cmp.Equal(got, want) {
407+
t.Errorf("CollectWithError() len %d, want %d", got, want)
408+
}
409+
}
410+
})
411+
}
412+
}

0 commit comments

Comments
 (0)