Skip to content

Commit 520767a

Browse files
committed
update VariableView
1 parent 6685e1a commit 520767a

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

brainpy/math/jaxarray.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,43 +1528,43 @@ def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
15281528
class VariableView(Variable):
15291529
"""A view of a Variable instance.
15301530
1531-
This class is used to create a slice view of ``brainpy.math.Variable``.
1531+
This class is used to create a subset view of ``brainpy.math.Variable``.
1532+
1533+
>>> import brainpy.math as bm
1534+
>>> bm.random.seed(123)
1535+
>>> origin = bm.Variable(bm.random.random(5))
1536+
>>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2]
1537+
VariableView([0.02920651, 0.19066381], dtype=float32)
15321538
15331539
``VariableView`` can be used to update the subset of the original
15341540
Variable instance, and make operations on this subset of the Variable.
1541+
1542+
>>> view[:] = 1.
1543+
>>> view
1544+
VariableView([1., 1.], dtype=float32)
1545+
>>> origin
1546+
Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32)
1547+
>>> view + 10
1548+
DeviceArray([11., 11.], dtype=float32)
1549+
>>> view *= 10
1550+
VariableView([10., 10.], dtype=float32)
1551+
1552+
The above example demonstrates that the updating of an ``VariableView`` instance
1553+
is actually made in the original ``Variable`` instance.
1554+
1555+
Moreover, it's worthy to note that ``VariableView`` is not a PyTree.
15351556
"""
15361557
def __init__(self, value: Variable, index):
15371558
self.index = index
15381559
if not isinstance(value, Variable):
15391560
raise ValueError('Must be instance of Variable.')
1540-
temp_shape = tuple([1] * len(index))
1541-
super(VariableView, self).__init__(jnp.zeros(temp_shape), batch_axis=value.batch_axis)
1561+
super(VariableView, self).__init__(value.value, batch_axis=value.batch_axis)
15421562
self._value = value
15431563

15441564
@property
15451565
def value(self):
15461566
return self._value[self.index]
15471567

1548-
@value.setter
1549-
def value(self, value):
1550-
int_shape = self.shape
1551-
if self.batch_axis is None:
1552-
ext_shape = value.shape
1553-
else:
1554-
ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:]
1555-
int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:]
1556-
if ext_shape != int_shape:
1557-
error = f"The shape of the original data is {int_shape}, while we got {value.shape}"
1558-
if self.batch_axis is None:
1559-
error += '. Do you forget to set "batch_axis" when initialize this variable?'
1560-
else:
1561-
error += f' with batch_axis={self.batch_axis}.'
1562-
raise MathError(error)
1563-
if value.dtype != self._value.dtype:
1564-
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
1565-
f"while we got {value.dtype}.")
1566-
self._value[self.index] = value
1567-
15681568
def __setitem__(self, index, value):
15691569
# value is JaxArray
15701570
if isinstance(value, JaxArray):
@@ -1653,3 +1653,23 @@ def fill(self, value):
16531653
def sort(self, axis=-1, kind=None, order=None):
16541654
"""Sort an array in-place."""
16551655
self._value[self.index] = self.value.sort(axis=axis, kind=kind, order=order)
1656+
1657+
def update(self, value):
1658+
if self.batch_axis is None:
1659+
ext_shape = value.shape
1660+
int_shape = self.shape
1661+
else:
1662+
ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:]
1663+
int_shape = self.shape[:self.batch_axis] + self.shape[self.batch_axis + 1:]
1664+
if ext_shape != int_shape:
1665+
error = f"The shape of the original data is {self.shape}, while we got {value.shape}"
1666+
if self.batch_axis is None:
1667+
error += '. Do you forget to set "batch_axis" when initialize this variable?'
1668+
else:
1669+
error += f' with batch_axis={self.batch_axis}.'
1670+
raise MathError(error)
1671+
if value.dtype != self._value.dtype:
1672+
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
1673+
f"while we got {value.dtype}.")
1674+
self._value[self.index] = value.value if isinstance(value, JaxArray) else value
1675+

0 commit comments

Comments
 (0)