11from __future__ import annotations
22
33import logging
4-
5- import pandas as pd
4+ from typing import TYPE_CHECKING
65
76from autogluon .common .utils .pandas_utils import get_approximate_df_mem_usage
87from autogluon .common .utils .resource_utils import ResourceManager
98from autogluon .core .models import AbstractModel
109from autogluon .tabular import __version__
1110
11+ if TYPE_CHECKING :
12+ import pandas as pd
13+
1214logger = logging .getLogger (__name__ )
1315
1416
15- # TODO: Verify if crashes when weights are not yet downloaded and fit in parallel
16- class TabICLModel (AbstractModel ):
17- """
18- TabICL is a foundation model for tabular data using in-context learning
17+ class TabICLModelBase (AbstractModel ):
18+ """TabICL is a foundation model for tabular data using in-context learning
1919 that is scalable to larger datasets than TabPFNv2. It is pretrained purely on synthetic data.
2020 TabICL currently only supports classification tasks.
2121
@@ -26,27 +26,57 @@ class TabICLModel(AbstractModel):
2626 Codebase: https://github.com/soda-inria/tabicl
2727 License: BSD-3-Clause
2828 """
29- ag_key = "TA-TABICL"
30- ag_name = "TA-TabICL"
29+
30+ ag_key = "NOTSET"
31+ ag_name = "NOTSET"
3132 ag_priority = 65
33+ seed_name = "random_state"
3234
33- def get_model_cls ( self ):
34- from tabicl import TabICLClassifier
35+ default_classification_model : str | None = None
36+ default_regression_model : str | None = None
3537
38+ def get_model_cls (self ):
3639 if self .problem_type in ["binary" , "multiclass" ]:
40+ from tabicl import TabICLClassifier
41+
3742 model_cls = TabICLClassifier
3843 else :
39- raise AssertionError (f"Unsupported problem_type: { self .problem_type } " )
44+ from tabicl import TabICLRegressor
45+
46+ model_cls = TabICLRegressor
4047 return model_cls
4148
49+ def get_checkpoint_version (self , hyperparameter : dict ) -> str :
50+ clf_checkpoint = self .default_classification_model
51+ reg_checkpoint = self .default_regression_model
52+
53+ # Resolve HPO
54+ if "checkpoint_version" in hyperparameter :
55+ if isinstance (hyperparameter ["checkpoint_version" ], str ):
56+ clf_checkpoint = hyperparameter ["checkpoint_version" ]
57+ reg_checkpoint = hyperparameter ["checkpoint_version" ]
58+ elif isinstance (hyperparameter ["checkpoint_version" ], tuple ):
59+ clf_checkpoint = hyperparameter ["checkpoint_version" ][0 ]
60+ reg_checkpoint = hyperparameter ["checkpoint_version" ][1 ]
61+ else :
62+ raise ValueError (
63+ "checkpoint_version hyperparameter must be either "
64+ "a string or a tuple of two strings (clf, reg)."
65+ )
66+
67+ if self .problem_type in ["binary" , "multiclass" ]:
68+ return clf_checkpoint
69+
70+ return reg_checkpoint
71+
72+ # TODO: is this still correct for TabICLv2?
4273 @staticmethod
4374 def _get_batch_size (n_cells : int ):
4475 if n_cells <= 4_000_000 :
4576 return 8
46- elif n_cells <= 6_000_000 :
77+ if n_cells <= 6_000_000 :
4778 return 4
48- else :
49- return 2
79+ return 2
5080
5181 def _fit (
5282 self ,
@@ -78,7 +108,11 @@ def _fit(
78108
79109 model_cls = self .get_model_cls ()
80110 hyp = self ._get_model_params ()
81- hyp ["batch_size" ] = hyp .get ("batch_size" , self ._get_batch_size (X .shape [0 ] * X .shape [1 ]))
111+ hyp ["batch_size" ] = hyp .get (
112+ "batch_size" , self ._get_batch_size (X .shape [0 ] * X .shape [1 ])
113+ )
114+ hyp ["checkpoint_version" ] = self .get_checkpoint_version (hyperparameter = hyp )
115+
82116 self .model = model_cls (
83117 ** hyp ,
84118 device = device ,
@@ -90,77 +124,76 @@ def _fit(
90124 y = y ,
91125 )
92126
93- def _set_default_params (self ):
94- default_params = {
95- "random_state" : 42 ,
96- }
97- for param , val in default_params .items ():
98- self ._set_default_param_value (param , val )
99-
100- @classmethod
101- def supported_problem_types (cls ) -> list [str ] | None :
102- return ["binary" , "multiclass" ]
103-
104127 def _get_default_resources (self ) -> tuple [int , int ]:
105128 # Use only physical cores for better performance based on benchmarks
106129 num_cpus = ResourceManager .get_cpu_count (only_physical_cores = True )
107130
108131 num_gpus = min (1 , ResourceManager .get_gpu_count_torch (cuda_only = True ))
109132 return num_cpus , num_gpus
110133
111- def get_minimum_resources (self , is_gpu_available : bool = False ) -> dict [str , int | float ]:
134+ def get_minimum_resources (
135+ self , is_gpu_available : bool = False
136+ ) -> dict [str , int | float ]:
112137 return {
113138 "num_cpus" : 1 ,
114139 "num_gpus" : 1 if is_gpu_available else 0 ,
115140 }
116141
117142 def _estimate_memory_usage (self , X : pd .DataFrame , ** kwargs ) -> int :
118143 hyperparameters = self ._get_model_params ()
119- return self .estimate_memory_usage_static (X = X , problem_type = self .problem_type , num_classes = self .num_classes , hyperparameters = hyperparameters , ** kwargs )
144+ return self .estimate_memory_usage_static (
145+ X = X ,
146+ problem_type = self .problem_type ,
147+ num_classes = self .num_classes ,
148+ hyperparameters = hyperparameters ,
149+ ** kwargs ,
150+ )
120151
152+ # TODO: move memory estimate to specific models below.
121153 @classmethod
122154 def _estimate_memory_usage_static (
123155 cls ,
124156 * ,
125157 X : pd .DataFrame ,
126- hyperparameters : dict = None ,
158+ hyperparameters : dict | None = None ,
127159 ** kwargs ,
128160 ) -> int :
129- """
130- Heuristic memory estimate that is very primitive.
161+ """Heuristic memory estimate that is very primitive.
131162 Can be vastly improved.
132163 """
133164 if hyperparameters is None :
134165 hyperparameters = {}
135166
136- dataset_size_mem_est = 3 * get_approximate_df_mem_usage (X ).sum () # roughly 3x DataFrame memory size
167+ dataset_size_mem_est = (
168+ 3 * get_approximate_df_mem_usage (X ).sum ()
169+ ) # roughly 3x DataFrame memory size
137170 baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
138171
139172 n_rows = X .shape [0 ]
140173 n_features = X .shape [1 ]
141- batch_size = hyperparameters .get ("batch_size" , cls ._get_batch_size (X .shape [0 ] * X .shape [1 ]))
174+ batch_size = hyperparameters .get (
175+ "batch_size" , cls ._get_batch_size (X .shape [0 ] * X .shape [1 ])
176+ )
142177 embedding_dim = 128
143178 bytes_per_float = 4
144- model_mem_estimate = 2 * batch_size * embedding_dim * bytes_per_float * (4 + n_rows ) * n_features
179+ model_mem_estimate = (
180+ 2 * batch_size * embedding_dim * bytes_per_float * (4 + n_rows ) * n_features
181+ )
145182
146183 model_mem_estimate *= 1.3 # add 30% buffer
147184
148185 # TODO: Observed memory spikes above expected values on large datasets, increasing mem estimate to compensate
149186 model_mem_estimate *= 2.0 # Note: 1.5 is not large enough, still gets OOM
150187
151- mem_estimate = model_mem_estimate + dataset_size_mem_est + baseline_overhead_mem_est
152-
153- return mem_estimate
188+ return model_mem_estimate + dataset_size_mem_est + baseline_overhead_mem_est
154189
155190 @classmethod
156191 def _get_default_ag_args_ensemble (cls , ** kwargs ) -> dict :
157- """
158- Set fold_fitting_strategy to sequential_local,
192+ """Set fold_fitting_strategy to sequential_local,
159193 as parallel folding crashes if model weights aren't pre-downloaded.
160194 """
161195 default_ag_args_ensemble = super ()._get_default_ag_args_ensemble (** kwargs )
162196 extra_ag_args_ensemble = {
163- # FIXME: If parallel, uses way more memory, seems to behave incorrectly, so we force sequential.
164197 "fold_fitting_strategy" : "sequential_local" ,
165198 "refit_folds" : True , # Better to refit the model for faster inference and similar quality as the bag.
166199 }
@@ -173,3 +206,57 @@ def _class_tags(cls) -> dict:
173206
174207 def _more_tags (self ) -> dict :
175208 return {"can_refit_full" : True }
209+
210+ @staticmethod
211+ def checkpoint_search_space () -> list [str | tuple [str , str ]]:
212+ raise NotImplementedError ("This method must be implemented in the subclass." )
213+
214+
215+ class TabICLModel (TabICLModelBase ):
216+ """TabICLv1.1 model as used on TabArena."""
217+
218+ ag_key = "TA-TABICL"
219+ ag_name = "TA-TabICL"
220+
221+ default_classification_model : str | None = "tabicl-classifier-v1.1-20250506.ckpt"
222+
223+ @classmethod
224+ def supported_problem_types (cls ) -> list [str ] | None :
225+ return ["binary" , "multiclass" ]
226+
227+ @staticmethod
228+ def checkpoint_search_space () -> list [str ]:
229+ return [
230+ "tabicl-classifier-v1.1-20250506.ckpt" ,
231+ "tabicl-classifier-v1-20250208.ckpt" ,
232+ ]
233+
234+ def _set_default_params (self ):
235+ default_params = {
236+ "n_estimators" : 32 , # default of TabICLv1
237+ }
238+ for param , val in default_params .items ():
239+ self ._set_default_param_value (param , val )
240+
241+ class TabICLv2Model (TabICLModelBase ):
242+ """TabICLv2 model as used on TabArena."""
243+
244+ ag_key = "TA-TABICLv2"
245+ ag_name = "TA-TabICLv2"
246+
247+ default_classification_model : str | None = "tabicl-classifier-v2-20260212.ckpt"
248+ default_regression_model : str | None = "tabicl-regressor-v2-20260212.ckpt"
249+
250+ @classmethod
251+ def supported_problem_types (cls ) -> list [str ] | None :
252+ return ["binary" , "multiclass" , "regression" ]
253+
254+ # TODO: search over v1 checkpoints too?
255+ @staticmethod
256+ def checkpoint_search_space () -> list [tuple [str , str ]]:
257+ return [
258+ (
259+ "tabicl-classifier-v2-20260212.ckpt" ,
260+ "tabicl-regressor-v2-20260212.ckpt" ,
261+ )
262+ ]
0 commit comments