diff --git a/cng/sha3.go b/cng/sha3.go index d7aa193..4d7a31a 100644 --- a/cng/sha3.go +++ b/cng/sha3.go @@ -58,17 +58,19 @@ func SumSHAKE256(data []byte, length int) []byte { return out } -// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is -// supported. -func SupportsSHAKE128() bool { - _, err := loadHash(bcrypt.CSHAKE128_ALGORITHM, bcrypt.ALG_NONE_FLAG) - return err == nil -} - -// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is -// supported. -func SupportsSHAKE256() bool { - _, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG) +// SupportsSHAKE returns true if the SHAKE and CSHAKE extendable output functions +// with the given securityBits are supported. +func SupportsSHAKE(securityBits int) bool { + var id string + switch securityBits { + case 128: + id = bcrypt.CSHAKE128_ALGORITHM + case 256: + id = bcrypt.CSHAKE256_ALGORITHM + default: + return false + } + _, err := loadHash(id, bcrypt.ALG_NONE_FLAG) return err == nil } diff --git a/cng/sha3_test.go b/cng/sha3_test.go index c52f48b..5a14cf8 100644 --- a/cng/sha3_test.go +++ b/cng/sha3_test.go @@ -32,15 +32,15 @@ var testShakes = map[string]struct { } func skipCSHAKEIfNotSupported(t *testing.T, algo string) { + var supported bool switch algo { case "SHAKE128", "CSHAKE128": - if !cng.SupportsSHAKE128() { - t.Skip("skipping: not supported") - } + supported = cng.SupportsSHAKE(128) case "SHAKE256", "CSHAKE256": - if !cng.SupportsSHAKE256() { - t.Skip("skipping: not supported") - } + supported = cng.SupportsSHAKE(256) + } + if !supported { + t.Skip("skipping: not supported") } } @@ -109,14 +109,14 @@ func TestCSHAKEReset(t *testing.T) { func TestCSHAKEAccumulated(t *testing.T) { t.Run("CSHAKE128", func(t *testing.T) { - if !cng.SupportsSHAKE128() { + if !cng.SupportsSHAKE(128) { t.Skip("skipping: not supported") } testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8, "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") }) t.Run("CSHAKE256", func(t *testing.T) { - if !cng.SupportsSHAKE256() { + if !cng.SupportsSHAKE(256) { t.Skip("skipping: not supported") } testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8, @@ -155,7 +155,7 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE, } func TestCSHAKELargeS(t *testing.T) { - if !cng.SupportsSHAKE128() { + if !cng.SupportsSHAKE(128) { t.Skip("skipping: not supported") } const s = (1<<32)/8 + 1000 // s * 8 > 2^32 @@ -173,13 +173,13 @@ func TestCSHAKELargeS(t *testing.T) { } } -func TestCSHAKESum(t *testing.T) { +func TestSHAKESum(t *testing.T) { const testString = "hello world" - t.Run("CSHAKE128", func(t *testing.T) { - if !cng.SupportsSHAKE128() { + t.Run("SHAKE128", func(t *testing.T) { + if !cng.SupportsSHAKE(128) { t.Skip("skipping: not supported") } - h := cng.NewCSHAKE128(nil, nil) + h := cng.NewSHAKE128() h.Write([]byte(testString[:5])) h.Write([]byte(testString[5:])) want := make([]byte, 32) @@ -189,11 +189,11 @@ func TestCSHAKESum(t *testing.T) { t.Errorf("got:%x want:%x", got, want) } }) - t.Run("CSHAKE256", func(t *testing.T) { - if !cng.SupportsSHAKE256() { + t.Run("SHAKE256", func(t *testing.T) { + if !cng.SupportsSHAKE(256) { t.Skip("skipping: not supported") } - h := cng.NewCSHAKE256(nil, nil) + h := cng.NewSHAKE256() h.Write([]byte(testString[:5])) h.Write([]byte(testString[5:])) want := make([]byte, 32)