Skip to content

Commit 2ac64cd

Browse files
authored
Move dtype to config (#189)
1 parent baf59c7 commit 2ac64cd

File tree

5 files changed

+85
-38
lines changed

5 files changed

+85
-38
lines changed

src/pyoframe/_arithmetic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pyoframe._constants import (
1010
COEF_KEY,
1111
CONST_TERM,
12-
KEY_TYPE,
1312
QUAD_VAR_KEY,
1413
RESERVED_COL_KEYS,
1514
VAR_KEY,
@@ -243,7 +242,9 @@ def add(*expressions: Expression) -> Expression:
243242
if any(QUAD_VAR_KEY in df.columns for df in expr_data):
244243
expr_data = [
245244
(
246-
df.with_columns(pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(KEY_TYPE))
245+
df.with_columns(
246+
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
247+
)
247248
if QUAD_VAR_KEY not in df.columns
248249
else df
249250
)
@@ -523,7 +524,7 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
523524
if df.is_empty():
524525
df = pl.DataFrame(
525526
{VAR_KEY: [CONST_TERM], COEF_KEY: [0]},
526-
schema={VAR_KEY: KEY_TYPE, COEF_KEY: pl.Float64},
527+
schema={VAR_KEY: Config.id_dtype, COEF_KEY: pl.Float64},
527528
)
528529

529530
if QUAD_VAR_KEY in df.columns and (df.get_column(QUAD_VAR_KEY) == CONST_TERM).all():

src/pyoframe/_constants.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
SOLUTION_KEY = "solution"
1818
DUAL_KEY = "dual"
1919

20-
# TODO: move as configuration since this could be too small... also add a test to make sure errors occur on overflow.
21-
KEY_TYPE = pl.UInt32
22-
2320

2421
@dataclass
2522
class _Solver:
@@ -101,6 +98,7 @@ class ConfigDefaults:
10198
)
10299
print_max_terms: int = 5
103100
maintain_order: bool = True
101+
id_dtype = pl.UInt32
104102

105103

106104
class _Config:
@@ -322,6 +320,41 @@ def maintain_order(self) -> bool:
322320
def maintain_order(self, value: bool):
323321
self._settings.maintain_order = value
324322

323+
@property
324+
def id_dtype(self):
325+
"""The Polars data type to use for variable and constraint IDs.
326+
327+
Defaults to `pl.UInt32` which should be ideal for most users.
328+
329+
Users with more than 4 billion variables or constraints can change this to `pl.UInt64`.
330+
331+
Users concerned with memory usage and with fewer than 65k variables or constraints can change this to `pl.UInt16`.
332+
333+
!!! warning
334+
Changing this setting after creating a model will lead to errors.
335+
You should only change this setting before creating any models.
336+
337+
Examples:
338+
An error is automatically raised if the number of variables or constraints exceeds the chosen data type:
339+
>>> pf.Config.id_dtype = pl.UInt8
340+
>>> m = pf.Model()
341+
>>> big_set = pf.Set(x=range(2**8 + 1))
342+
>>> m.X = pf.Variable()
343+
>>> m.constraint = m.X.over("x") <= big_set
344+
Traceback (most recent call last):
345+
...
346+
TypeError: Number of constraints exceeds the current data type (UInt8). Consider increasing the data type by changing Config.id_dtype.
347+
>>> m.X_large = pf.Variable(big_set)
348+
Traceback (most recent call last):
349+
...
350+
TypeError: Number of variables exceeds the current data type (UInt8). Consider increasing the data type by changing Config.id_dtype.
351+
"""
352+
return self._settings.id_dtype
353+
354+
@id_dtype.setter
355+
def id_dtype(self, value):
356+
self._settings.id_dtype = value
357+
325358
def reset_defaults(self):
326359
"""Resets all configuration options to their default values.
327360

src/pyoframe/_core.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
CONST_TERM,
2323
CONSTRAINT_KEY,
2424
DUAL_KEY,
25-
KEY_TYPE,
2625
QUAD_VAR_KEY,
2726
RESERVED_COL_KEYS,
2827
SOLUTION_KEY,
@@ -701,7 +700,7 @@ def constant(cls, constant: int | float) -> Expression:
701700
COEF_KEY: [constant],
702701
VAR_KEY: [CONST_TERM],
703702
},
704-
schema={COEF_KEY: pl.Float64, VAR_KEY: KEY_TYPE},
703+
schema={COEF_KEY: pl.Float64, VAR_KEY: Config.id_dtype},
705704
),
706705
name=str(constant),
707706
)
@@ -1215,11 +1214,11 @@ def _add_const(self, const: int | float) -> Expression:
12151214
if CONST_TERM not in data[VAR_KEY]:
12161215
const_df = pl.DataFrame(
12171216
{COEF_KEY: [0.0], VAR_KEY: [CONST_TERM]},
1218-
schema={COEF_KEY: pl.Float64, VAR_KEY: KEY_TYPE},
1217+
schema={COEF_KEY: pl.Float64, VAR_KEY: Config.id_dtype},
12191218
)
12201219
if self.is_quadratic:
12211220
const_df = const_df.with_columns(
1222-
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(KEY_TYPE)
1221+
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
12231222
)
12241223
data = pl.concat(
12251224
[data, const_df],
@@ -1229,11 +1228,11 @@ def _add_const(self, const: int | float) -> Expression:
12291228
keys = (
12301229
data.select(dim)
12311230
.unique(maintain_order=Config.maintain_order)
1232-
.with_columns(pl.lit(CONST_TERM).alias(VAR_KEY).cast(KEY_TYPE))
1231+
.with_columns(pl.lit(CONST_TERM).alias(VAR_KEY).cast(Config.id_dtype))
12331232
)
12341233
if self.is_quadratic:
12351234
keys = keys.with_columns(
1236-
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(KEY_TYPE)
1235+
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
12371236
)
12381237
data = data.join(
12391238
keys,
@@ -1276,7 +1275,7 @@ def constant_terms(self) -> pl.DataFrame:
12761275
if len(constant_terms) == 0:
12771276
return pl.DataFrame(
12781277
{COEF_KEY: [0.0], VAR_KEY: [CONST_TERM]},
1279-
schema={COEF_KEY: pl.Float64, VAR_KEY: KEY_TYPE},
1278+
schema={COEF_KEY: pl.Float64, VAR_KEY: Config.id_dtype},
12801279
)
12811280
return constant_terms
12821281

@@ -1770,23 +1769,25 @@ def _assign_ids(self):
17701769
if is_quadratic
17711770
else poi.ScalarAffineFunction.from_numpy # when called only once from_numpy is faster
17721771
)
1773-
df = self.data.with_columns(
1774-
pl.lit(
1775-
add_constraint(
1776-
create_expression(
1777-
*(
1778-
df.get_column(c).to_numpy()
1779-
for c in ([COEF_KEY] + self.lhs._variable_columns)
1780-
)
1781-
),
1782-
sense,
1783-
0,
1784-
name,
1785-
).index
1772+
constr_id = add_constraint(
1773+
create_expression(
1774+
*(
1775+
df.get_column(c).to_numpy()
1776+
for c in ([COEF_KEY] + self.lhs._variable_columns)
1777+
)
1778+
),
1779+
sense,
1780+
0,
1781+
name,
1782+
).index
1783+
try:
1784+
df = self.data.with_columns(
1785+
pl.lit(constr_id).alias(CONSTRAINT_KEY).cast(Config.id_dtype)
17861786
)
1787-
.alias(CONSTRAINT_KEY)
1788-
.cast(KEY_TYPE)
1789-
)
1787+
except TypeError as e:
1788+
raise TypeError(
1789+
f"Number of constraints exceeds the current data type ({Config.id_dtype}). Consider increasing the data type by changing Config.id_dtype."
1790+
) from e
17901791
else:
17911792
create_expression = (
17921793
poi.ScalarQuadraticFunction
@@ -1875,9 +1876,14 @@ def _assign_ids(self):
18751876
).index
18761877
for s0, s1 in pairwise(split)
18771878
]
1878-
df = df_unique.with_columns(
1879-
pl.Series(ids, dtype=KEY_TYPE).alias(CONSTRAINT_KEY)
1880-
)
1879+
try:
1880+
df = df_unique.with_columns(
1881+
pl.Series(ids, dtype=Config.id_dtype).alias(CONSTRAINT_KEY)
1882+
)
1883+
except TypeError as e:
1884+
raise TypeError(
1885+
f"Number of constraints exceeds the current data type ({Config.id_dtype}). Consider increasing the data type by changing Config.id_dtype."
1886+
) from e
18811887

18821888
self._data = df
18831889

@@ -2399,7 +2405,14 @@ def _assign_ids(self):
23992405
else:
24002406
ids = [poi_add_var(lb, ub, name=name).index for _ in range(n)]
24012407

2402-
df = self.data.with_columns(pl.Series(ids, dtype=KEY_TYPE).alias(VAR_KEY))
2408+
try:
2409+
df = self.data.with_columns(
2410+
pl.Series(ids, dtype=Config.id_dtype).alias(VAR_KEY)
2411+
)
2412+
except TypeError as e:
2413+
raise TypeError(
2414+
f"Number of variables exceeds the current data type ({Config.id_dtype}). Consider increasing the data type by changing Config.id_dtype."
2415+
) from e
24032416

24042417
self._data = df
24052418

src/pyoframe/_model_element.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from pyoframe._arithmetic import _get_dimensions
1111
from pyoframe._constants import (
1212
COEF_KEY,
13-
KEY_TYPE,
1413
QUAD_VAR_KEY,
1514
RESERVED_COL_KEYS,
1615
VAR_KEY,
16+
Config,
1717
)
1818

1919
if TYPE_CHECKING: # pragma: no cover
@@ -41,9 +41,9 @@ def __init__(self, data: pl.DataFrame, name="unnamed") -> None:
4141
if COEF_KEY in data.columns:
4242
data = data.cast({COEF_KEY: pl.Float64})
4343
if VAR_KEY in data.columns:
44-
data = data.cast({VAR_KEY: KEY_TYPE})
44+
data = data.cast({VAR_KEY: Config.id_dtype})
4545
if QUAD_VAR_KEY in data.columns:
46-
data = data.cast({QUAD_VAR_KEY: KEY_TYPE})
46+
data = data.cast({QUAD_VAR_KEY: Config.id_dtype})
4747

4848
self._data = data
4949
self._model: Model | None = None

src/pyoframe/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,12 @@ def __init__(self) -> None:
313313
self._ID_COL = VAR_KEY
314314
self.mapping_registry = pl.DataFrame(
315315
{self._ID_COL: [], self.NAME_COL: []},
316-
schema={self._ID_COL: pl.UInt32, self.NAME_COL: pl.String},
316+
schema={self._ID_COL: Config.id_dtype, self.NAME_COL: pl.String},
317317
)
318318
self._extend_registry(
319319
pl.DataFrame(
320320
{self._ID_COL: [CONST_TERM], self.NAME_COL: [self.CONST_TERM_NAME]},
321-
schema={self._ID_COL: pl.UInt32, self.NAME_COL: pl.String},
321+
schema={self._ID_COL: Config.id_dtype, self.NAME_COL: pl.String},
322322
)
323323
)
324324

0 commit comments

Comments
 (0)