@@ -1528,43 +1528,43 @@ def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
15281528class VariableView (Variable ):
15291529 """A view of a Variable instance.
15301530
1531- This class is used to create a slice view of ``brainpy.math.Variable``.
1531+ This class is used to create a subset view of ``brainpy.math.Variable``.
1532+
1533+ >>> import brainpy.math as bm
1534+ >>> bm.random.seed(123)
1535+ >>> origin = bm.Variable(bm.random.random(5))
1536+ >>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2]
1537+ VariableView([0.02920651, 0.19066381], dtype=float32)
15321538
15331539 ``VariableView`` can be used to update the subset of the original
15341540 Variable instance, and make operations on this subset of the Variable.
1541+
1542+ >>> view[:] = 1.
1543+ >>> view
1544+ VariableView([1., 1.], dtype=float32)
1545+ >>> origin
1546+ Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32)
1547+ >>> view + 10
1548+ DeviceArray([11., 11.], dtype=float32)
1549+ >>> view *= 10
1550+ VariableView([10., 10.], dtype=float32)
1551+
1552+ The above example demonstrates that the updating of an ``VariableView`` instance
1553+ is actually made in the original ``Variable`` instance.
1554+
1555+ Moreover, it's worthy to note that ``VariableView`` is not a PyTree.
15351556 """
15361557 def __init__ (self , value : Variable , index ):
15371558 self .index = index
15381559 if not isinstance (value , Variable ):
15391560 raise ValueError ('Must be instance of Variable.' )
1540- temp_shape = tuple ([1 ] * len (index ))
1541- super (VariableView , self ).__init__ (jnp .zeros (temp_shape ), batch_axis = value .batch_axis )
1561+ super (VariableView , self ).__init__ (value .value , batch_axis = value .batch_axis )
15421562 self ._value = value
15431563
15441564 @property
15451565 def value (self ):
15461566 return self ._value [self .index ]
15471567
1548- @value .setter
1549- def value (self , value ):
1550- int_shape = self .shape
1551- if self .batch_axis is None :
1552- ext_shape = value .shape
1553- else :
1554- ext_shape = value .shape [:self .batch_axis ] + value .shape [self .batch_axis + 1 :]
1555- int_shape = int_shape [:self .batch_axis ] + int_shape [self .batch_axis + 1 :]
1556- if ext_shape != int_shape :
1557- error = f"The shape of the original data is { int_shape } , while we got { value .shape } "
1558- if self .batch_axis is None :
1559- error += '. Do you forget to set "batch_axis" when initialize this variable?'
1560- else :
1561- error += f' with batch_axis={ self .batch_axis } .'
1562- raise MathError (error )
1563- if value .dtype != self ._value .dtype :
1564- raise MathError (f"The dtype of the original data is { self ._value .dtype } , "
1565- f"while we got { value .dtype } ." )
1566- self ._value [self .index ] = value
1567-
15681568 def __setitem__ (self , index , value ):
15691569 # value is JaxArray
15701570 if isinstance (value , JaxArray ):
@@ -1653,3 +1653,23 @@ def fill(self, value):
16531653 def sort (self , axis = - 1 , kind = None , order = None ):
16541654 """Sort an array in-place."""
16551655 self ._value [self .index ] = self .value .sort (axis = axis , kind = kind , order = order )
1656+
1657+ def update (self , value ):
1658+ if self .batch_axis is None :
1659+ ext_shape = value .shape
1660+ int_shape = self .shape
1661+ else :
1662+ ext_shape = value .shape [:self .batch_axis ] + value .shape [self .batch_axis + 1 :]
1663+ int_shape = self .shape [:self .batch_axis ] + self .shape [self .batch_axis + 1 :]
1664+ if ext_shape != int_shape :
1665+ error = f"The shape of the original data is { self .shape } , while we got { value .shape } "
1666+ if self .batch_axis is None :
1667+ error += '. Do you forget to set "batch_axis" when initialize this variable?'
1668+ else :
1669+ error += f' with batch_axis={ self .batch_axis } .'
1670+ raise MathError (error )
1671+ if value .dtype != self ._value .dtype :
1672+ raise MathError (f"The dtype of the original data is { self ._value .dtype } , "
1673+ f"while we got { value .dtype } ." )
1674+ self ._value [self .index ] = value .value if isinstance (value , JaxArray ) else value
1675+
0 commit comments