@@ -232,27 +232,47 @@ fn check_2(
232232
233233#[ cfg( test) ]
234234mod tests {
235+ use proptest:: sample:: size_range;
235236 use rand_core:: OsRng ;
237+ use test_strategy:: proptest;
236238
237239 use super :: { super :: elgamal:: SecretKey , * } ;
238240
239241 const VECTOR_SIZE : usize = 3 ;
240242
241- #[ test]
242- fn zk_unit_vector_test ( ) {
243+ fn is_unit_vector ( vector : & [ Scalar ] ) -> bool {
244+ let ones = vector. iter ( ) . filter ( |s| s == & & Scalar :: one ( ) ) . count ( ) ;
245+ let zeros = vector. iter ( ) . filter ( |s| s == & & Scalar :: zero ( ) ) . count ( ) ;
246+ ones == 1 && zeros == vector. len ( ) - 1
247+ }
248+
249+ #[ proptest]
250+ fn zk_unit_vector_test (
251+ secret_key : SecretKey , secret_commitment_key : SecretKey ,
252+ #[ strategy( 2 ..10_usize ) ] unit_vector_size : usize ,
253+ #[ strategy( 0 ..#unit_vector_size) ] unit_vector_index : usize ,
254+ ) {
243255 let mut rng = OsRng ;
244256
245- let secret_key = SecretKey :: random ( & mut rng) ;
246- let secret_commitment_key = SecretKey :: random ( & mut rng) ;
247257 let public_key = secret_key. public_key ( ) ;
248258 let commitment_key = secret_commitment_key. public_key ( ) ;
249259
250- let unit_vector = [ Scalar :: one ( ) , Scalar :: zero ( ) , Scalar :: zero ( ) ] ;
251- let encryption_randomness = vec ! [
252- Scalar :: random( & mut rng) ,
253- Scalar :: random( & mut rng) ,
254- Scalar :: random( & mut rng) ,
255- ] ;
260+ let unit_vector: Vec < _ > = ( 0 ..unit_vector_size)
261+ . map ( |i| {
262+ if i == unit_vector_index {
263+ Scalar :: one ( )
264+ } else {
265+ Scalar :: zero ( )
266+ }
267+ } )
268+ . collect ( ) ;
269+
270+ assert ! ( is_unit_vector( & unit_vector) ) ;
271+
272+ let encryption_randomness: Vec < _ > = unit_vector
273+ . iter ( )
274+ . map ( |_| Scalar :: random ( & mut rng) )
275+ . collect ( ) ;
256276
257277 let ciphertexts: Vec < _ > = encryption_randomness
258278 . iter ( )
@@ -277,31 +297,35 @@ mod tests {
277297 ) ) ;
278298 }
279299
280- #[ test]
281- fn not_a_unit_vector_test ( ) {
300+ #[ proptest]
301+ fn not_a_unit_vector_test (
302+ secret_key : SecretKey , secret_commitment_key : SecretKey ,
303+ #[ any( size_range( 2 ..10_usize ) . lift( ) ) ] random_vector : Vec < Scalar > ,
304+ ) {
282305 let mut rng = OsRng ;
283306
284- let secret_key = SecretKey :: random ( & mut rng) ;
285- let secret_commitment_key = SecretKey :: random ( & mut rng) ;
307+ // make sure the `random_vector` is not a unit vector
308+ // if it is early return
309+ if is_unit_vector ( & random_vector) {
310+ return Ok ( ( ) ) ;
311+ }
312+
286313 let public_key = secret_key. public_key ( ) ;
287314 let commitment_key = secret_commitment_key. public_key ( ) ;
288315
289- // Encrypt not a unit vector
290- let unit_vector = [ Scalar :: from ( 2 ) , Scalar :: zero ( ) , Scalar :: zero ( ) ] ;
291- let encryption_randomness = vec ! [
292- Scalar :: random( & mut rng) ,
293- Scalar :: random( & mut rng) ,
294- Scalar :: random( & mut rng) ,
295- ] ;
316+ let encryption_randomness: Vec < _ > = random_vector
317+ . iter ( )
318+ . map ( |_| Scalar :: random ( & mut rng) )
319+ . collect ( ) ;
296320
297321 let ciphertexts: Vec < _ > = encryption_randomness
298322 . iter ( )
299- . zip ( unit_vector . iter ( ) )
323+ . zip ( random_vector . iter ( ) )
300324 . map ( |( r, v) | encrypt ( v, & public_key, r) )
301325 . collect ( ) ;
302326
303327 let proof = generate_unit_vector_proof (
304- & unit_vector ,
328+ & random_vector ,
305329 encryption_randomness,
306330 ciphertexts. clone ( ) ,
307331 & public_key,
0 commit comments