Skip to content

Commit 0e4a51c

Browse files
authored
Merge pull request #82 from microsoft/shakesup
Deduplicate SupportsSHAKE
2 parents b49854c + 0fc0aaf commit 0e4a51c

File tree

2 files changed

+29
-27
lines changed

2 files changed

+29
-27
lines changed

cng/sha3.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,19 @@ func SumSHAKE256(data []byte, length int) []byte {
5858
return out
5959
}
6060

61-
// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is
62-
// supported.
63-
func SupportsSHAKE128() bool {
64-
_, err := loadHash(bcrypt.CSHAKE128_ALGORITHM, bcrypt.ALG_NONE_FLAG)
65-
return err == nil
66-
}
67-
68-
// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is
69-
// supported.
70-
func SupportsSHAKE256() bool {
71-
_, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG)
61+
// SupportsSHAKE returns true if the SHAKE and CSHAKE extendable output functions
62+
// with the given securityBits are supported.
63+
func SupportsSHAKE(securityBits int) bool {
64+
var id string
65+
switch securityBits {
66+
case 128:
67+
id = bcrypt.CSHAKE128_ALGORITHM
68+
case 256:
69+
id = bcrypt.CSHAKE256_ALGORITHM
70+
default:
71+
return false
72+
}
73+
_, err := loadHash(id, bcrypt.ALG_NONE_FLAG)
7274
return err == nil
7375
}
7476

cng/sha3_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ var testShakes = map[string]struct {
3232
}
3333

3434
func skipCSHAKEIfNotSupported(t *testing.T, algo string) {
35+
var supported bool
3536
switch algo {
3637
case "SHAKE128", "CSHAKE128":
37-
if !cng.SupportsSHAKE128() {
38-
t.Skip("skipping: not supported")
39-
}
38+
supported = cng.SupportsSHAKE(128)
4039
case "SHAKE256", "CSHAKE256":
41-
if !cng.SupportsSHAKE256() {
42-
t.Skip("skipping: not supported")
43-
}
40+
supported = cng.SupportsSHAKE(256)
41+
}
42+
if !supported {
43+
t.Skip("skipping: not supported")
4444
}
4545
}
4646

@@ -109,14 +109,14 @@ func TestCSHAKEReset(t *testing.T) {
109109

110110
func TestCSHAKEAccumulated(t *testing.T) {
111111
t.Run("CSHAKE128", func(t *testing.T) {
112-
if !cng.SupportsSHAKE128() {
112+
if !cng.SupportsSHAKE(128) {
113113
t.Skip("skipping: not supported")
114114
}
115115
testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8,
116116
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
117117
})
118118
t.Run("CSHAKE256", func(t *testing.T) {
119-
if !cng.SupportsSHAKE256() {
119+
if !cng.SupportsSHAKE(256) {
120120
t.Skip("skipping: not supported")
121121
}
122122
testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8,
@@ -155,7 +155,7 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE,
155155
}
156156

157157
func TestCSHAKELargeS(t *testing.T) {
158-
if !cng.SupportsSHAKE128() {
158+
if !cng.SupportsSHAKE(128) {
159159
t.Skip("skipping: not supported")
160160
}
161161
const s = (1<<32)/8 + 1000 // s * 8 > 2^32
@@ -173,13 +173,13 @@ func TestCSHAKELargeS(t *testing.T) {
173173
}
174174
}
175175

176-
func TestCSHAKESum(t *testing.T) {
176+
func TestSHAKESum(t *testing.T) {
177177
const testString = "hello world"
178-
t.Run("CSHAKE128", func(t *testing.T) {
179-
if !cng.SupportsSHAKE128() {
178+
t.Run("SHAKE128", func(t *testing.T) {
179+
if !cng.SupportsSHAKE(128) {
180180
t.Skip("skipping: not supported")
181181
}
182-
h := cng.NewCSHAKE128(nil, nil)
182+
h := cng.NewSHAKE128()
183183
h.Write([]byte(testString[:5]))
184184
h.Write([]byte(testString[5:]))
185185
want := make([]byte, 32)
@@ -189,11 +189,11 @@ func TestCSHAKESum(t *testing.T) {
189189
t.Errorf("got:%x want:%x", got, want)
190190
}
191191
})
192-
t.Run("CSHAKE256", func(t *testing.T) {
193-
if !cng.SupportsSHAKE256() {
192+
t.Run("SHAKE256", func(t *testing.T) {
193+
if !cng.SupportsSHAKE(256) {
194194
t.Skip("skipping: not supported")
195195
}
196-
h := cng.NewCSHAKE256(nil, nil)
196+
h := cng.NewSHAKE256()
197197
h.Write([]byte(testString[:5]))
198198
h.Write([]byte(testString[5:]))
199199
want := make([]byte, 32)

0 commit comments

Comments
 (0)