@@ -25,101 +25,106 @@ static void fillWithRandom(Memref<i256> *input, const i256 &kPrime) {
2525 std::mt19937_64 rng (std::random_device{}()); // NOLINT(whitespace/braces)
2626 std::uniform_int_distribution<uint64_t > dist (0 , UINT64_MAX);
2727 for (int i = 0 ; i < NUM_COEFFS; i++) {
28- *input->pget (0 , i ) = i256::randomLT (kPrime , rng, dist);
28+ *input->pget (i, 0 ) = i256::randomLT (kPrime , rng, dist);
2929 }
3030}
3131
32- extern " C" void _mlir_ciface_input_generation (Memref<i256> *output);
33- extern " C" void _mlir_ciface_ntt (Memref<i256> *output, Memref<i256> *input);
34- extern " C" void _mlir_ciface_intt (Memref<i256> *output, Memref<i256> *input);
32+ extern " C" void _mlir_ciface_ntt (Memref<i256> *buffer);
33+ extern " C" void _mlir_ciface_intt (Memref<i256> *buffer);
3534
36- extern " C" void _mlir_ciface_ntt_mont (Memref<i256> *output,
37- Memref<i256> *input);
38- extern " C" void _mlir_ciface_intt_mont (Memref<i256> *output,
39- Memref<i256> *input);
35+ extern " C" void _mlir_ciface_ntt_mont (Memref<i256> *buffer);
36+ extern " C" void _mlir_ciface_intt_mont (Memref<i256> *buffer);
4037
4138void BM_ntt_benchmark (::benchmark::State &state) {
42- Memref<i256> input (1 , NUM_COEFFS);
43- _mlir_ciface_input_generation (&input);
39+ Memref<i256> input (NUM_COEFFS, 1 );
4440 fillWithRandom (&input, kPrime );
4541
46- Memref<i256> ntt (1 , NUM_COEFFS );
42+ Memref<i256> ntt (NUM_COEFFS, 1 );
4743 for (auto _ : state) {
48- _mlir_ciface_ntt (&ntt, &input);
44+ state.PauseTiming ();
45+ memcpy (ntt.pget (0 , 0 ), input.pget (0 , 0 ), sizeof (i256) * NUM_COEFFS);
46+ state.ResumeTiming ();
47+ _mlir_ciface_ntt (&ntt);
4948 }
5049
51- Memref<i256> intt (1 , NUM_COEFFS);
52- _mlir_ciface_intt (&intt, &ntt);
50+ _mlir_ciface_intt (&ntt);
5351
5452 for (int i = 0 ; i < NUM_COEFFS; i++) {
5553 for (int j = 0 ; j < 4 ; j++) {
56- EXPECT_EQ (intt .pget (0 , i )->limbs [j], input.pget (0 , i )->limbs [j]);
54+ EXPECT_EQ (ntt .pget (i, 0 )->limbs [j], input.pget (i, 0 )->limbs [j]);
5755 }
5856 }
5957}
6058
6159BENCHMARK (BM_ntt_benchmark)->Unit(::benchmark::kMillisecond );
6260
6361void BM_intt_benchmark (::benchmark::State &state) {
64- Memref<i256> input (1 , NUM_COEFFS);
65- _mlir_ciface_input_generation (&input);
62+ Memref<i256> input (NUM_COEFFS, 1 );
6663 fillWithRandom (&input, kPrime );
6764
68- Memref<i256> ntt (1 , NUM_COEFFS);
69- _mlir_ciface_ntt (&ntt, &input);
65+ Memref<i256> ntt (NUM_COEFFS, 1 );
66+ memcpy (ntt.pget (0 , 0 ), input.pget (0 , 0 ), sizeof (i256) * NUM_COEFFS);
67+ _mlir_ciface_ntt (&ntt);
7068
71- Memref<i256> intt (1 , NUM_COEFFS );
69+ Memref<i256> intt (NUM_COEFFS, 1 );
7270 for (auto _ : state) {
73- _mlir_ciface_intt (&intt, &ntt);
71+ state.PauseTiming ();
72+ memcpy (intt.pget (0 , 0 ), ntt.pget (0 , 0 ), sizeof (i256) * NUM_COEFFS);
73+ state.ResumeTiming ();
74+ _mlir_ciface_intt (&ntt);
7475 }
7576
7677 for (int i = 0 ; i < NUM_COEFFS; i++) {
7778 for (int j = 0 ; j < 4 ; j++) {
78- EXPECT_EQ (intt .pget (0 , i )->limbs [j], input.pget (0 , i )->limbs [j]);
79+ EXPECT_EQ (ntt .pget (i, 0 )->limbs [j], input.pget (i, 0 )->limbs [j]);
7980 }
8081 }
8182}
8283
8384BENCHMARK (BM_intt_benchmark)->Unit(::benchmark::kMillisecond );
8485
8586void BM_ntt_mont_benchmark (::benchmark::State &state) {
86- Memref<i256> input (1 , NUM_COEFFS);
87- _mlir_ciface_input_generation (&input);
87+ Memref<i256> input (NUM_COEFFS, 1 );
8888 fillWithRandom (&input, kPrime );
8989
90- Memref<i256> ntt (1 , NUM_COEFFS );
90+ Memref<i256> ntt (NUM_COEFFS, 1 );
9191 for (auto _ : state) {
92- _mlir_ciface_ntt_mont (&ntt, &input);
92+ state.PauseTiming ();
93+ memcpy (ntt.pget (0 , 0 ), input.pget (0 , 0 ), sizeof (i256) * NUM_COEFFS);
94+ state.ResumeTiming ();
95+ _mlir_ciface_ntt_mont (&ntt);
9396 }
9497
95- Memref<i256> intt (1 , NUM_COEFFS);
96- _mlir_ciface_intt_mont (&intt, &ntt);
98+ _mlir_ciface_intt_mont (&ntt);
9799
98100 for (int i = 0 ; i < NUM_COEFFS; i++) {
99101 for (int j = 0 ; j < 4 ; j++) {
100- EXPECT_EQ (intt .pget (0 , i )->limbs [j], input.pget (0 , i )->limbs [j]);
102+ EXPECT_EQ (ntt .pget (i, 0 )->limbs [j], input.pget (i, 0 )->limbs [j]);
101103 }
102104 }
103105}
104106
105107BENCHMARK (BM_ntt_mont_benchmark)->Unit(::benchmark::kMillisecond );
106108
107109void BM_intt_mont_benchmark (::benchmark::State &state) {
108- Memref<i256> input (1 , NUM_COEFFS);
109- _mlir_ciface_input_generation (&input);
110+ Memref<i256> input (NUM_COEFFS, 1 );
110111 fillWithRandom (&input, kPrime );
111112
112- Memref<i256> ntt (1 , NUM_COEFFS);
113- _mlir_ciface_ntt_mont (&ntt, &input);
113+ Memref<i256> ntt (NUM_COEFFS, 1 );
114+ memcpy (ntt.pget (0 , 0 ), input.pget (0 , 0 ), sizeof (i256) * NUM_COEFFS);
115+ _mlir_ciface_ntt_mont (&ntt);
114116
115- Memref<i256> intt (1 , NUM_COEFFS );
117+ Memref<i256> intt (NUM_COEFFS, 1 );
116118 for (auto _ : state) {
117- _mlir_ciface_intt_mont (&intt, &ntt);
119+ state.PauseTiming ();
120+ memcpy (intt.pget (0 , 0 ), ntt.pget (0 , 0 ), sizeof (i256) * NUM_COEFFS);
121+ state.ResumeTiming ();
122+ _mlir_ciface_intt_mont (&intt);
118123 }
119124
120125 for (int i = 0 ; i < NUM_COEFFS; i++) {
121126 for (int j = 0 ; j < 4 ; j++) {
122- EXPECT_EQ (intt.pget (0 , i )->limbs [j], input.pget (0 , i )->limbs [j]);
127+ EXPECT_EQ (intt.pget (i, 0 )->limbs [j], input.pget (i, 0 )->limbs [j]);
123128 }
124129 }
125130}
@@ -136,12 +141,12 @@ BENCHMARK(BM_intt_mont_benchmark)->Unit(::benchmark::kMillisecond);
136141// L1 Data 64 KiB
137142// L1 Instruction 128 KiB
138143// L2 Unified 4096 KiB (x14)
139- // Load Average: 6.49, 5.64, 5.49
140- // -------------------------------------------------------------------------
141- // Benchmark Time CPU Iterations
142- // -------------------------------------------------------------------------
143- // BM_ntt_benchmark 1656 ms 1050 ms 1
144- // BM_intt_benchmark/iterations:1 1791 ms 1090 ms 1
145- // BM_ntt_mont_benchmark 38.6 ms 18.6 ms 40
146- // BM_intt_mont_benchmark 99.4 ms 56.4 ms 11
144+ // Load Average: 8.66, 7.19, 7.37
145+ // -----------------------------------------------------------------
146+ // Benchmark Time CPU Iterations
147+ // -----------------------------------------------------------------
148+ // BM_ntt_benchmark 1603 ms 1085 ms 1
149+ // BM_intt_benchmark 1585 ms 1120 ms 1
150+ // BM_ntt_mont_benchmark 34.7 ms 16.8 ms 42
151+ // BM_intt_mont_benchmark 33.8 ms 16.6 ms 42
147152// NOLINTEND()
0 commit comments