Skip to content

Commit 7266341

Browse files
committed
rework SupportsSHAKE
1 parent 9856e1e commit 7266341

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

cng/sha3.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,26 @@ func SumSHAKE256(data []byte, length int) []byte {
5858
return out
5959
}
6060

61-
// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is
62-
// supported.
63-
func SupportsSHAKE128() bool {
64-
_, err := loadHash(bcrypt.CSHAKE128_ALGORITHM, bcrypt.ALG_NONE_FLAG)
65-
return err == nil
66-
}
67-
68-
// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is
69-
// supported.
70-
func SupportsSHAKE256() bool {
71-
_, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG)
61+
// SupportsSHAKE returns true if the SHAKE extendable output function with the
62+
// given securityBits is supported.
63+
func SupportsSHAKE(securityBits int) bool {
64+
// CNG implements SHAKE using CSHAKE with empty N and S.
65+
return SupportsCSHAKE(securityBits)
66+
}
67+
68+
// SupportsCSHAKE returns true if the CSHAKE extendable output function with the
69+
// given securityBits is supported.
70+
func SupportsCSHAKE(securityBits int) bool {
71+
var id string
72+
switch securityBits {
73+
case 128:
74+
id = bcrypt.CSHAKE128_ALGORITHM
75+
case 256:
76+
id = bcrypt.CSHAKE256_ALGORITHM
77+
default:
78+
return false
79+
}
80+
_, err := loadHash(id, bcrypt.ALG_NONE_FLAG)
7281
return err == nil
7382
}
7483

cng/sha3_test.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,19 @@ var testShakes = map[string]struct {
3232
}
3333

3434
func skipCSHAKEIfNotSupported(t *testing.T, algo string) {
35+
var supported bool
3536
switch algo {
36-
case "SHAKE128", "CSHAKE128":
37-
if !cng.SupportsSHAKE128() {
38-
t.Skip("skipping: not supported")
39-
}
40-
case "SHAKE256", "CSHAKE256":
41-
if !cng.SupportsSHAKE256() {
42-
t.Skip("skipping: not supported")
43-
}
37+
case "SHAKE128":
38+
supported = cng.SupportsSHAKE(128)
39+
case "SHAKE256":
40+
supported = cng.SupportsSHAKE(256)
41+
case "CSHAKE128":
42+
supported = cng.SupportsCSHAKE(128)
43+
case "CSHAKE256":
44+
supported = cng.SupportsCSHAKE(256)
45+
}
46+
if !supported {
47+
t.Skip("skipping: not supported")
4448
}
4549
}
4650

@@ -109,14 +113,14 @@ func TestCSHAKEReset(t *testing.T) {
109113

110114
func TestCSHAKEAccumulated(t *testing.T) {
111115
t.Run("CSHAKE128", func(t *testing.T) {
112-
if !cng.SupportsSHAKE128() {
116+
if !cng.SupportsSHAKE(128) {
113117
t.Skip("skipping: not supported")
114118
}
115119
testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8,
116120
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
117121
})
118122
t.Run("CSHAKE256", func(t *testing.T) {
119-
if !cng.SupportsSHAKE256() {
123+
if !cng.SupportsSHAKE(256) {
120124
t.Skip("skipping: not supported")
121125
}
122126
testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8,
@@ -155,7 +159,7 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE,
155159
}
156160

157161
func TestCSHAKELargeS(t *testing.T) {
158-
if !cng.SupportsSHAKE128() {
162+
if !cng.SupportsSHAKE(128) {
159163
t.Skip("skipping: not supported")
160164
}
161165
const s = (1<<32)/8 + 1000 // s * 8 > 2^32
@@ -173,13 +177,13 @@ func TestCSHAKELargeS(t *testing.T) {
173177
}
174178
}
175179

176-
func TestCSHAKESum(t *testing.T) {
180+
func TestSHAKESum(t *testing.T) {
177181
const testString = "hello world"
178-
t.Run("CSHAKE128", func(t *testing.T) {
179-
if !cng.SupportsSHAKE128() {
182+
t.Run("SHAKE128", func(t *testing.T) {
183+
if !cng.SupportsSHAKE(128) {
180184
t.Skip("skipping: not supported")
181185
}
182-
h := cng.NewCSHAKE128(nil, nil)
186+
h := cng.NewSHAKE128()
183187
h.Write([]byte(testString[:5]))
184188
h.Write([]byte(testString[5:]))
185189
want := make([]byte, 32)
@@ -189,11 +193,11 @@ func TestCSHAKESum(t *testing.T) {
189193
t.Errorf("got:%x want:%x", got, want)
190194
}
191195
})
192-
t.Run("CSHAKE256", func(t *testing.T) {
193-
if !cng.SupportsSHAKE256() {
196+
t.Run("SHAKE256", func(t *testing.T) {
197+
if !cng.SupportsSHAKE(256) {
194198
t.Skip("skipping: not supported")
195199
}
196-
h := cng.NewCSHAKE256(nil, nil)
200+
h := cng.NewSHAKE256()
197201
h.Write([]byte(testString[:5]))
198202
h.Write([]byte(testString[5:]))
199203
want := make([]byte, 32)

0 commit comments

Comments
 (0)