Skip to content

Commit b5f950b

Browse files
authored
PyMethod inputs use tensor_ptr_maker to convert attensor to etensor
Differential Revision: D78035237 Pull Request resolved: #12315
1 parent 31ba959 commit b5f950b

File tree

1 file changed

+12
-24
lines changed

1 file changed

+12
-24
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,16 +1098,9 @@ struct PyMethod final {
10981098
#ifndef USE_ATEN_LIB // Portable mode
10991099
// So the ETensors and their metadata stay in scope for
11001100
// Module->set_inputs.
1101-
std::vector<torch::executor::TensorImpl> input_tensors;
1102-
std::vector<std::vector<torch::executor::Tensor::SizesType>> input_sizes;
1103-
std::vector<std::vector<torch::executor::Tensor::StridesType>>
1104-
input_strides;
1105-
std::vector<std::vector<torch::executor::Tensor::DimOrderType>>
1106-
input_dim_order;
1101+
std::vector<TensorPtr> input_tensors;
11071102
// We store pointers to these vector elements so important to reserve so
1108-
// that we don't lose those on a vector resize. Don't need to do this for
1109-
// the others since they are vectors of vectors, and we don't store a
1110-
// pointer to the root level vector data.
1103+
// that we don't lose those on a vector resize.
11111104
input_tensors.reserve(inputs_size);
11121105
#endif
11131106

@@ -1127,9 +1120,9 @@ struct PyMethod final {
11271120
size_t dim = at_tensor.dim();
11281121
// cant directly alias at::Tensor sizes and strides due to int64 vs
11291122
// int32 typing conflict
1130-
input_sizes.emplace_back(
1123+
std::vector<int> sizes(
11311124
at_tensor.sizes().begin(), at_tensor.sizes().end());
1132-
input_strides.emplace_back(
1125+
std::vector<int> strides(
11331126
at_tensor.strides().begin(), at_tensor.strides().end());
11341127

11351128
// Only works for MemoryFormat::Contiguous or MemoryFormat::ChannelsLast
@@ -1149,19 +1142,14 @@ struct PyMethod final {
11491142
" should be contiguous or channels-last.";
11501143
throw std::runtime_error(error_msg);
11511144
}
1152-
input_dim_order.push_back(std::move(dim_order));
1153-
input_tensors.emplace_back(
1154-
type,
1155-
dim,
1156-
input_sizes.back().data(),
1157-
nullptr,
1158-
input_dim_order.back().data(),
1159-
input_strides.back().data());
1160-
1161-
torch::executor::Tensor temp =
1162-
torch::executor::Tensor(&input_tensors.back());
1163-
alias_etensor_to_attensor(at_tensor, temp);
1164-
EValue evalue(temp);
1145+
TensorPtr tensor =
1146+
for_blob(at_tensor.data_ptr(), std::move(sizes), type)
1147+
.strides(std::move(strides))
1148+
.dim_order(std::move(dim_order))
1149+
.dynamism(aten::TensorShapeDynamism::STATIC)
1150+
.make_tensor_ptr();
1151+
input_tensors.push_back(tensor);
1152+
EValue evalue(input_tensors.back());
11651153
#endif
11661154

11671155
cpp_inputs.push_back(evalue);

0 commit comments

Comments
 (0)