Skip to content

Commit 9d30dfa

Browse files
committed
v1.4.1: updates to imports, TabM, scheduler
1 parent 763d9b5 commit 9d30dfa

File tree

11 files changed

+253
-83
lines changed

11 files changed

+253
-83
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
170170

171171
## Releases (see git tags)
172172

173+
- v1.4.1:
174+
- moved dill to optional dependencies
175+
- updated TabM code to a newer version:
176+
added option share_training_batches=False (old version: True),
177+
exclude certain parameters from weight decay.
178+
- added [documentation](https://pytabkit.readthedocs.io/en/latest/bench/using_the_scheduler.html) for using the scheduler with custom jobs.
179+
- fixed bug in RealMLP refitting.
180+
- updated process start method for scheduler to speed up benchmarking
173181
- v1.4.0:
174182
- moved some imports to the new `models` optional dependencies
175183
to have a more light-weight RealMLP installation
@@ -237,4 +245,3 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
237245
- v0.0.1: First release for arXiv v1.
238246
Code and data are archived at [DaRUS](https://doi.org/10.18419/darus-4255).
239247

240-
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Using the scheduler
2+
3+
`pytabkit` includes a flexible scheduler that can schedule jobs within python using `ray` and `multiprocessing`.
4+
Essentially, it is a much fancier version of `multiprocessing.Pool`.
5+
Custom jobs need to provide an estimate of their required resources. The scheduler will
6+
- run as many jobs in parallel as possible on the current hardware while respecting the RAM and resource constraints
7+
- try to run the slowest jobs first, to avoid waiting for a few slow jobs in the end
8+
- measure free CPU RAM in the beginning, and add the fixed RAM that a CPU process uses to the requested RAM.
9+
For processes requesting a GPU, the fixed RAM used by a process using torch CUDA will be added to the requested RAM.
10+
- print info including remaining time estimates after each new started job, failed jobs etc.
11+
(unless the jobs run so fast that multiple ones are started at once).
12+
The time estimates will be based on the time estimates by the jobs,
13+
but they will be adapted by a factor learned based on the actual time taken by already finished jobs.
14+
Hence, the time estimate is only accurate after a few jobs have finished.
15+
It often underestimates the actually needed time to some extent.
16+
(This is probably also due to selection bias, since the estimated longest jobs are run first.)
17+
18+
The scheduler also works on multi-GPU systems,
19+
and it even works on multi-node systems thanks to `ray`'s multi-node support.
20+
See [`ray_slurm_launch.py`](https://github.com/dholzmueller/pytabkit/blob/main/scripts/ray_slurm_launch.py)
21+
and [`ray_slurm_template.sh`](https://github.com/dholzmueller/pytabkit/blob/main/scripts/ray_slurm_template.sh).
22+
To use the scheduler, install `pytabkit[models,bench]`.
23+
24+
Here is some example code:
25+
26+
```python
27+
from pytabkit.models.alg_interfaces.base import RequiredResources
28+
from pytabkit.bench.scheduling.execution import RayJobManager
29+
from pytabkit.bench.scheduling.jobs import AbstractJob
30+
from pytabkit.bench.scheduling.resources import NodeResources
31+
from pytabkit.bench.scheduling.schedulers import SimpleJobScheduler
32+
33+
class CustomJob(AbstractJob):
34+
def get_group(self):
35+
# group name, for all jobs with the same group name
36+
# one joint time multiplier will be fitted in the scheduler
37+
return 'default'
38+
39+
def get_desc(self) -> str:
40+
return 'CustomJob' # name for displaying
41+
42+
def __call__(self, assigned_resources: NodeResources) -> bool:
43+
# the main job, should only use the assigned resources
44+
print(f'Running job with {assigned_resources.get_n_threads()} threads', flush=True)
45+
return True # job finished successfully
46+
47+
def get_required_resources(self) -> RequiredResources:
48+
# Return the resources requested by this job (RAM should be upper bounds, time doesn't need to be)
49+
return RequiredResources(time_s=1.0, n_threads=1, cpu_ram_gb=0.1, n_gpus=0, gpu_ram_gb=0.0, gpu_usage=1.0)
50+
51+
52+
sched = SimpleJobScheduler(RayJobManager(available_gpu_ram_multiplier=0.7))
53+
sched.add_jobs([CustomJob() for _ in range(1000)])
54+
sched.run()
55+
```

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Tabular benchmarking using pytabkit.bench
3030
bench/03_code
3131
bench/download_results
3232
bench/refine_then_calibrate
33+
bench/using_the_scheduler
3334

3435

3536

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ dependencies = [
3838
# can also install the newer lightning package with more dependencies instead, it will be prioritized
3939
"pytorch_lightning>=2.0",
4040
"psutil>=5.0", # used for getting logical CPU count in the sklearn base and for getting process RAM usage
41-
"dill", # more powerful pickle, used for file-saving and multiprocessing
4241
]
4342

4443
[project.optional-dependencies]
@@ -62,6 +61,9 @@ models = [
6261
# not necessary unless these things are actually used
6362
"probmetrics>=0.0.1",
6463

64+
# more powerful pickle, used for file-saving and multiprocessing.
65+
# Unfortunately it can't save certain torch objects
66+
"dill",
6567
# saving objects in yaml/msgpack
6668
# needed if used in utils.serialize() / deserialize()
6769
"pyyaml>=5.0",

pytabkit/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__version__ = "1.4.0"
5+
__version__ = "1.4.1"

pytabkit/bench/scheduling/execution.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import traceback
66
from typing import Tuple, Optional, List
77

8-
import dill
98
import numpy as np
109

1110
from pytabkit.bench.scheduling.jobs import JobRunner
@@ -69,7 +68,7 @@ def measure_node_resources(node_id: int) -> Tuple[NodeResources, NodeResources]:
6968

7069

7170
def node_runner(feedback_queue, job_queue, node_id: int):
72-
mp.set_start_method('spawn', force=True)
71+
mp.set_start_method('fork', force=True)
7372

7473
# get resources in separate process so CUDA runtime is shut down when the process is terminated
7574
# this means that this process will not use up CUDA memory all the time
@@ -96,6 +95,7 @@ def node_runner(feedback_queue, job_queue, node_id: int):
9695
# cannot use None as termination signal since that is already the timeout signal
9796
return # or check if processes are still running?
9897

98+
import dill
9999
job_data = dill.loads(job_str)
100100
# print(f'DEBUG: got job data', flush=True)
101101
processes.append(FunctionProcess(JobRunner(*job_data)).start())
@@ -193,6 +193,7 @@ def get_resource_manager(self) -> ResourceManager:
193193
return self.resource_manager
194194

195195
def submit_job(self, job_info: JobInfo) -> None:
196+
import dill
196197
if self.resource_manager is None:
197198
raise RuntimeError('called submit_job() before start()')
198199
job = job_info.job

pytabkit/models/alg_interfaces/tabm_interface.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytabkit.models.nn_models import rtdl_num_embeddings
1919
from pytabkit.models.nn_models.base import Fitter
2020
from pytabkit.models.nn_models.models import PreprocessingFactory
21-
from pytabkit.models.nn_models.tabm import Model
21+
from pytabkit.models.nn_models.tabm import Model, make_parameter_groups
2222
from pytabkit.models.training.logging import Logger
2323

2424

@@ -56,6 +56,8 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
5656
allow_amp = self.config.get('allow_amp', False)
5757
n_blocks = self.config.get('n_blocks', 'auto')
5858
num_emb_n_bins = self.config.get('num_emb_n_bins', 48)
59+
# set default to True for backward compatibility
60+
share_training_batches = self.config.get("share_training_batches", False)
5961

6062
weight_decay = self.config.get('weight_decay', 0.0)
6163
gradient_clipping_norm = self.config.get('gradient_clipping_norm', None)
@@ -180,8 +182,9 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
180182
),
181183
arch_type=arch_type,
182184
k=tabm_k,
185+
share_training_batches=share_training_batches,
183186
).to(device)
184-
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
187+
optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=lr, weight_decay=weight_decay)
185188

186189

187190
if compile_model:
@@ -210,8 +213,11 @@ def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
210213
# TabM produces k predictions per object. Each of them must be trained separately.
211214
# (regression) y_pred.shape == (batch_size, k)
212215
# (classification) y_pred.shape == (batch_size, k, n_classes)
213-
k = y_pred.shape[-1 if task_type == 'regression' else -2]
214-
return base_loss_fn(y_pred.flatten(0, 1), y_true.repeat_interleave(k))
216+
k = y_pred.shape[1]
217+
return base_loss_fn(
218+
y_pred.flatten(0, 1),
219+
y_true.repeat_interleave(k) if model.share_training_batches else y_true.squeeze(-1),
220+
)
215221

216222
@evaluation_mode()
217223
def evaluate(part: str) -> float:
@@ -270,17 +276,22 @@ def evaluate(part: str) -> float:
270276
if self.config.get('verbosity', 0) >= 1:
271277
from tqdm.std import tqdm
272278
else:
273-
tqdm = lambda arr, desc, total: arr
279+
tqdm = lambda arr, desc: arr
274280
except ImportError:
275-
tqdm = lambda arr, desc, total: arr
281+
tqdm = lambda arr, desc: arr
276282

277283
logger.log(1, '-' * 88 + '\n')
278284
for epoch in range(n_epochs):
279-
for batch_idx in tqdm(
280-
torch.randperm(len(data['train']['y']), device=device).split(batch_size),
281-
desc=f'Epoch {epoch}',
282-
total=epoch_size,
283-
):
285+
batches = (
286+
torch.randperm(n_train, device=device).split(batch_size)
287+
if model.share_training_batches
288+
else [
289+
x.transpose(0, 1).flatten()
290+
for x in torch.rand((model.k, n_train), device=device).argsort(dim=1).split(batch_size, dim=1)
291+
]
292+
)
293+
294+
for batch_idx in tqdm(batches, desc=f"Epoch {epoch}"):
284295
model.train()
285296
optimizer.zero_grad()
286297
loss = loss_fn(apply_model('train', batch_idx), Y_train[batch_idx])

0 commit comments

Comments
 (0)