Skip to content

Commit fb1cabe

Browse files
[New Model] TabICLv2 (#259)
* add: TabICLv2 * maint: slurm ray fix and misc * maint: final pr changes * fix: typo * maint: update venv for PR * fix: ensure correct n_est for v1
1 parent ac9e0d5 commit fb1cabe

File tree

12 files changed

+235
-85
lines changed

12 files changed

+235
-85
lines changed

tabarena/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ dependencies = [
4141
tabpfn = [
4242
"tabpfn>=6.0.5", # We used version 6.0.5
4343
]
44-
tabicl = ["tabicl>=0.1.1"]
44+
tabicl = ["tabicl>=2.0.0"]
4545
ebm = ["interpret-core>=0.7.3"]
4646
search_spaces = ["configspace>=1.2,<2.0"]
4747
realmlp = ["pytabkit>=1.5.0,<2.0"]
@@ -57,7 +57,7 @@ tabprep = []
5757
# union of all above extras (mirrors your "benchmark" extra)
5858
benchmark = [
5959
"tabpfn>=6.0.5",
60-
"tabicl>=0.1.1",
60+
"tabicl>=2.0.0",
6161
"interpret-core>=0.7.3",
6262
"configspace>=1.2,<2.0",
6363
"pytabkit>=1.5.0,<2.0",

tabarena/tabarena/benchmark/models/ag/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tabarena.benchmark.models.ag.realmlp.realmlp_model import RealMLPModel
77
from tabarena.benchmark.models.ag.sap_rpt_oss.sap_rpt_oss_model import SAPRPTOSSModel
88
from tabarena.benchmark.models.ag.tabdpt.tabdpt_model import TabDPTModel
9-
from tabarena.benchmark.models.ag.tabicl.tabicl_model import TabICLModel
9+
from tabarena.benchmark.models.ag.tabicl.tabicl_model import TabICLModel, TabICLv2Model
1010
from tabarena.benchmark.models.ag.tabm.tabm_model import TabMModel
1111
from tabarena.benchmark.models.ag.tabpfnv2_5.tabpfnv2_5_model import RealTabPFNv25Model
1212
from tabarena.benchmark.models.ag.xrfm.xrfm_model import XRFMModel
@@ -20,6 +20,7 @@
2020
"SAPRPTOSSModel",
2121
"TabDPTModel",
2222
"TabICLModel",
23+
"TabICLv2Model",
2324
"TabMModel",
2425
"XRFMModel",
2526
]
Lines changed: 127 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
from __future__ import annotations
22

33
import logging
4-
5-
import pandas as pd
4+
from typing import TYPE_CHECKING
65

76
from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
87
from autogluon.common.utils.resource_utils import ResourceManager
98
from autogluon.core.models import AbstractModel
109
from autogluon.tabular import __version__
1110

11+
if TYPE_CHECKING:
12+
import pandas as pd
13+
1214
logger = logging.getLogger(__name__)
1315

1416

15-
# TODO: Verify if crashes when weights are not yet downloaded and fit in parallel
16-
class TabICLModel(AbstractModel):
17-
"""
18-
TabICL is a foundation model for tabular data using in-context learning
17+
class TabICLModelBase(AbstractModel):
18+
"""TabICL is a foundation model for tabular data using in-context learning
1919
that is scalable to larger datasets than TabPFNv2. It is pretrained purely on synthetic data.
2020
TabICL currently only supports classification tasks.
2121
@@ -26,27 +26,57 @@ class TabICLModel(AbstractModel):
2626
Codebase: https://github.com/soda-inria/tabicl
2727
License: BSD-3-Clause
2828
"""
29-
ag_key = "TA-TABICL"
30-
ag_name = "TA-TabICL"
29+
30+
ag_key = "NOTSET"
31+
ag_name = "NOTSET"
3132
ag_priority = 65
33+
seed_name = "random_state"
3234

33-
def get_model_cls(self):
34-
from tabicl import TabICLClassifier
35+
default_classification_model: str | None = None
36+
default_regression_model: str | None = None
3537

38+
def get_model_cls(self):
3639
if self.problem_type in ["binary", "multiclass"]:
40+
from tabicl import TabICLClassifier
41+
3742
model_cls = TabICLClassifier
3843
else:
39-
raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
44+
from tabicl import TabICLRegressor
45+
46+
model_cls = TabICLRegressor
4047
return model_cls
4148

49+
def get_checkpoint_version(self, hyperparameter: dict) -> str:
50+
clf_checkpoint = self.default_classification_model
51+
reg_checkpoint = self.default_regression_model
52+
53+
# Resolve HPO
54+
if "checkpoint_version" in hyperparameter:
55+
if isinstance(hyperparameter["checkpoint_version"], str):
56+
clf_checkpoint = hyperparameter["checkpoint_version"]
57+
reg_checkpoint = hyperparameter["checkpoint_version"]
58+
elif isinstance(hyperparameter["checkpoint_version"], tuple):
59+
clf_checkpoint = hyperparameter["checkpoint_version"][0]
60+
reg_checkpoint = hyperparameter["checkpoint_version"][1]
61+
else:
62+
raise ValueError(
63+
"checkpoint_version hyperparameter must be either "
64+
"a string or a tuple of two strings (clf, reg)."
65+
)
66+
67+
if self.problem_type in ["binary", "multiclass"]:
68+
return clf_checkpoint
69+
70+
return reg_checkpoint
71+
72+
# TODO: is this still correct for TabICLv2?
4273
@staticmethod
4374
def _get_batch_size(n_cells: int):
4475
if n_cells <= 4_000_000:
4576
return 8
46-
elif n_cells <= 6_000_000:
77+
if n_cells <= 6_000_000:
4778
return 4
48-
else:
49-
return 2
79+
return 2
5080

5181
def _fit(
5282
self,
@@ -78,7 +108,11 @@ def _fit(
78108

79109
model_cls = self.get_model_cls()
80110
hyp = self._get_model_params()
81-
hyp["batch_size"] = hyp.get("batch_size", self._get_batch_size(X.shape[0] * X.shape[1]))
111+
hyp["batch_size"] = hyp.get(
112+
"batch_size", self._get_batch_size(X.shape[0] * X.shape[1])
113+
)
114+
hyp["checkpoint_version"] = self.get_checkpoint_version(hyperparameter=hyp)
115+
82116
self.model = model_cls(
83117
**hyp,
84118
device=device,
@@ -90,77 +124,76 @@ def _fit(
90124
y=y,
91125
)
92126

93-
def _set_default_params(self):
94-
default_params = {
95-
"random_state": 42,
96-
}
97-
for param, val in default_params.items():
98-
self._set_default_param_value(param, val)
99-
100-
@classmethod
101-
def supported_problem_types(cls) -> list[str] | None:
102-
return ["binary", "multiclass"]
103-
104127
def _get_default_resources(self) -> tuple[int, int]:
105128
# Use only physical cores for better performance based on benchmarks
106129
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
107130

108131
num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
109132
return num_cpus, num_gpus
110133

111-
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
134+
def get_minimum_resources(
135+
self, is_gpu_available: bool = False
136+
) -> dict[str, int | float]:
112137
return {
113138
"num_cpus": 1,
114139
"num_gpus": 1 if is_gpu_available else 0,
115140
}
116141

117142
def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
118143
hyperparameters = self._get_model_params()
119-
return self.estimate_memory_usage_static(X=X, problem_type=self.problem_type, num_classes=self.num_classes, hyperparameters=hyperparameters, **kwargs)
144+
return self.estimate_memory_usage_static(
145+
X=X,
146+
problem_type=self.problem_type,
147+
num_classes=self.num_classes,
148+
hyperparameters=hyperparameters,
149+
**kwargs,
150+
)
120151

152+
# TODO: move memory estimate to specific models below.
121153
@classmethod
122154
def _estimate_memory_usage_static(
123155
cls,
124156
*,
125157
X: pd.DataFrame,
126-
hyperparameters: dict = None,
158+
hyperparameters: dict | None = None,
127159
**kwargs,
128160
) -> int:
129-
"""
130-
Heuristic memory estimate that is very primitive.
161+
"""Heuristic memory estimate that is very primitive.
131162
Can be vastly improved.
132163
"""
133164
if hyperparameters is None:
134165
hyperparameters = {}
135166

136-
dataset_size_mem_est = 3 * get_approximate_df_mem_usage(X).sum() # roughly 3x DataFrame memory size
167+
dataset_size_mem_est = (
168+
3 * get_approximate_df_mem_usage(X).sum()
169+
) # roughly 3x DataFrame memory size
137170
baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
138171

139172
n_rows = X.shape[0]
140173
n_features = X.shape[1]
141-
batch_size = hyperparameters.get("batch_size", cls._get_batch_size(X.shape[0] * X.shape[1]))
174+
batch_size = hyperparameters.get(
175+
"batch_size", cls._get_batch_size(X.shape[0] * X.shape[1])
176+
)
142177
embedding_dim = 128
143178
bytes_per_float = 4
144-
model_mem_estimate = 2 * batch_size * embedding_dim * bytes_per_float * (4 + n_rows) * n_features
179+
model_mem_estimate = (
180+
2 * batch_size * embedding_dim * bytes_per_float * (4 + n_rows) * n_features
181+
)
145182

146183
model_mem_estimate *= 1.3 # add 30% buffer
147184

148185
# TODO: Observed memory spikes above expected values on large datasets, increasing mem estimate to compensate
149186
model_mem_estimate *= 2.0 # Note: 1.5 is not large enough, still gets OOM
150187

151-
mem_estimate = model_mem_estimate + dataset_size_mem_est + baseline_overhead_mem_est
152-
153-
return mem_estimate
188+
return model_mem_estimate + dataset_size_mem_est + baseline_overhead_mem_est
154189

155190
@classmethod
156191
def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
157-
"""
158-
Set fold_fitting_strategy to sequential_local,
192+
"""Set fold_fitting_strategy to sequential_local,
159193
as parallel folding crashes if model weights aren't pre-downloaded.
160194
"""
161195
default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
162196
extra_ag_args_ensemble = {
163-
# FIXME: If parallel, uses way more memory, seems to behave incorrectly, so we force sequential.
164197
"fold_fitting_strategy": "sequential_local",
165198
"refit_folds": True, # Better to refit the model for faster inference and similar quality as the bag.
166199
}
@@ -173,3 +206,57 @@ def _class_tags(cls) -> dict:
173206

174207
def _more_tags(self) -> dict:
175208
return {"can_refit_full": True}
209+
210+
@staticmethod
211+
def checkpoint_search_space() -> list[str | tuple[str, str]]:
212+
raise NotImplementedError("This method must be implemented in the subclass.")
213+
214+
215+
class TabICLModel(TabICLModelBase):
216+
"""TabICLv1.1 model as used on TabArena."""
217+
218+
ag_key = "TA-TABICL"
219+
ag_name = "TA-TabICL"
220+
221+
default_classification_model: str | None = "tabicl-classifier-v1.1-20250506.ckpt"
222+
223+
@classmethod
224+
def supported_problem_types(cls) -> list[str] | None:
225+
return ["binary", "multiclass"]
226+
227+
@staticmethod
228+
def checkpoint_search_space() -> list[str]:
229+
return [
230+
"tabicl-classifier-v1.1-20250506.ckpt",
231+
"tabicl-classifier-v1-20250208.ckpt",
232+
]
233+
234+
def _set_default_params(self):
235+
default_params = {
236+
"n_estimators": 32, # default of TabICLv1
237+
}
238+
for param, val in default_params.items():
239+
self._set_default_param_value(param, val)
240+
241+
class TabICLv2Model(TabICLModelBase):
242+
"""TabICLv2 model as used on TabArena."""
243+
244+
ag_key = "TA-TABICLv2"
245+
ag_name = "TA-TabICLv2"
246+
247+
default_classification_model: str | None = "tabicl-classifier-v2-20260212.ckpt"
248+
default_regression_model: str | None = "tabicl-regressor-v2-20260212.ckpt"
249+
250+
@classmethod
251+
def supported_problem_types(cls) -> list[str] | None:
252+
return ["binary", "multiclass", "regression"]
253+
254+
# TODO: search over v1 checkpoints too?
255+
@staticmethod
256+
def checkpoint_search_space() -> list[tuple[str, str]]:
257+
return [
258+
(
259+
"tabicl-classifier-v2-20260212.ckpt",
260+
"tabicl-regressor-v2-20260212.ckpt",
261+
)
262+
]

tabarena/tabarena/benchmark/models/model_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
SAPRPTOSSModel,
1414
TabDPTModel,
1515
TabICLModel,
16+
TabICLv2Model,
1617
TabMModel,
1718
XRFMModel,
1819
)
@@ -30,6 +31,7 @@
3031
KNNNewModel,
3132
RealTabPFNv25Model,
3233
SAPRPTOSSModel,
34+
TabICLv2Model,
3335
]
3436

3537
for _model_cls in _models_to_add:
Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from __future__ import annotations
22

3+
from copy import deepcopy
4+
35
from autogluon.common.space import Categorical, Real
46

5-
from tabarena.benchmark.models.ag.tabicl.tabicl_model import TabICLModel
7+
from tabarena.benchmark.models.ag.tabicl.tabicl_model import (
8+
TabICLModel,
9+
TabICLModelBase,
10+
TabICLv2Model,
11+
)
612
from tabarena.utils.config_utils import ConfigGenerator
713

8-
name = "TabICL"
9-
manual_configs = [
10-
# Default config with refit after cross-validation.
11-
{"ag_args_ensemble": {"refit_folds": True}},
12-
]
13-
1414
# Unofficial search space
15-
search_space = {
16-
"checkpoint_version": Categorical("tabicl-classifier-v1.1-0506.ckpt", "tabicl-classifier-v1-0208.ckpt"),
17-
"norm_methods": Categorical("none", "power", "robust", "quantile_rtdl", ["none", "power"]),
15+
base_search_space = {
16+
"norm_methods": Categorical(
17+
"none", "power", "robust", "quantile_rtdl", ["none", "power"]
18+
),
1819
# just in case, tuning between TabICL and TabPFN defaults
1920
"outlier_threshold": Real(4.0, 12.0),
2021
"average_logits": Categorical(False, True),
@@ -24,9 +25,20 @@
2425
"ag_args_ensemble": Categorical({"refit_folds": True}),
2526
}
2627

27-
gen_tabicl = ConfigGenerator(
28-
model_cls=TabICLModel, manual_configs=manual_configs, search_space=search_space
29-
)
28+
29+
def get_gen_function(model_cls: TabICLModelBase):
30+
search_space = deepcopy(base_search_space)
31+
search_space["checkpoint_version"] = Categorical(
32+
*model_cls.checkpoint_search_space()
33+
)
34+
return ConfigGenerator(
35+
model_cls=model_cls, manual_configs=[{}], search_space=search_space
36+
)
37+
38+
39+
gen_tabicl = get_gen_function(TabICLModel)
40+
41+
gen_tabiclv2 = get_gen_function(TabICLv2Model)
3042

3143
if __name__ == "__main__":
3244
from tabarena.benchmark.experiment import YamlExperimentSerializer
@@ -36,3 +48,9 @@
3648
experiments=gen_tabicl.generate_all_bag_experiments(num_random_configs=0),
3749
),
3850
)
51+
52+
print(
53+
YamlExperimentSerializer.to_yaml_str(
54+
experiments=gen_tabiclv2.generate_all_bag_experiments(num_random_configs=0),
55+
),
56+
)

tabarena/tabarena/models/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def get_configs_generator_from_name(model_name: str):
4848
"xRFM": lambda: importlib.import_module("tabarena.models.xrfm.generate").gen_xrfm,
4949
"RealTabPFN-v2.5": lambda: importlib.import_module("tabarena.models.tabpfnv2_5.generate").gen_realtabpfnv25,
5050
"SAP-RPT-OSS": lambda: importlib.import_module("tabarena.models.sap_rpt_oss.generate").gen_sap_rpt_oss,
51+
"TabICLv2": lambda: importlib.import_module("tabarena.models.tabicl.generate").gen_tabiclv2,
5152
}
5253

5354
if model_name not in name_to_import_map:

0 commit comments

Comments
 (0)