@@ -1521,41 +1521,13 @@ static paddle::Tensor& GetTensorFromPyObject(const std::string& op_type,
15211521 }
15221522}
15231523
1524- // For Intermediate State Dygraph,
1525- // we use an uninitialized Tensor to represent dispensable Tensor
1526- paddle::Tensor& GetTensorFromArgs (const std::string& op_type,
1527- const std::string& arg_name,
1528- PyObject* args,
1529- ssize_t arg_idx,
1530- bool dispensable) {
1531- PyObject* obj = PyTuple_GET_ITEM (args, arg_idx);
1532- return GetTensorFromPyObject (op_type, arg_name, obj, arg_idx, dispensable);
1533- }
1534-
1535- paddle::Tensor& GetTensorFromArgsOrKWArgs (
1536- const std::string& op_type,
1537- const std::string& arg_name,
1538- PyObject* args,
1539- ssize_t arg_idx,
1540- PyObject* kwargs,
1541- const std::vector<std::string>& keywords,
1542- const int nargs,
1543- int * remaining_kwargs,
1544- bool dispensable) {
1545- PyObject* obj = GetItemFromArgsOrKWArgs (
1546- args, arg_idx, kwargs, keywords, nargs, remaining_kwargs);
1547- return GetTensorFromPyObject (op_type, arg_name, obj, arg_idx, dispensable);
1548- }
1549-
1550- std::vector<paddle::Tensor> GetTensorListFromArgs (
1524+ std::vector<paddle::Tensor> GetTensorListFromPyObject_ (
15511525 const std::string& op_type,
15521526 const std::string& arg_name,
1553- PyObject* args ,
1527+ PyObject* list ,
15541528 ssize_t arg_idx,
15551529 bool dispensable,
15561530 const phi::distributed::ProcessMesh* mesh) {
1557- PyObject* list = PyTuple_GET_ITEM (args, arg_idx);
1558-
15591531 if (list == nullptr ) {
15601532 if (!dispensable) {
15611533 PADDLE_THROW (common::errors::InvalidArgument (
@@ -1671,6 +1643,61 @@ std::vector<paddle::Tensor> GetTensorListFromArgs(
16711643 return result;
16721644}
16731645
1646+ // For Intermediate State Dygraph,
1647+ // we use an uninitialized Tensor to represent dispensable Tensor
1648+ paddle::Tensor& GetTensorFromArgs (const std::string& op_type,
1649+ const std::string& arg_name,
1650+ PyObject* args,
1651+ ssize_t arg_idx,
1652+ bool dispensable) {
1653+ PyObject* obj = PyTuple_GET_ITEM (args, arg_idx);
1654+ return GetTensorFromPyObject (op_type, arg_name, obj, arg_idx, dispensable);
1655+ }
1656+
1657+ paddle::Tensor& GetTensorFromArgsOrKWArgs (
1658+ const std::string& op_type,
1659+ const std::string& arg_name,
1660+ PyObject* args,
1661+ ssize_t arg_idx,
1662+ PyObject* kwargs,
1663+ const std::vector<std::string>& keywords,
1664+ const int nargs,
1665+ int * remaining_kwargs,
1666+ bool dispensable) {
1667+ PyObject* obj = GetItemFromArgsOrKWArgs (
1668+ args, arg_idx, kwargs, keywords, nargs, remaining_kwargs);
1669+ return GetTensorFromPyObject (op_type, arg_name, obj, arg_idx, dispensable);
1670+ }
1671+
1672+ std::vector<paddle::Tensor> GetTensorListFromArgs (
1673+ const std::string& op_type,
1674+ const std::string& arg_name,
1675+ PyObject* args,
1676+ ssize_t arg_idx,
1677+ bool dispensable,
1678+ const phi::distributed::ProcessMesh* mesh) {
1679+ PyObject* list = PyTuple_GET_ITEM (args, arg_idx);
1680+ return GetTensorListFromPyObject_ (
1681+ op_type, arg_name, list, arg_idx, dispensable, mesh);
1682+ }
1683+
1684+ std::vector<paddle::Tensor> GetTensorListFromArgsOrKWArgs (
1685+ const std::string& op_type,
1686+ const std::string& arg_name,
1687+ PyObject* args,
1688+ ssize_t arg_idx,
1689+ PyObject* kwargs,
1690+ const std::vector<std::string>& keywords,
1691+ const int nargs,
1692+ int * remaining_kwargs,
1693+ bool dispensable,
1694+ const phi::distributed::ProcessMesh* mesh) {
1695+ PyObject* list = GetItemFromArgsOrKWArgs (
1696+ args, arg_idx, kwargs, keywords, nargs, remaining_kwargs);
1697+ return GetTensorListFromPyObject_ (
1698+ op_type, arg_name, list, arg_idx, dispensable, mesh);
1699+ }
1700+
16741701paddle::optional<std::vector<paddle::Tensor>> GetOptionalTensorListFromArgs (
16751702 const std::string& op_type,
16761703 const std::string& arg_name,
0 commit comments