Skip to content

Commit ca1d04b

Browse files
authored
Release data in cache. (dmlc#10286)
1 parent f1f69ff commit ca1d04b

File tree

5 files changed

+46
-39
lines changed

5 files changed

+46
-39
lines changed

python-package/xgboost/core.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
504504
cache_prefix :
505505
Prefix to the cache files, only used in external memory.
506506
release_data :
507-
Whether the iterator should release the data during reset. Set it to True if the
508-
data transformation (converting data to np.float32 type) is expensive.
507+
Whether the iterator should release the data during iteration. Set it to True if
508+
the data transformation (converting data to np.float32 type) is memory
509+
intensive. Otherwise, if the transformation is computation intensive then we can
510+
keep the cache.
509511
510512
"""
511513

@@ -517,15 +519,12 @@ def __init__(
517519
self._handle = _ProxyDMatrix()
518520
self._exception: Optional[Exception] = None
519521
self._enable_categorical = False
520-
self._allow_host = True
521522
self._release = release_data
522523
# Stage data in Python until reset or next is called to avoid data being free.
523524
self._temporary_data: Optional[TransformedData] = None
524525
self._data_ref: Optional[weakref.ReferenceType] = None
525526

526-
def get_callbacks(
527-
self, allow_host: bool, enable_categorical: bool
528-
) -> Tuple[Callable, Callable]:
527+
def get_callbacks(self, enable_categorical: bool) -> Tuple[Callable, Callable]:
529528
"""Get callback functions for iterating in C. This is an internal function."""
530529
assert hasattr(self, "cache_prefix"), "__init__ is not called."
531530
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
@@ -535,7 +534,6 @@ def get_callbacks(
535534
ctypes.c_int,
536535
ctypes.c_void_p,
537536
)(self._next_wrapper)
538-
self._allow_host = allow_host
539537
self._enable_categorical = enable_categorical
540538
return self._reset_callback, self._next_callback
541539

@@ -624,14 +622,17 @@ def input_data(
624622
)
625623
# Stage the data, meta info are copied inside C++ MetaInfo.
626624
self._temporary_data = (new, cat_codes, feature_names, feature_types)
627-
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host)
625+
dispatch_proxy_set_data(self.proxy, new, cat_codes)
628626
self.proxy.set_info(
629627
feature_names=feature_names,
630628
feature_types=feature_types,
631629
**kwargs,
632630
)
633631
self._data_ref = ref
634632

633+
# Release the data before next batch is loaded.
634+
if self._release:
635+
self._temporary_data = None
635636
# pylint: disable=not-callable
636637
return self._handle_exception(lambda: self.next(input_data), 0)
637638

@@ -911,7 +912,7 @@ def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
911912
}
912913
args_cstr = from_pystr_to_cstr(json.dumps(args))
913914
handle = ctypes.c_void_p()
914-
reset_callback, next_callback = it.get_callbacks(True, enable_categorical)
915+
reset_callback, next_callback = it.get_callbacks(enable_categorical)
915916
ret = _LIB.XGDMatrixCreateFromCallback(
916917
None,
917918
it.proxy.handle,
@@ -1437,37 +1438,37 @@ def __init__(self) -> None: # pylint: disable=super-init-not-called
14371438
self.handle = ctypes.c_void_p()
14381439
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
14391440

1440-
def _set_data_from_cuda_interface(self, data: DataType) -> None:
1441-
"""Set data from CUDA array interface."""
1441+
def _ref_data_from_cuda_interface(self, data: DataType) -> None:
1442+
"""Reference data from CUDA array interface."""
14421443
interface = data.__cuda_array_interface__
14431444
interface_str = bytes(json.dumps(interface), "utf-8")
14441445
_check_call(
14451446
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
14461447
)
14471448

1448-
def _set_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None:
1449-
"""Set data from CUDA columnar format."""
1449+
def _ref_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None:
1450+
"""Reference data from CUDA columnar format."""
14501451
from .data import _cudf_array_interfaces
14511452

14521453
interfaces_str = _cudf_array_interfaces(data, cat_codes)
14531454
_check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str))
14541455

1455-
def _set_data_from_array(self, data: np.ndarray) -> None:
1456-
"""Set data from numpy array."""
1456+
def _ref_data_from_array(self, data: np.ndarray) -> None:
1457+
"""Reference data from numpy array."""
14571458
from .data import _array_interface
14581459

14591460
_check_call(
14601461
_LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data))
14611462
)
14621463

1463-
def _set_data_from_pandas(self, data: DataType) -> None:
1464-
"""Set data from a pandas DataFrame. The input is a PandasTransformed instance."""
1464+
def _ref_data_from_pandas(self, data: DataType) -> None:
1465+
"""Reference data from a pandas DataFrame. The input is a PandasTransformed instance."""
14651466
_check_call(
14661467
_LIB.XGProxyDMatrixSetDataColumnar(self.handle, data.array_interface())
14671468
)
14681469

1469-
def _set_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
1470-
"""Set data from scipy csr"""
1470+
def _ref_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
1471+
"""Reference data from scipy csr."""
14711472
from .data import _array_interface
14721473

14731474
_LIB.XGProxyDMatrixSetDataCSR(
@@ -1609,7 +1610,7 @@ def _init(
16091610
it = SingleBatchInternalIter(data=data, **meta)
16101611

16111612
handle = ctypes.c_void_p()
1612-
reset_callback, next_callback = it.get_callbacks(True, enable_categorical)
1613+
reset_callback, next_callback = it.get_callbacks(enable_categorical)
16131614
if it.cache_prefix is not None:
16141615
raise ValueError(
16151616
"QuantileDMatrix doesn't cache data, remove the cache_prefix "

python-package/xgboost/dask/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def __init__(
616616
assert isinstance(self._label_upper_bound, types)
617617

618618
self._iter = 0 # set iterator to 0
619-
super().__init__()
619+
super().__init__(release_data=True)
620620

621621
def _get(self, attr: str) -> Optional[Any]:
622622
if getattr(self, attr) is not None:

python-package/xgboost/data.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,41 +1467,37 @@ def dispatch_proxy_set_data(
14671467
proxy: _ProxyDMatrix,
14681468
data: DataType,
14691469
cat_codes: Optional[list],
1470-
allow_host: bool,
14711470
) -> None:
14721471
"""Dispatch for QuantileDMatrix."""
14731472
if not _is_cudf_ser(data) and not _is_pandas_series(data):
14741473
_check_data_shape(data)
14751474

14761475
if _is_cudf_df(data):
14771476
# pylint: disable=W0212
1478-
proxy._set_data_from_cuda_columnar(data, cast(List, cat_codes))
1477+
proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes))
14791478
return
14801479
if _is_cudf_ser(data):
14811480
# pylint: disable=W0212
1482-
proxy._set_data_from_cuda_columnar(data, cast(List, cat_codes))
1481+
proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes))
14831482
return
14841483
if _is_cupy_alike(data):
1485-
proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212
1484+
proxy._ref_data_from_cuda_interface(data) # pylint: disable=W0212
14861485
return
14871486
if _is_dlpack(data):
14881487
data = _transform_dlpack(data)
1489-
proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212
1488+
proxy._ref_data_from_cuda_interface(data) # pylint: disable=W0212
14901489
return
1491-
1492-
err = TypeError("Value type is not supported for data iterator:" + str(type(data)))
1493-
1494-
if not allow_host:
1495-
raise err
1496-
1490+
# Host
14971491
if isinstance(data, PandasTransformed):
1498-
proxy._set_data_from_pandas(data) # pylint: disable=W0212
1492+
proxy._ref_data_from_pandas(data) # pylint: disable=W0212
14991493
return
15001494
if _is_np_array_like(data):
15011495
_check_data_shape(data)
1502-
proxy._set_data_from_array(data) # pylint: disable=W0212
1496+
proxy._ref_data_from_array(data) # pylint: disable=W0212
15031497
return
15041498
if is_scipy_csr(data):
1505-
proxy._set_data_from_csr(data) # pylint: disable=W0212
1499+
proxy._ref_data_from_csr(data) # pylint: disable=W0212
15061500
return
1501+
1502+
err = TypeError("Value type is not supported for data iterator:" + str(type(data)))
15071503
raise err

python-package/xgboost/spark/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
self._data = data
7878
self._kwargs = kwargs
7979

80-
super().__init__()
80+
super().__init__(release_data=True)
8181

8282
def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFrame]:
8383
if not data:

tests/python/test_data_iterator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,11 @@ def test_data_iterator(
160160

161161

162162
class IterForCacheTest(xgb.DataIter):
163-
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
163+
def __init__(
164+
self, x: np.ndarray, y: np.ndarray, w: np.ndarray, release_data: bool
165+
) -> None:
164166
self.kwargs = {"data": x, "label": y, "weight": w}
165-
super().__init__(release_data=False)
167+
super().__init__(release_data=release_data)
166168

167169
def next(self, input_data: Callable) -> int:
168170
if self.it == 1:
@@ -181,7 +183,9 @@ def test_data_cache() -> None:
181183
n_samples_per_batch = 16
182184
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
183185
batches = [v[0] for v in data]
184-
it = IterForCacheTest(*batches)
186+
187+
# Test with a cache.
188+
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=False)
185189
transform = xgb.data._proxy_transform
186190

187191
called = 0
@@ -196,6 +200,12 @@ def mock(*args: Any, **kwargs: Any) -> Any:
196200
assert it._data_ref is weakref.ref(batches[0])
197201
assert called == 1
198202

203+
# Test without a cache.
204+
called = 0
205+
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=True)
206+
xgb.QuantileDMatrix(it)
207+
assert called == 4
208+
199209
xgb.data._proxy_transform = transform
200210

201211

0 commit comments

Comments
 (0)