diff --git a/std/conversion/conversion.go b/std/conversion/conversion.go index 158420d5c..46566265e 100644 --- a/std/conversion/conversion.go +++ b/std/conversion/conversion.go @@ -81,7 +81,7 @@ func BytesToNative(api frontend.API, b []uints.U8, opts ...Option) (frontend.Var } // check that the input was in range of the field modulus. Omit if cfg.allowOverflow is set. if !cfg.allowOverflow { - assertBytesLeq(api, b, api.Compiler().Field()) + assertBytesLt(api, b, api.Compiler().Field()) } return res, nil } @@ -220,7 +220,7 @@ func NativeToBytes(api frontend.API, v frontend.Variable, opts ...Option) ([]uin // check if we don't care about the uniqueness (in case later when composing // back to native element the check is done there). if !cfg.allowOverflow { - assertBytesLeq(api, resU8, api.Compiler().Field()) + assertBytesLt(api, resU8, api.Compiler().Field()) } return resU8, nil } @@ -337,3 +337,30 @@ func assertBytesLeq(api frontend.API, b []uints.U8, bound *big.Int) error { } return nil } + +func assertBytesLt(api frontend.API, b []uints.U8, bound *big.Int) error { + bapi, err := uints.NewBytes(api) + if err != nil { + return err + } + mBytes := bound.Bytes() + if len(b) < len(mBytes) { + return nil + } + for i := 0; i < len(b)-len(mBytes); i++ { + api.AssertIsEqual(bapi.ValueUnchecked(b[i]), 0) + } + bb := b[len(b)-len(mBytes):] + rchecker := rangecheck.New(api) + var eq_i frontend.Variable = 1 + for i := range mBytes { + diff := api.Sub(mBytes[i], bapi.Value(bb[i])) + nbBits := bits.Len8(mBytes[i]) + rchecker.Check(api.Mul(eq_i, diff), nbBits) + isEq := api.IsZero(diff) + eq_i = api.Mul(eq_i, isEq) + } + // when lengths are comparable, disallow equality + api.AssertIsEqual(eq_i, 0) + return nil +} diff --git a/std/conversion/conversion_test.go b/std/conversion/conversion_test.go index 45b9ad726..41aab5b19 100644 --- a/std/conversion/conversion_test.go +++ b/std/conversion/conversion_test.go @@ -407,3 +407,17 @@ func TestAssertBytesLeq(t *testing.T) { tc(assert, []byte{253, 253, 253}, []byte{254, 252}, true) tc(assert, []byte{253, 253, 253}, []byte{0, 254, 252}, true) } + +func TestBytesToNative_EqualModulus(t *testing.T) { + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + m := fr_bn254.Modulus() + sbytes := m.Bytes() + // Expect invalid since strict < modulus must hold + assert.CheckCircuit( + &BytesToNativeCircuit{In: make([]uints.U8, len(sbytes))}, + test.WithInvalidAssignment(&BytesToNativeCircuit{In: uints.NewU8Array(sbytes), Expected: big.NewInt(0)}), + test.WithCurves(ecc.BN254), + ) + }, "equal-modulus") +}