Skip to content

Commit da739e4

Browse files
committed
Implement a basic MPC test
1 parent d939e5c commit da739e4

File tree

3 files changed

+61
-19
lines changed

3 files changed

+61
-19
lines changed

libc/utils/MPCWrapper/MPCUtils.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,20 @@ class MPCNumber {
155155

156156
MPCNumber carg() const {
157157
mpfr_t res;
158+
mpc_t res_mpc;
159+
158160
mpfr_init2(res, this->mpc_real_precision);
161+
mpc_init3(res_mpc, this->mpc_real_precision, this->mpc_imag_precision);
162+
159163
mpc_arg(res, value, MPC_RND_RE(this->mpc_rounding));
160-
mpc_t res_mpc;
161-
mpc_set_fr(res_mpc, res, MPC_RND_RE(this->mpc_rounding));
162-
return MPCNumber(res_mpc, this->mpc_real_precision,
163-
this->mpc_imag_precision, this->mpc_rounding);
164+
mpc_set_fr(res_mpc, res, this->mpc_rounding);
165+
166+
MPCNumber result(res_mpc, this->mpc_real_precision, this->mpc_imag_precision, this->mpc_rounding);
167+
168+
mpfr_clear(res);
169+
mpc_clear(res_mpc);
170+
171+
return result;
164172
}
165173
};
166174

@@ -215,8 +223,10 @@ bool compare_unary_operation_single_output_different_type(
215223
MPCNumber mpc_result;
216224
mpc_result = unary_operation(op, input, precision, rounding);
217225
mpc_t mpc_result_val;
226+
mpc_init3(mpc_result_val, precision, precision);
218227
mpc_result.getValue(mpc_result_val);
219228
mpfr_t real;
229+
mpfr_init2(real, precision);
220230
mpc_real(real, mpc_result_val, get_mpfr_rounding_mode(rounding.Rrnd));
221231
mpfr::MPFRNumber mpfr_real(real, precision, rounding.Rrnd);
222232
double ulp_real = mpfr_real.ulp(libc_result);
@@ -227,6 +237,42 @@ template bool compare_unary_operation_single_output_different_type(
227237
Operation, _Complex float, float, double, MPCRoundingMode);
228238
template bool compare_unary_operation_single_output_different_type(
229239
Operation, _Complex double, double, double, MPCRoundingMode);
240+
241+
template <typename InputType, typename OutputType>
242+
void explain_unary_operation_single_output_different_type_error(
243+
Operation op, InputType input, OutputType libc_result, double ulp_tolerance,
244+
MPCRoundingMode rounding) {
245+
246+
unsigned int precision = get_precision<get_real_t<InputType>>(ulp_tolerance);
247+
248+
MPCNumber mpc_result;
249+
mpc_result = unary_operation(op, input, precision, rounding);
250+
251+
mpc_t mpc_result_val;
252+
mpc_init3(mpc_result_val, precision, precision);
253+
mpc_result.getValue(mpc_result_val);
254+
255+
mpfr_t real;
256+
mpfr_init2(real, precision);
257+
mpc_real(real, mpc_result_val, get_mpfr_rounding_mode(rounding.Rrnd));
258+
259+
mpfr::MPFRNumber mpfr_real(real, precision, rounding.Rrnd);
260+
261+
double ulp_real = mpfr_real.ulp(libc_result);
262+
263+
if(ulp_real > ulp_tolerance) {
264+
cpp::array<char, 1024> msg_buf;
265+
cpp::StringStream msg(msg_buf);
266+
// TODO: Add information to the error message.
267+
}
268+
269+
}
270+
271+
template void explain_unary_operation_single_output_different_type_error(
272+
Operation, _Complex float, float, double, MPCRoundingMode);
273+
template void explain_unary_operation_single_output_different_type_error(
274+
Operation, _Complex double, double, double, MPCRoundingMode);
275+
230276
} // namespace internal
231277

232278
} // namespace mpc

libc/utils/MPCWrapper/MPCUtils.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -250,19 +250,17 @@ get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output,
250250
MPCRND::ForceRoundingMode __i##i##j(Irounding); \
251251
if (__r##i##j.success && __i##i##j.success) { \
252252
EXPECT_MPC_MATCH_ROUNDING( \
253-
match_value, \
254-
LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
255-
input, match_value, ulp_tolerance, Rrounding, Irounding)) \
253+
op, input, match_value, ulp_tolerance, Rrounding, Irounding); \
256254
} \
257255
}
258256

259257
#define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
260258
{ \
261259
namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \
262-
for (int i = 0; i < 5; i++) { \
263-
for (int j = 0; j < 5; j++) { \
264-
RoundingMode r_mode = static_cast<RoundingMode>(i); \
265-
RoundingMode i_mode = static_cast<RoundingMode>(j); \
260+
for (int i = 0; i < 4; i++) { \
261+
for (int j = 0; j < 4; j++) { \
262+
MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \
263+
MPCRND::RoundingMode i_mode = static_cast<MPCRND::RoundingMode>(j); \
266264
EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(i, j, op, input, match_value, \
267265
ulp_tolerance, r_mode, i_mode); \
268266
} \
@@ -299,19 +297,17 @@ get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output,
299297
MPCRND::ForceRoundingMode __i##i##j(Irounding); \
300298
if (__r##i##j.success && __i##i##j.success) { \
301299
ASSERT_MPC_MATCH_ROUNDING( \
302-
match_value, \
303-
LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
304-
input, match_value, ulp_tolerance, Rrounding, Irounding)) \
300+
op, input, match_value, ulp_tolerance, Rrounding, Irounding); \
305301
} \
306302
}
307303

308304
#define ASSERT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
309305
{ \
310306
namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \
311-
for (int i = 0; i < 5; i++) { \
312-
for (int j = 0; j < 5; j++) { \
313-
RoundingMode r_mode = static_cast<RoundingMode>(i); \
314-
RoundingMode i_mode = static_cast<RoundingMode>(j); \
307+
for (int i = 0; i < 4; i++) { \
308+
for (int j = 0; j < 4; j++) { \
309+
MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \
310+
MPCRND::RoundingMode i_mode = static_cast<MPCRND::RoundingMode>(j); \
315311
ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(i, j, op, input, match_value, \
316312
ulp_tolerance, r_mode, i_mode); \
317313
} \

libc/utils/MPFRWrapper/MPFRUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class MPFRNumber {
188188
mpfr_set(value, other.value, mpfr_rounding);
189189
}
190190

191-
MPFRNumber(const mpfr_t &x, unsigned int precision, RoundingMode rounding)
191+
MPFRNumber(const mpfr_t x, unsigned int precision, RoundingMode rounding)
192192
: mpfr_precision(precision),
193193
mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
194194
mpfr_init2(value, mpfr_precision);

0 commit comments

Comments
 (0)