@@ -1584,3 +1584,85 @@ def split(input, num_or_sections, dim=-1):
1584
1584
'axis' : dim
1585
1585
})
1586
1586
return outs
1587
+
1588
+
1589
+ def matmul (x , y ):
1590
+ """
1591
+ Applies matrix multipication to two tensors.
1592
+
1593
+ This operator is used to perform (batched) matrix multiplication
1594
+ over the last two dimensions of the input tensors `X` and `Y`.
1595
+
1596
+ If a transpose flag is specified, the last two dimensions of the
1597
+ tensor are transposed. If the tensor is rank-1 of shape [D], then
1598
+ for `X` it is treated as [1, D] in nontransposed form and as [D, 1]
1599
+ in transposed form, whereas for `Y` it is the opposite: It is treated
1600
+ as [D, 1] in nontransposed form and as [1, D] in transposed form.
1601
+
1602
+ Examples without transpose:
1603
+ - X: [K], Y: [K] => Out: [1]
1604
+ - X: [K], Y: [K, N] => Out: [N]
1605
+ - X: [B, M, K], Y: [K] => Out: [B, M]
1606
+ - X: [M, K], Y: [B, K, N] => Out: [B, M, N]
1607
+ - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
1608
+
1609
+ The behavior is designed to be similar to the `numpy.matmul` function.
1610
+ The differences are:
1611
+ - Currently only rank 1 to rank 3 input tensors are supported.
1612
+ - We add `transpose_X` and `transpose_Y` flags.
1613
+
1614
+ Both the input `X` and `Y` can carry the LoD (Level of Details) information,
1615
+ or not. But the output only shares the LoD information with input `X`.
1616
+
1617
+ Args:
1618
+ x (Variable): The input variable which is a Tensor or LoDTensor.
1619
+ y (Variable): If :attr:`num_or_sections` is an integer,
1620
+ then the integer indicates the number of equal sized sub-tensors
1621
+ that the tensor will be divided into. If :attr:`num_or_sections`
1622
+ is a list of integers, the length of list indicates the number of
1623
+ sub-tensors and the integers indicate the sizes of sub-tensors'
1624
+ :attr:`dim` dimension orderly.
1625
+ dim (int): The dimension along which to split. If :math:`dim < 0`, the
1626
+ dimension to split along is :math:`rank(input) + dim`.
1627
+
1628
+ Returns:
1629
+ List: The list of segmented tensor variables.
1630
+
1631
+ Examples:
1632
+ .. code-block:: python
1633
+
1634
+ # x is a Tensor variable with shape [3, 9, 5]:
1635
+ x0, x1, x2 = fluid.layers.split(x, num_or_sections=3, dim=1)
1636
+ x0.shape # [3, 3, 5]
1637
+ x1.shape # [3, 3, 5]
1638
+ x2.shape # [3, 3, 5]
1639
+ x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1)
1640
+ x0.shape # [3, 2, 5]
1641
+ x1.shape # [3, 3, 5]
1642
+ x2.shape # [3, 4, 5]
1643
+ """
1644
+ helper = LayerHelper ('split' , ** locals ())
1645
+ input_shape = input .shape
1646
+ dim = (len (input_shape ) + dim ) if dim < 0 else dim
1647
+ if isinstance (num_or_sections , int ):
1648
+ assert num_or_sections > 1 , 'num_or_sections must be more than 1.'
1649
+ num = num_or_sections
1650
+ else :
1651
+ assert len (num_or_sections ) < input_shape [
1652
+ dim ], 'len(num_or_sections) must not be more than input.shape[dim].'
1653
+ num = len (num_or_sections )
1654
+ outs = [
1655
+ helper .create_tmp_variable (dtype = helper .input_dtype ())
1656
+ for i in range (num )
1657
+ ]
1658
+ helper .append_op (
1659
+ type = 'split' ,
1660
+ inputs = {'X' : input },
1661
+ outputs = {'Out' : outs },
1662
+ attrs = {
1663
+ 'num' : num_or_sections if isinstance (num_or_sections , int ) else 0 ,
1664
+ 'sections' : num_or_sections
1665
+ if isinstance (num_or_sections , list ) else [],
1666
+ 'axis' : dim
1667
+ })
1668
+ return outs
0 commit comments