Skip to content

Commit 5413604

Browse files
committed
Update some files
1 parent 603ca63 commit 5413604

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

dataset/mnist/mnist_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package mnist_test
22

33
import (
4+
"bytes"
5+
"compress/gzip"
46
"fmt"
7+
"os"
58
"testing"
69

710
"github.com/itsubaki/neu/dataset/mnist"
@@ -122,3 +125,49 @@ func TestMust(t *testing.T) {
122125
mnist.Must(nil, nil, fmt.Errorf("something went wrong"))
123126
t.Fail()
124127
}
128+
129+
func TestLoadImage(t *testing.T) {
130+
invalid := []byte{0x00, 0x08, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00}
131+
file := "invalid.gz"
132+
133+
buf := new(bytes.Buffer)
134+
w := gzip.NewWriter(buf)
135+
if _, err := w.Write(invalid); err != nil {
136+
t.Fatalf("write gzip data: %v", err)
137+
}
138+
w.Close()
139+
140+
if err := os.WriteFile(file, buf.Bytes(), 0644); err != nil {
141+
t.Fatalf("write invalid file: %v", err)
142+
}
143+
defer os.Remove(file)
144+
145+
if _, err := mnist.LoadImage(file); err != nil {
146+
return
147+
}
148+
149+
t.Fatal("unexpected")
150+
}
151+
152+
func TestLoadLabel(t *testing.T) {
153+
invalid := []byte{0x00, 0x08, 0x01, 0x00, 0x00, 0x00}
154+
file := "invalid.gz"
155+
156+
buf := new(bytes.Buffer)
157+
w := gzip.NewWriter(buf)
158+
if _, err := w.Write(invalid); err != nil {
159+
t.Fatalf("write gzip data: %v", err)
160+
}
161+
w.Close()
162+
163+
if err := os.WriteFile(file, buf.Bytes(), 0644); err != nil {
164+
t.Fatalf("write invalid file: %v", err)
165+
}
166+
defer os.Remove(file)
167+
168+
if _, err := mnist.LoadLabel(file); err != nil {
169+
return
170+
}
171+
172+
t.Fatal("unexpected")
173+
}

math/rand/crypto.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"fmt"
66
)
77

8+
var RandRead = rand.Read
9+
810
func Must[T any](a T, err error) T {
911
if err != nil {
1012
panic(err)
@@ -19,7 +21,7 @@ func MustRead() [32]byte {
1921

2022
func Read() ([32]byte, error) {
2123
var p [32]byte
22-
if _, err := rand.Read(p[:]); err != nil {
24+
if _, err := RandRead(p[:]); err != nil {
2325
return [32]byte{}, fmt.Errorf("read: %v", err)
2426
}
2527

math/rand/crypto_test.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,31 @@ package rand_test
22

33
import (
44
crand "crypto/rand"
5+
"errors"
56
"fmt"
67
randv2 "math/rand/v2"
7-
"strings"
88
"testing"
99

1010
"github.com/itsubaki/neu/math/rand"
1111
)
1212

13+
var ErrSomtingWentWrong = errors.New("something went wrong")
14+
1315
func ExampleRead() {
14-
reader := crand.Reader
1516
defer func() {
16-
crand.Reader = reader
17+
rand.RandRead = crand.Read
1718
}()
1819

19-
crand.Reader = strings.NewReader("io.Reader stream to be read\n")
20+
rand.RandRead = func(b []byte) (int, error) {
21+
return 0, ErrSomtingWentWrong
22+
}
23+
2024
if _, err := rand.Read(); err != nil {
2125
fmt.Println(err)
2226
}
2327

2428
// Output:
25-
// read: unexpected EOF
29+
// read: something went wrong
2630
}
2731

2832
func TestMustRead(t *testing.T) {

0 commit comments

Comments
 (0)