@@ -159,6 +159,28 @@ def __eq__(self, other):
159159 return (isinstance (other , Shape ) and
160160 self .width == other .width and self .signed == other .signed )
161161
162+ @staticmethod
163+ def _unify (shapes ):
164+ """Returns the minimal shape that contains all shapes from the list.
165+
166+ If no shapes passed in, returns unsigned(0).
167+ """
168+ unsigned_width = signed_width = 0
169+ has_signed = False
170+ for shape in shapes :
171+ assert isinstance (shape , Shape )
172+ if shape .signed :
173+ has_signed = True
174+ signed_width = max (signed_width , shape .width )
175+ else :
176+ unsigned_width = max (unsigned_width , shape .width )
177+ # If all shapes unsigned, simply take max.
178+ if not has_signed :
179+ return unsigned (unsigned_width )
180+ # Otherwise, result is signed. All unsigned inputs, if any,
181+ # need to be converted to signed by adding a zero bit.
182+ return signed (max (signed_width , unsigned_width + 1 ))
183+
162184
163185def unsigned (width ):
164186 """Returns :py:`Shape(width, signed=False)`."""
@@ -1524,20 +1546,6 @@ def operands(self):
15241546 return self ._operands
15251547
15261548 def shape (self ):
1527- def _bitwise_binary_shape (a_shape , b_shape ):
1528- if not a_shape .signed and not b_shape .signed :
1529- # both operands unsigned
1530- return unsigned (max (a_shape .width , b_shape .width ))
1531- elif a_shape .signed and b_shape .signed :
1532- # both operands signed
1533- return signed (max (a_shape .width , b_shape .width ))
1534- elif not a_shape .signed and b_shape .signed :
1535- # first operand unsigned (add sign bit), second operand signed
1536- return signed (max (a_shape .width + 1 , b_shape .width ))
1537- else :
1538- # first signed, second operand unsigned (add sign bit)
1539- return signed (max (a_shape .width , b_shape .width + 1 ))
1540-
15411549 op_shapes = list (map (lambda x : x .shape (), self .operands ))
15421550 if len (op_shapes ) == 1 :
15431551 a_shape , = op_shapes
@@ -1554,10 +1562,10 @@ def _bitwise_binary_shape(a_shape, b_shape):
15541562 elif len (op_shapes ) == 2 :
15551563 a_shape , b_shape = op_shapes
15561564 if self .operator == "+" :
1557- o_shape = _bitwise_binary_shape ( * op_shapes )
1565+ o_shape = Shape . _unify ( op_shapes )
15581566 return Shape (o_shape .width + 1 , o_shape .signed )
15591567 if self .operator == "-" :
1560- o_shape = _bitwise_binary_shape ( * op_shapes )
1568+ o_shape = Shape . _unify ( op_shapes )
15611569 return Shape (o_shape .width + 1 , True )
15621570 if self .operator == "*" :
15631571 return Shape (a_shape .width + b_shape .width , a_shape .signed or b_shape .signed )
@@ -1568,7 +1576,7 @@ def _bitwise_binary_shape(a_shape, b_shape):
15681576 if self .operator in ("<" , "<=" , "==" , "!=" , ">" , ">=" ):
15691577 return Shape (1 , False )
15701578 if self .operator in ("&" , "|" , "^" ):
1571- return _bitwise_binary_shape ( * op_shapes )
1579+ return Shape . _unify ( op_shapes )
15721580 if self .operator == "<<" :
15731581 assert not b_shape .signed
15741582 return Shape (a_shape .width + 2 ** b_shape .width - 1 , a_shape .signed )
@@ -1578,7 +1586,7 @@ def _bitwise_binary_shape(a_shape, b_shape):
15781586 elif len (op_shapes ) == 3 :
15791587 if self .operator == "m" :
15801588 s_shape , a_shape , b_shape = op_shapes
1581- return _bitwise_binary_shape ( a_shape , b_shape )
1589+ return Shape . _unify (( a_shape , b_shape ) )
15821590 raise NotImplementedError # :nocov:
15831591
15841592 def _lhs_signals (self ):
@@ -2254,27 +2262,9 @@ def _iter_as_values(self):
22542262 return (Value .cast (elem ) for elem in self .elems )
22552263
22562264 def shape (self ):
2257- unsigned_width = signed_width = 0
2258- has_unsigned = has_signed = False
2259- for elem_shape in (elem .shape () for elem in self ._iter_as_values ()):
2260- if elem_shape .signed :
2261- has_signed = True
2262- signed_width = max (signed_width , elem_shape .width )
2263- else :
2264- has_unsigned = True
2265- unsigned_width = max (unsigned_width , elem_shape .width )
22662265 # The shape of the proxy must be such that it preserves the mathematical value of the array
22672266 # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
2268- # To ensure this holds, if the array contains both signed and unsigned values, make sure
2269- # that every unsigned value is zero-extended by at least one bit.
2270- if has_signed and has_unsigned and unsigned_width >= signed_width :
2271- # Array contains both signed and unsigned values, and at least one of the unsigned
2272- # values won't be zero-extended otherwise.
2273- return signed (unsigned_width + 1 )
2274- else :
2275- # Array contains values of the same signedness, or else all of the unsigned values
2276- # are zero-extended.
2277- return Shape (max (unsigned_width , signed_width ), has_signed )
2267+ return Shape ._unify (elem .shape () for elem in self ._iter_as_values ())
22782268
22792269 def _lhs_signals (self ):
22802270 signals = union ((elem ._lhs_signals () for elem in self ._iter_as_values ()),
0 commit comments