Skip to content

Commit a8db6e0

Browse files
authored
[doc] Update the iterator demo. (dmlc#11222)
1 parent 3f727c2 commit a8db6e0

File tree

3 files changed

+35
-21
lines changed

3 files changed

+35
-21
lines changed

demo/guide-python/external_memory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
- rmm
2626
- python-cuda
2727
28+
.. seealso::
29+
30+
:ref:`sphx_glr_python_examples_distributed_extmem_basic.py`
31+
2832
"""
2933

3034
import argparse

demo/guide-python/quantile_data_iterator.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,24 @@
55
.. versionadded:: 1.2.0
66
77
The demo that defines a customized iterator for passing batches of data into
8-
:py:class:`xgboost.QuantileDMatrix` and use this ``QuantileDMatrix`` for
9-
training. The feature is used primarily designed to reduce the required GPU
10-
memory for training on distributed environment.
8+
:py:class:`xgboost.QuantileDMatrix` and use this ``QuantileDMatrix`` for training. The
9+
feature is primarily designed to reduce the required GPU memory for training on
10+
distributed environment.
1111
12-
Aftering going through the demo, one might ask why don't we use more native
13-
Python iterator? That's because XGBoost requires a `reset` function, while
14-
using `itertools.tee` might incur significant memory usage according to:
12+
Aftering going through the demo, one might ask why don't we use more native Python
13+
iterator? That's because XGBoost requires a `reset` function, while using
14+
`itertools.tee` might incur significant memory usage according to:
1515
1616
https://docs.python.org/3/library/itertools.html#itertools.tee.
1717
18+
.. seealso::
19+
20+
:ref:`sphx_glr_python_examples_external_memory.py`
21+
1822
"""
1923

24+
from typing import Callable
25+
2026
import cupy
2127
import numpy
2228

@@ -35,7 +41,7 @@ class IterForDMatrixDemo(xgboost.core.DataIter):
3541
3642
"""
3743

38-
def __init__(self):
44+
def __init__(self) -> None:
3945
"""Generate some random data for demostration.
4046
4147
Actual data can be anything that is currently supported by XGBoost.
@@ -50,41 +56,44 @@ def __init__(self):
5056
self.it = 0 # set iterator to 0
5157
super().__init__()
5258

53-
def as_array(self):
59+
def as_array(self) -> cupy.ndarray:
5460
return cupy.concatenate(self._data)
5561

56-
def as_array_labels(self):
62+
def as_array_labels(self) -> cupy.ndarray:
5763
return cupy.concatenate(self._labels)
5864

59-
def as_array_weights(self):
65+
def as_array_weights(self) -> cupy.ndarray:
6066
return cupy.concatenate(self._weights)
6167

62-
def data(self):
68+
def data(self) -> cupy.ndarray:
6369
"""Utility function for obtaining current batch of data."""
6470
return self._data[self.it]
6571

66-
def labels(self):
72+
def labels(self) -> cupy.ndarray:
6773
"""Utility function for obtaining current batch of label."""
6874
return self._labels[self.it]
6975

70-
def weights(self):
76+
def weights(self) -> cupy.ndarray:
7177
return self._weights[self.it]
7278

73-
def reset(self):
79+
def reset(self) -> None:
7480
"""Reset the iterator"""
7581
self.it = 0
7682

77-
def next(self, input_data):
78-
"""Yield next batch of data."""
83+
def next(self, input_data: Callable) -> bool:
84+
"""Yield the next batch of data."""
7985
if self.it == len(self._data):
80-
# Return 0 when there's no more batch.
81-
return 0
86+
# Return False to let XGBoost know this is the end of iteration
87+
return False
88+
89+
# input_data is a keyword-only function passed in by XGBoost and has the similar
90+
# signature to the ``DMatrix`` constructor.
8291
input_data(data=self.data(), label=self.labels(), weight=self.weights())
8392
self.it += 1
84-
return 1
93+
return True
8594

8695

87-
def main():
96+
def main() -> None:
8897
rounds = 100
8998
it = IterForDMatrixDemo()
9099

@@ -103,7 +112,7 @@ def main():
103112

104113
assert m_with_it.num_col() == m.num_col()
105114
assert m_with_it.num_row() == m.num_row()
106-
# Tree meethod must be `hist`.
115+
# Tree method must be `hist`.
107116
reg_with_it = xgboost.train(
108117
{"tree_method": "hist", "device": "cuda"},
109118
m_with_it,

ops/script/lint_python.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class LintersPaths:
122122
"demo/guide-python/model_parser.py",
123123
"demo/guide-python/individual_trees.py",
124124
"demo/guide-python/quantile_regression.py",
125+
"demo/guide-python/quantile_data_iterator.py",
125126
"demo/guide-python/multioutput_regression.py",
126127
"demo/guide-python/learning_to_rank.py",
127128
"demo/aft_survival/aft_survival_viz_demo.py",

0 commit comments

Comments
 (0)