Skip to content

Commit f05a23b

Browse files
authored
Use weakref instead of id for DataIter cache. (dmlc#9445)
- Fix case where Python reuses id from freed objects. - Small optimization to column matrix with QDM by using `realloc` instead of copying data.
1 parent d495a18 commit f05a23b

File tree

14 files changed

+193
-63
lines changed

14 files changed

+193
-63
lines changed

demo/guide-python/external_memory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222

2323

2424
def make_batches(
25-
n_samples_per_batch: int, n_features: int, n_batches: int, tmpdir: str,
25+
n_samples_per_batch: int,
26+
n_features: int,
27+
n_batches: int,
28+
tmpdir: str,
2629
) -> List[Tuple[str, str]]:
2730
files: List[Tuple[str, str]] = []
2831
rng = np.random.RandomState(1994)
@@ -38,6 +41,7 @@ def make_batches(
3841

3942
class Iterator(xgboost.DataIter):
4043
"""A custom iterator for loading files in batches."""
44+
4145
def __init__(self, file_paths: List[Tuple[str, str]]):
4246
self._file_paths = file_paths
4347
self._it = 0

doc/python/python_api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ Core Data Structure
2323
:show-inheritance:
2424

2525
.. autoclass:: xgboost.QuantileDMatrix
26+
:members:
2627
:show-inheritance:
2728

2829
.. autoclass:: xgboost.Booster
2930
:members:
3031
:show-inheritance:
3132

33+
.. autoclass:: xgboost.DataIter
34+
:members:
35+
:show-inheritance:
3236

3337
Learning API
3438
------------

python-package/xgboost/_typing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
Callable,
99
Dict,
1010
List,
11+
Optional,
1112
Sequence,
13+
Tuple,
1214
Type,
1315
TypeVar,
1416
Union,
@@ -20,8 +22,6 @@
2022

2123
DataType = Any
2224

23-
# xgboost accepts some other possible types in practice due to historical reason, which is
24-
# lesser tested. For now we encourage users to pass a simple list of string.
2525
FeatureInfo = Sequence[str]
2626
FeatureNames = FeatureInfo
2727
FeatureTypes = FeatureInfo
@@ -97,6 +97,13 @@
9797
ctypes._Pointer,
9898
]
9999

100+
# The second arg is actually Optional[List[cudf.Series]], skipped for easier type check.
101+
# The cudf Series is the obtained cat codes, preserved in the `DataIter` to prevent it
102+
# being freed.
103+
TransformedData = Tuple[
104+
Any, Optional[List], Optional[FeatureNames], Optional[FeatureTypes]
105+
]
106+
100107
# template parameter
101108
_T = TypeVar("_T")
102109
_F = TypeVar("_F", bound=Callable[..., Any])

python-package/xgboost/core.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import re
1010
import sys
1111
import warnings
12+
import weakref
1213
from abc import ABC, abstractmethod
1314
from collections.abc import Mapping
1415
from enum import IntEnum, unique
@@ -51,6 +52,7 @@
5152
FeatureTypes,
5253
ModelIn,
5354
NumpyOrCupy,
55+
TransformedData,
5456
c_bst_ulong,
5557
)
5658
from .compat import PANDAS_INSTALLED, DataFrame, py_str
@@ -486,7 +488,16 @@ def _prediction_output(
486488

487489

488490
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
489-
"""The interface for user defined data iterator.
491+
"""The interface for user defined data iterator. The iterator facilitates
492+
distributed training, :py:class:`QuantileDMatrix`, and external memory support using
493+
:py:class:`DMatrix`. Most of time, users don't need to interact with this class
494+
directly.
495+
496+
.. note::
497+
498+
The class caches some intermediate results using the `data` input (predictor
499+
`X`) as key. Don't repeat the `X` for multiple batches with different meta data
500+
(like `label`), make a copy if necessary.
490501
491502
Parameters
492503
----------
@@ -510,13 +521,13 @@ def __init__(
510521
self._allow_host = True
511522
self._release = release_data
512523
# Stage data in Python until reset or next is called to avoid data being free.
513-
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
514-
self._input_id: int = 0
524+
self._temporary_data: Optional[TransformedData] = None
525+
self._data_ref: Optional[weakref.ReferenceType] = None
515526

516527
def get_callbacks(
517528
self, allow_host: bool, enable_categorical: bool
518529
) -> Tuple[Callable, Callable]:
519-
"""Get callback functions for iterating in C."""
530+
"""Get callback functions for iterating in C. This is an internal function."""
520531
assert hasattr(self, "cache_prefix"), "__init__ is not called."
521532
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
522533
self._reset_wrapper
@@ -591,7 +602,19 @@ def input_data(
591602
from .data import _proxy_transform, dispatch_proxy_set_data
592603

593604
# Reduce the amount of transformation that's needed for QuantileDMatrix.
594-
if self._temporary_data is not None and id(data) == self._input_id:
605+
#
606+
# To construct the QDM, one needs 4 iterations on CPU, or 2 iterations on
607+
# GPU. If the QDM has only one batch of input (most of the cases), we can
608+
# avoid transforming the data repeatly.
609+
try:
610+
ref = weakref.ref(data)
611+
except TypeError:
612+
ref = None
613+
if (
614+
self._temporary_data is not None
615+
and ref is not None
616+
and ref is self._data_ref
617+
):
595618
new, cat_codes, feature_names, feature_types = self._temporary_data
596619
else:
597620
new, cat_codes, feature_names, feature_types = _proxy_transform(
@@ -608,7 +631,7 @@ def input_data(
608631
feature_types=feature_types,
609632
**kwargs,
610633
)
611-
self._input_id = id(data)
634+
self._data_ref = ref
612635

613636
# pylint: disable=not-callable
614637
return self._handle_exception(lambda: self.next(input_data), 0)
@@ -1134,7 +1157,7 @@ def get_data(self) -> scipy.sparse.csr_matrix:
11341157
testing purposes. If this is a quantized DMatrix then quantized values are
11351158
returned instead of input values.
11361159
1137-
.. versionadded:: 1.7.0
1160+
.. versionadded:: 1.7.0
11381161
11391162
"""
11401163
indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
@@ -1155,7 +1178,11 @@ def get_data(self) -> scipy.sparse.csr_matrix:
11551178
return ret
11561179

11571180
def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]:
1158-
"""Get quantile cuts for quantization."""
1181+
"""Get quantile cuts for quantization.
1182+
1183+
.. versionadded:: 2.0.0
1184+
1185+
"""
11591186
n_features = self.num_col()
11601187

11611188
c_sindptr = ctypes.c_char_p()

python-package/xgboost/data.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import os
77
import warnings
8-
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast
8+
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast
99

1010
import numpy as np
1111

@@ -17,6 +17,7 @@
1717
FloatCompatible,
1818
NumpyDType,
1919
PandasDType,
20+
TransformedData,
2021
c_bst_ulong,
2122
)
2223
from .compat import DataFrame, lazy_isinstance
@@ -1268,12 +1269,7 @@ def _proxy_transform(
12681269
feature_names: Optional[FeatureNames],
12691270
feature_types: Optional[FeatureTypes],
12701271
enable_categorical: bool,
1271-
) -> Tuple[
1272-
Union[bool, ctypes.c_void_p, np.ndarray],
1273-
Optional[list],
1274-
Optional[FeatureNames],
1275-
Optional[FeatureTypes],
1276-
]:
1272+
) -> TransformedData:
12771273
if _is_cudf_df(data) or _is_cudf_ser(data):
12781274
return _transform_cudf_df(
12791275
data, feature_names, feature_types, enable_categorical

python-package/xgboost/testing/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def reset(self) -> None:
230230

231231
def as_arrays(
232232
self,
233-
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, ArrayLike]:
233+
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, Optional[ArrayLike]]:
234234
if isinstance(self.X[0], sparse.csr_matrix):
235235
X = sparse.vstack(self.X, format="csr")
236236
else:
@@ -244,7 +244,12 @@ def as_arrays(
244244

245245

246246
def make_batches(
247-
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
247+
n_samples_per_batch: int,
248+
n_features: int,
249+
n_batches: int,
250+
use_cupy: bool = False,
251+
*,
252+
vary_size: bool = False,
248253
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
249254
X = []
250255
y = []
@@ -255,10 +260,11 @@ def make_batches(
255260
rng = cupy.random.RandomState(1994)
256261
else:
257262
rng = np.random.RandomState(1994)
258-
for _ in range(n_batches):
259-
_X = rng.randn(n_samples_per_batch, n_features)
260-
_y = rng.randn(n_samples_per_batch)
261-
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
263+
for i in range(n_batches):
264+
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
265+
_X = rng.randn(n_samples, n_features)
266+
_y = rng.randn(n_samples)
267+
_w = rng.uniform(low=0, high=1, size=n_samples)
262268
X.append(_X)
263269
y.append(_y)
264270
w.append(_w)

src/common/column_matrix.h

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
1010

1111
#include <algorithm>
12-
#include <cstddef> // for size_t
12+
#include <cstddef> // for size_t, byte
1313
#include <cstdint> // for uint8_t
1414
#include <limits>
1515
#include <memory>
16-
#include <utility> // for move
17-
#include <vector>
16+
#include <type_traits> // for enable_if_t, is_same_v, is_signed_v
17+
#include <utility> // for move
1818

1919
#include "../data/adapter.h"
2020
#include "../data/gradient_index.h"
@@ -112,9 +112,6 @@ class SparseColumnIter : public Column<BinIdxT> {
112112
*/
113113
template <typename BinIdxT, bool any_missing>
114114
class DenseColumnIter : public Column<BinIdxT> {
115-
public:
116-
using ByteType = bool;
117-
118115
private:
119116
using Base = Column<BinIdxT>;
120117
/* flags for missing values in dense columns */
@@ -153,8 +150,17 @@ class ColumnMatrix {
153150
* @brief A bit set for indicating whether an element in a dense column is missing.
154151
*/
155152
struct MissingIndicator {
156-
LBitField32 missing;
157-
RefResourceView<std::uint32_t> storage;
153+
using BitFieldT = LBitField32;
154+
using T = typename BitFieldT::value_type;
155+
156+
BitFieldT missing;
157+
RefResourceView<T> storage;
158+
static_assert(std::is_same_v<T, std::uint32_t>);
159+
160+
template <typename U>
161+
[[nodiscard]] std::enable_if_t<!std::is_signed_v<U>, U> static InitValue(bool init) {
162+
return init ? ~U{0} : U{0};
163+
}
158164

159165
MissingIndicator() = default;
160166
/**
@@ -163,7 +169,7 @@ class ColumnMatrix {
163169
*/
164170
MissingIndicator(std::size_t n_elements, bool init) {
165171
auto m_size = missing.ComputeStorageSize(n_elements);
166-
storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
172+
storage = common::MakeFixedVecWithMalloc(m_size, InitValue<T>(init));
167173
this->InitView();
168174
}
169175
/** @brief Set the i^th element to be a valid element (instead of missing). */
@@ -181,11 +187,12 @@ class ColumnMatrix {
181187
if (m_size == storage.size()) {
182188
return;
183189
}
190+
// grow the storage
191+
auto resource = std::dynamic_pointer_cast<common::MallocResource>(storage.Resource());
192+
CHECK(resource);
193+
resource->Resize(m_size * sizeof(T), InitValue<std::byte>(init));
194+
storage = RefResourceView<T>{resource->DataAs<T>(), m_size, resource};
184195

185-
auto new_storage =
186-
common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
187-
std::copy_n(storage.cbegin(), storage.size(), new_storage.begin());
188-
storage = std::move(new_storage);
189196
this->InitView();
190197
}
191198
};
@@ -210,7 +217,6 @@ class ColumnMatrix {
210217
}
211218

212219
public:
213-
using ByteType = bool;
214220
// get number of features
215221
[[nodiscard]] bst_feature_t GetNumFeature() const {
216222
return static_cast<bst_feature_t>(type_.size());
@@ -408,6 +414,7 @@ class ColumnMatrix {
408414
// IO procedures for external memory.
409415
[[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base);
410416
[[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const;
417+
[[nodiscard]] MissingIndicator const& Missing() const { return missing_; }
411418

412419
private:
413420
RefResourceView<std::uint8_t> index_;

src/common/io.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <dmlc/io.h>
1111
#include <rabit/rabit.h>
1212

13-
#include <algorithm> // for min
13+
#include <algorithm> // for min, fill_n, copy_n
1414
#include <array> // for array
1515
#include <cstddef> // for byte, size_t
1616
#include <cstdlib> // for malloc, realloc, free
@@ -207,7 +207,7 @@ class MallocResource : public ResourceHandler {
207207
* @param n_bytes The new size.
208208
*/
209209
template <bool force_malloc = false>
210-
void Resize(std::size_t n_bytes) {
210+
void Resize(std::size_t n_bytes, std::byte init = std::byte{0}) {
211211
// realloc(ptr, 0) works, but is deprecated.
212212
if (n_bytes == 0) {
213213
this->Clear();
@@ -236,7 +236,7 @@ class MallocResource : public ResourceHandler {
236236
std::copy_n(reinterpret_cast<std::byte*>(ptr_), n_, reinterpret_cast<std::byte*>(new_ptr));
237237
}
238238
// default initialize
239-
std::memset(reinterpret_cast<std::byte*>(new_ptr) + n_, '\0', n_bytes - n_);
239+
std::fill_n(reinterpret_cast<std::byte*>(new_ptr) + n_, n_bytes - n_, init);
240240
// free the old ptr if malloc is used.
241241
if (need_copy) {
242242
this->Clear();

tests/ci_build/lint_python.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class LintersPaths:
4242
"demo/guide-python/feature_weights.py",
4343
"demo/guide-python/sklearn_parallel.py",
4444
"demo/guide-python/spark_estimator_examples.py",
45+
"demo/guide-python/external_memory.py",
4546
"demo/guide-python/individual_trees.py",
4647
"demo/guide-python/quantile_regression.py",
4748
"demo/guide-python/multioutput_regression.py",

0 commit comments

Comments
 (0)