Skip to content

Commit ce83a37

Browse files
committed
Fix bug in exponentiation.
1 parent b47c9bb commit ce83a37

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

Compiler/GC/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,10 +1625,14 @@ def get_input_from(cls, player, size=1):
16251625
"""
16261626
return cls._new(cls.int_type.get_input_from(player, size=size,
16271627
f=cls.f))
1628-
def __init__(self, value=None, *args, **kwargs):
1628+
def __init__(self, value=None, k=None, *args, **kwargs):
16291629
if isinstance(value, (list, tuple)):
1630-
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]))
1630+
super(sbitfixvec, self).__init__(None, k=k, *args, **kwargs)
1631+
self.int_type = sbitintvec.get_type(self.k)
1632+
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]).v)
16311633
else:
1634+
self.k = k or self.k
1635+
self.int_type = sbitintvec.get_type(self.k)
16321636
if isinstance(value, sbitvec):
16331637
value = self.int_type(value)
16341638
super(sbitfixvec, self).__init__(value, *args, **kwargs)

Compiler/mpc_math.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,12 @@ class my_fix(type(a)):
274274
# improve precision
275275
my_fix.set_precision(a.k - 2, a.k)
276276
n_shift = a.k - 2 - a.f
277+
res_k = 2 * a.k - n_shift
277278
x = my_fix._new(frac.v << n_shift)
278279
# evaluates fractional part of a in p_1045
279280
e = p_eval(p_1045, x)
280281
g = a._new(whole_exp.TruncMul(e.v, 2 * a.k, n_shift,
281-
nearest=a.round_nearest), a.k, a.f)
282+
nearest=a.round_nearest), res_k, a.f)
282283
return g
283284
# how many bits to use from integer part
284285
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
@@ -368,15 +369,15 @@ class my_fix(type(a)):
368369
pow2_bits = [sint.conv(x) for x in higher_bits]
369370
d = floatingpoint.Pow2_from_bits(pow2_bits)
370371
g = exp_from_parts(d, c)
371-
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits,
372+
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits + 1,
372373
2 ** n_int_bits, signed=False,
373374
nearest=a.round_nearest),
374375
k=a.k, f=a.f)
375376
if zero_output:
376377
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
377378
bits_to_check))
378379
small_result = t.if_else(small_result, 0)
379-
return s.if_else(small_result, g)
380+
return s.if_else(small_result, a._new(g.v, k=a.k, f=a.f))
380381
else:
381382
assert not zero_output
382383
# obtain absolute value of a

Compiler/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4744,7 +4744,7 @@ def conv(cls, other):
47444744
@classmethod
47454745
def _new(cls, other, k=None, f=None):
47464746
res = cls(k=k, f=f, initialize=False)
4747-
res.v = cls.int_type.conv(other)
4747+
res.v = res.int_type.conv(other)
47484748
return res
47494749

47504750
@vectorize_init

0 commit comments

Comments
 (0)