@@ -1741,117 +1741,140 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) {
17411741}
17421742
17431743#if !defined(DISABLE_SPARSE_TENSORS)
1744- static Status CopySparseData (size_t n_sparse_elements,
1744+ static Status CopySparseData (const std::string& name,
1745+ int64_t nnz_elements,
17451746 const ONNX_NAMESPACE::TensorProto& indices,
17461747 const std::filesystem::path& model_path,
1747- gsl::span<const int64_t >
1748- dims,
1749- std::function<void (size_t from_idx, size_t to_idx)>
1750- copier) {
1748+ gsl::span<const int64_t > dense_dims,
1749+ int64_t dense_elements,
1750+ std::function<void (size_t from_idx, size_t to_idx)> copier) {
17511751 Status status = Status::OK ();
17521752 TensorShape indices_shape (indices.dims ().data (), indices.dims ().size ());
1753- const auto elements = narrow< size_t >( indices_shape.Size () );
1753+ const int64_t indices_elements = indices_shape.Size ();
17541754
1755- std::vector <int64_t > indices_values; // used for conversion of smaller size indices
1755+ InlinedVector <int64_t > indices_values; // used for conversion of smaller size indices
17561756 std::vector<uint8_t > unpack_buffer;
17571757 gsl::span<const int64_t > indices_data;
1758- const bool has_raw_data = indices. has_raw_data ( );
1758+ const bool needs_unpack = utils::HasRawData ( indices) || utils::HasExternalData (indices );
17591759 switch (indices.data_type ()) {
17601760 case ONNX_NAMESPACE::TensorProto_DataType_INT64:
1761- if (has_raw_data) {
1762- ORT_RETURN_IF_NOT (indices.raw_data ().size () == (elements * sizeof (int64_t )),
1763- " Sparse Indices raw data size does not match expected." );
1761+ if (needs_unpack) {
1762+ ORT_RETURN_IF_NOT (indices.raw_data ().size () == (narrow<size_t >(indices_elements) * sizeof (int64_t )),
1763+ " Sparse tensor: " , name, " indices raw data size does not match expected: " ,
1764+ indices_elements * sizeof (int64_t ));
17641765 ORT_RETURN_IF_ERROR (UnpackInitializerData (indices, model_path, unpack_buffer));
17651766 indices_data = ReinterpretAsSpan<const int64_t >(gsl::make_span (unpack_buffer));
17661767 } else {
1767- ORT_RETURN_IF_NOT (indices.int64_data_size () == static_cast <int64_t >(elements),
1768- " Sparse indices int64 data size does not match expected" );
1769- indices_data = gsl::make_span (indices.int64_data ().data (), elements);
1768+ ORT_RETURN_IF_NOT (indices.int64_data_size () == indices_elements,
1769+ " Sparse tensor: " , name, " indices int64 data size does not match expected: " ,
1770+ indices_elements);
1771+ indices_data = gsl::make_span (indices.int64_data ().data (), narrow<size_t >(indices_elements));
17701772 }
17711773 break ;
17721774 case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
1773- if (has_raw_data) {
1774- ORT_RETURN_IF_NOT (indices.raw_data ().size () == (elements * sizeof (int32_t )),
1775- " Sparse Indices raw data size does not match expected." );
1775+ if (needs_unpack) {
1776+ ORT_RETURN_IF_NOT (indices.raw_data ().size () == (narrow<size_t >(indices_elements) * sizeof (int32_t )),
1777+ " Sparse tensor: " , name, " indices raw data size does not match expected: " ,
1778+ indices_elements * sizeof (int32_t ));
17761779 ORT_RETURN_IF_ERROR (UnpackInitializerData (indices, model_path, unpack_buffer));
17771780 auto int32_span = ReinterpretAsSpan<const int32_t >(gsl::make_span (unpack_buffer));
17781781 indices_values.insert (indices_values.cend (), int32_span.begin (), int32_span.end ());
17791782 unpack_buffer.clear ();
17801783 unpack_buffer.shrink_to_fit ();
17811784 } else {
1782- ORT_RETURN_IF_NOT (indices.int32_data_size () == static_cast <int64_t >(elements),
1783- " Sparse indices int32 data size does not match expected" );
1785+ ORT_RETURN_IF_NOT (indices.int32_data_size () == indices_elements,
1786+ " Sparse tensor: " , name, " indices int32 data size does not match expected: " ,
1787+ indices_elements);
17841788 indices_values.insert (indices_values.cend (), indices.int32_data ().cbegin (), indices.int32_data ().cend ());
17851789 }
17861790 indices_data = gsl::make_span (indices_values);
17871791 break ;
17881792 }
17891793 case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
1790- if (has_raw_data) {
1791- ORT_RETURN_IF_NOT (indices.raw_data ().size () == (elements * sizeof (int16_t )),
1792- " Sparse Indices raw data size does not match expected." );
1794+ if (needs_unpack) {
1795+ ORT_RETURN_IF_NOT (indices.raw_data ().size () == (narrow<size_t >(indices_elements) * sizeof (int16_t )),
1796+ " Sparse tensor: " , name, " indices raw data size does not match expected: " ,
1797+ indices_elements * sizeof (int16_t ));
17931798 ORT_RETURN_IF_ERROR (UnpackInitializerData (indices, model_path, unpack_buffer));
17941799 auto int16_span = ReinterpretAsSpan<const int16_t >(gsl::make_span (unpack_buffer));
17951800 indices_values.insert (indices_values.cend (), int16_span.begin (), int16_span.end ());
1796- indices_data = gsl::make_span (indices_values);
17971801 unpack_buffer.clear ();
17981802 unpack_buffer.shrink_to_fit ();
17991803 } else {
1800- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1801- " Invalid SparseTensor indices. INT16 indices must be in the raw data of indices tensor" );
1804+ ORT_RETURN_IF_NOT (indices.int32_data_size () == indices_elements,
1805+ " Sparse tensor: " , name, " indices int16 data size does not match expected: " ,
1806+ indices_elements);
1807+ indices_values.insert (indices_values.cend (), indices.int32_data ().cbegin (), indices.int32_data ().cend ());
18021808 }
1809+ indices_data = gsl::make_span (indices_values);
18031810 break ;
18041811 }
18051812 case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
1806- if (has_raw_data) {
1807- ORT_RETURN_IF_NOT (indices.raw_data ().size () == elements,
1808- " Sparse Indices raw data size does not match expected." );
1813+ if (needs_unpack) {
1814+ ORT_RETURN_IF_NOT (indices.raw_data ().size () == narrow<size_t >(indices_elements),
1815+ " Sparse tensor: " , name, " indices raw data size does not match expected: " ,
1816+ indices_elements * sizeof (int8_t ));
18091817 ORT_RETURN_IF_ERROR (UnpackInitializerData (indices, model_path, unpack_buffer));
18101818 auto int8_span = ReinterpretAsSpan<const int8_t >(gsl::make_span (unpack_buffer));
18111819 indices_values.insert (indices_values.cend (), int8_span.begin (), int8_span.end ());
1812- indices_data = gsl::make_span (indices_values);
18131820 unpack_buffer.clear ();
18141821 unpack_buffer.shrink_to_fit ();
18151822 } else {
1816- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1817- " Invalid SparseTensor indices. INT8 indices must be in the raw data of indices tensor" );
1823+ ORT_RETURN_IF_NOT (indices.int32_data_size () == indices_elements,
1824+ " Sparse tensor: " , name, " indices int8 data size does not match expected: " ,
1825+ indices_elements);
1826+ indices_values.insert (indices_values.cend (), indices.int32_data ().cbegin (), indices.int32_data ().cend ());
18181827 }
1828+ indices_data = gsl::make_span (indices_values);
18191829 break ;
18201830 }
18211831 default :
18221832 return ORT_MAKE_STATUS (
18231833 ONNXRUNTIME, INVALID_GRAPH,
1824- " Invalid SparseTensor indices. Should one of the following types: int8, int16, int32 or int64" );
1834+ " Sparse tensor: " , name, " indices. Should be one of the following types: int8, int16, int32 or int64" );
18251835 }
18261836
1827- if (indices_shape.NumDimensions () == 1 ) {
1837+ const auto indices_rank = indices_shape.NumDimensions ();
1838+ if (indices_rank == 1 ) {
18281839 // flattened indexes
1829- for (size_t i = 0 ; i < n_sparse_elements; ++i) {
1830- copier (i, narrow<size_t >(indices_data[i]));
1840+ for (size_t i = 0 , lim = narrow<size_t >(nnz_elements); i < lim; ++i) {
1841+ const auto idx = indices_data[i];
1842+ ORT_RETURN_IF_NOT (idx >= 0 && idx < dense_elements,
1843+ " Sparse tensor: " , name, " index is out of bounds. Got:" , idx,
1844+ " expected to be in [0, " , dense_elements, " )" );
1845+
1846+ copier (i, narrow<size_t >(idx));
18311847 }
1832- } else if (indices_shape. NumDimensions () == 2 ) {
1848+ } else if (indices_rank == 2 ) {
18331849 // entries in format {NNZ, rank}
1834- ORT_ENFORCE (indices_shape[1 ] > 0 && static_cast <size_t >(indices_shape[1 ]) == dims .size ());
1835- auto rank = static_cast <size_t >(indices_shape[1 ]);
1850+ ORT_ENFORCE (indices_shape[1 ] > 0 && static_cast <size_t >(indices_shape[1 ]) == dense_dims .size ());
1851+ const auto rank = static_cast <size_t >(indices_shape[1 ]);
18361852 auto cur_index = indices_data.begin ();
1837- std::vector <size_t > multipliers;
1853+ InlinedVector <size_t > multipliers;
18381854 multipliers.resize (rank);
18391855
18401856 // calculate sum of inner dimension elements for each dimension.
18411857 // e.g. if shape {2,3,4}, the result should be {3*4, 4, 1}
18421858 multipliers[rank - 1 ] = 1 ;
18431859 for (auto r = rank - 1 ; r > 0 ; --r) {
1844- multipliers[r - 1 ] = SafeInt<size_t >(dims [r]) * multipliers[r];
1860+ multipliers[r - 1 ] = SafeInt<size_t >(dense_dims [r]) * multipliers[r];
18451861 }
18461862
18471863 // calculate the offset for the entry
18481864 // e.g. if shape was {2,3,4} and entry was (1, 0, 2) the offset is 14
18491865 // as there are 2 rows, each with 12 entries per row
1850- for (size_t i = 0 ; i < n_sparse_elements ; ++i) {
1866+ for (size_t i = 0 , lim = narrow< size_t >(nnz_elements) ; i < lim ; ++i) {
18511867 SafeInt<int64_t > idx = 0 ;
18521868 for (size_t j = 0 ; j < rank; ++j) {
1853- idx += SafeInt<int64_t >(cur_index[j]) * multipliers[j];
1869+ const auto dim_index = cur_index[j];
1870+ ORT_RETURN_IF_NOT (dim_index >= 0 && dim_index < dense_dims[j],
1871+ " Sparse tensor: " , name, " index is out of bounds. Got:" , dim_index,
1872+ " expected to be in [0, " , dense_dims[j], " )" );
1873+ idx += SafeInt<int64_t >(dim_index) * multipliers[j];
18541874 }
1875+ ORT_RETURN_IF_NOT (idx >= 0 && idx < dense_elements,
1876+ " Sparse tensor: " , name, " index is out of bounds. Got:" , static_cast <int64_t >(idx),
1877+ " expected to be in [0, " , dense_elements, " )" );
18551878
18561879 copier (i, static_cast <size_t >(idx));
18571880 cur_index += rank;
@@ -1860,7 +1883,7 @@ static Status CopySparseData(size_t n_sparse_elements,
18601883 ORT_ENFORCE (cur_index == indices_data.end ());
18611884 } else {
18621885 status = ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1863- " Invalid SparseTensor indices. Should be rank 0 or 1 . Got:" , indices_shape);
1886+ " Sparse tensor: " , name, " indices shape. Expected to be rank 1 or 2 . Got:" , indices_shape);
18641887 }
18651888
18661889 return status;
@@ -1869,53 +1892,110 @@ static Status CopySparseData(size_t n_sparse_elements,
18691892common::Status SparseTensorProtoToDenseTensorProto (const ONNX_NAMESPACE::SparseTensorProto& sparse,
18701893 const std::filesystem::path& model_path,
18711894 ONNX_NAMESPACE::TensorProto& dense) {
1872- Status status = Status::OK () ;
1895+ Status status;
18731896
18741897 const auto & sparse_values = sparse.values ();
1875- auto type = sparse_values.data_type ();
1876- dense.set_data_type (type);
1877- *dense.mutable_name () = sparse_values.name ();
1898+ const auto & name = sparse_values.name ();
18781899
1879- SafeInt<size_t > n_sparse_elements = 1 ;
1880- for (auto dim : sparse_values.dims ()) {
1881- n_sparse_elements *= dim;
1900+ const auto values_rank = sparse_values.dims_size ();
1901+ if (values_rank != 1 ) {
1902+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1903+ " Sparse tensor: " , name, " values should be rank 1 for COO format. Got:" , values_rank);
18821904 }
18831905
1884- SafeInt<size_t > n_dense_elements = 1 ;
1906+ auto type = sparse_values.data_type ();
1907+ dense.set_data_type (type);
1908+ *dense.mutable_name () = name;
1909+ SafeInt<int64_t > dense_elements = 1 ;
1910+
18851911 for (auto dim : sparse.dims ()) {
1886- n_dense_elements *= dim;
1912+ if (dim < 0 ) {
1913+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1914+ " Sparse tensor: " , name, " dense dims expected to be non-negative. Got:" , dim);
1915+ }
1916+ dense_elements *= dim;
18871917 dense.add_dims (dim);
18881918 }
18891919
1920+ const auto dense_dims = gsl::make_span<const int64_t >(dense.dims ().data (), dense.dims ().size ());
1921+
1922+ SafeInt<int64_t > nnz_elements = 1 ;
1923+ for (auto dim : sparse_values.dims ()) {
1924+ if (dim < 0 ) {
1925+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1926+ " Sparse tensor: " , name, " tensor dims expected to be non-negative. Got:" , dim);
1927+ }
1928+ nnz_elements *= dim;
1929+ }
1930+
18901931 const auto & indices = sparse.indices ();
1891- auto dims = gsl::make_span<const int64_t >(dense.dims ().data (), dense.dims ().size ());
1932+ const auto indices_rank = indices.dims_size ();
1933+ if (indices_rank != 1 && indices_rank != 2 ) {
1934+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1935+ " Sparse tensor: " , name, " indices should be rank 1 or 2 for supported COO format. Got:" , indices_rank);
1936+ }
18921937
1893- if (type != TensorProto_DataType_STRING) {
1894- auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum (type)->GetElementType ();
1895- size_t element_size = ml_data->Size ();
1938+ const auto indices_dims = gsl::make_span (indices.dims ().data (), indices.dims ().size ());
1939+
1940+ if (indices_dims[0 ] != nnz_elements) {
1941+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1942+ " Sparse tensor: " , name,
1943+ " indices outer dimension should match the number of non-zero values. Got:" ,
1944+ indices_dims[0 ], " expected: " , static_cast <int64_t >(nnz_elements));
1945+ }
18961946
1897- // need to read in sparse data first as it could be in a type specific field, in raw data, or in external data
1898- std::vector<uint8_t > sparse_data_storage;
1899- ORT_RETURN_IF_ERROR (UnpackInitializerData (sparse_values, model_path, sparse_data_storage));
1900- void * sparse_data = sparse_data_storage.data ();
1947+ if (indices_rank == 2 && dense_dims.size () != narrow<size_t >(indices_dims[1 ])) {
1948+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1949+ " Sparse tensor: " , name,
1950+ " indices is rank 2, its inner dimension should match the rank of the dense tensor. Got:" ,
1951+ indices_dims[1 ], " expected: " , dense_dims.size ());
1952+ }
1953+
1954+ if (indices_rank == 2 ) {
1955+ const auto num_indices = TensorShape (indices_dims).Size ();
1956+ const int64_t expected_indices_entries = SafeInt<int64_t >(nnz_elements) * indices_dims[1 ];
1957+ if (num_indices != expected_indices_entries) {
1958+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_GRAPH,
1959+ " Sparse tensor: " , name,
1960+ " indices is rank 2, it should have NNZ values * indices_dims[1] entries. Got:" ,
1961+ num_indices, " expected: " , expected_indices_entries);
1962+ }
1963+ }
1964+
1965+ if (dense_elements == 0 ) {
1966+ // if there are no elements in the dense tensor, we can return early with an empty tensor proto
1967+ return status;
1968+ }
1969+
1970+ if (type != ONNX_NAMESPACE::TensorProto_DataType_STRING) {
1971+ auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum (type)->GetElementType ();
1972+ const size_t element_size = ml_data->Size ();
19011973
19021974 // by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move
19031975 // into the TensorProto.
1904- std::string dense_data_storage (n_dense_elements * element_size, 0 );
1905- if (n_sparse_elements > 0 ) {
1976+ std::string dense_data_storage (narrow<size_t >(dense_elements) * element_size, 0 );
1977+ if (nnz_elements > 0 ) {
1978+ // need to read in sparse data first as it could be in a type specific field, in raw data, or in external data
1979+ std::vector<uint8_t > values_data;
1980+ ORT_RETURN_IF_ERROR (UnpackInitializerData (sparse_values, model_path, values_data));
1981+ ORT_RETURN_IF_NOT (values_data.size () == static_cast <size_t >(nnz_elements) * element_size,
1982+ " Sparse tensor: " , name, " values data size does not match expected: " ,
1983+ static_cast <size_t >(nnz_elements) * element_size);
1984+ void * sparse_data = values_data.data ();
19061985 void * dense_data = dense_data_storage.data ();
19071986
19081987 switch (element_size) {
19091988 case 1 : {
19101989 status = CopySparseData (
1911- n_sparse_elements, indices, model_path, dims, [sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1990+ name, nnz_elements, indices, model_path, dense_dims, dense_elements,
1991+ [sparse_data, dense_data](size_t from_idx, size_t to_idx) {
19121992 static_cast <uint8_t *>(dense_data)[to_idx] = static_cast <const uint8_t *>(sparse_data)[from_idx];
19131993 });
19141994
19151995 break ;
19161996 }
19171997 case 2 : {
1918- status = CopySparseData (n_sparse_elements, indices, model_path, dims ,
1998+ status = CopySparseData (name, nnz_elements, indices, model_path, dense_dims, dense_elements ,
19191999 [sparse_data, dense_data](size_t from_idx, size_t to_idx) {
19202000 const auto * src = static_cast <const uint16_t *>(sparse_data) + from_idx;
19212001 auto * dst = static_cast <uint16_t *>(dense_data) + to_idx;
@@ -1925,7 +2005,7 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT
19252005 break ;
19262006 }
19272007 case 4 : {
1928- status = CopySparseData (n_sparse_elements, indices, model_path, dims ,
2008+ status = CopySparseData (name, nnz_elements, indices, model_path, dense_dims, dense_elements ,
19292009 [sparse_data, dense_data](size_t from_idx, size_t to_idx) {
19302010 const auto * src = static_cast <const uint32_t *>(sparse_data) + from_idx;
19312011 auto * dst = static_cast <uint32_t *>(dense_data) + to_idx;
@@ -1935,7 +2015,7 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT
19352015 break ;
19362016 }
19372017 case 8 : {
1938- status = CopySparseData (n_sparse_elements, indices, model_path, dims ,
2018+ status = CopySparseData (name, nnz_elements, indices, model_path, dense_dims, dense_elements ,
19392019 [sparse_data, dense_data](size_t from_idx, size_t to_idx) {
19402020 const auto * src = static_cast <const uint64_t *>(sparse_data) + from_idx;
19412021 auto * dst = static_cast <uint64_t *>(dense_data) + to_idx;
0 commit comments