Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6756e39
first implementation in variables and TorchModel
pluflou Jul 23, 2025
714cebf
remove typo added by mistake
pluflou Jul 31, 2025
a4836d6
validate scalars by shape/dim, and refactor scalar validation in torc…
pluflou Jul 31, 2025
71f7153
update input validation in prob model base
pluflou Jul 31, 2025
12e0230
Merge branch 'slaclab:main' into add-arrays
pluflou Jan 9, 2026
9ae9a55
add arrayvariable
pluflou Jan 12, 2026
3ca8eeb
rm print statement
pluflou Jan 12, 2026
8e77e86
test arrayvariable and refactor validation for vars
pluflou Jan 14, 2026
ea4f5e4
fix arranging of inputs for array, clean up input validation
pluflou Jan 15, 2026
c047366
clean up prob model base input validation
pluflou Jan 15, 2026
ddd7d45
rm redundant property
pluflou Jan 15, 2026
fa49494
resolve merge conflicts
pluflou Feb 4, 2026
c5c3048
fix how enums are behaving
pluflou Feb 4, 2026
f021700
merge in enum fix
pluflou Feb 4, 2026
0e303ca
initial implementation of torch-specific ScalarVariable and NDVariable
pluflou Feb 4, 2026
79ad791
update pyproject.toml with my lume-base branch
pluflou Feb 4, 2026
f642b98
update pyproject.toml with my lume-base branch
pluflou Feb 4, 2026
9ff1f8f
add read_only validation and tests for variables
pluflou Feb 5, 2026
2bb40df
clean refs to numpy arrays
pluflou Feb 5, 2026
b72cec3
add torch tensor encoder for proper serialization
pluflou Feb 5, 2026
8d1d3d7
clean up torch_model
pluflou Feb 5, 2026
8e5ace7
clean up gp_model
pluflou Feb 5, 2026
8733628
clean up prob_model_base
pluflou Feb 5, 2026
f78c2bb
rename ProbModelBaseModel -> ProbabilisticBaseModel
pluflou Feb 5, 2026
ddee906
update model w/ ProbabilisticBaseModel
pluflou Feb 5, 2026
c6e3f54
Merge branch 'lume-torch' into ndvariable
pluflou Feb 12, 2026
9dfd26c
move to TorchScalarVariable and add deprecation warning
pluflou Feb 12, 2026
154f000
update NDVariable and ScalarVariable based on new lume-base changes
pluflou Feb 17, 2026
fa503f1
adjust based on lume-base changes to NDVariable
pluflou Feb 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions lume_torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
import numpy as np
from pydantic import BaseModel, ConfigDict, field_validator

from lume_torch.variables import ScalarVariable, get_variable, ConfigEnum
from lume_torch.variables import (
TorchScalarVariable,
get_variable,
ConfigEnum,
DistributionVariable,
TorchNDVariable,
)
from lume_torch.utils import (
try_import_module,
verify_unique_variable_names,
Expand All @@ -34,6 +40,11 @@
np.float64: lambda x: float(x),
}

# Add torch.Tensor encoder if torch is available
torch = try_import_module("torch")
if torch is not None:
JSON_ENCODERS[torch.Tensor] = lambda x: x.tolist()


def process_torch_module(
module,
Expand Down Expand Up @@ -341,9 +352,9 @@ class LUMETorch(BaseModel, ABC):

Attributes
----------
input_variables : list of ScalarVariable
input_variables : list of TorchScalarVariable
List defining the input variables and their order.
output_variables : list of ScalarVariable
output_variables : list of TorchScalarVariable
List defining the output variables and their order.
input_validation_config : dict of str to ConfigEnum, optional
Determines the behavior during input validation by specifying the validation
Expand Down Expand Up @@ -378,8 +389,10 @@ class LUMETorch(BaseModel, ABC):

"""

input_variables: list[ScalarVariable]
output_variables: list[ScalarVariable]
input_variables: list[Union[TorchScalarVariable, TorchNDVariable]]
output_variables: list[
Union[TorchScalarVariable, TorchNDVariable, DistributionVariable]
]
input_validation_config: Optional[dict[str, ConfigEnum]] = None
output_validation_config: Optional[dict[str, ConfigEnum]] = None

Expand All @@ -396,7 +409,7 @@ def validate_input_variables(cls, value):

Returns
-------
list of ScalarVariable
list of TorchScalarVariable
List of validated variable instances.

Raises
Expand All @@ -411,7 +424,14 @@ def validate_input_variables(cls, value):
if isinstance(val, dict):
variable_class = get_variable(val["variable_class"])
new_value.append(variable_class(name=name, **val))
elif isinstance(val, ScalarVariable):
elif isinstance(
val,
(
TorchScalarVariable,
TorchNDVariable,
DistributionVariable,
),
):
new_value.append(val)
else:
raise TypeError(f"type {type(val)} not supported")
Expand Down Expand Up @@ -510,13 +530,18 @@ def input_validation(self, input_dict: dict[str, Any]) -> dict[str, Any]:

"""
for name, value in input_dict.items():
_config = (
"none"
if self.input_validation_config is None
else self.input_validation_config.get(name)
)
var = self.input_variables[self.input_names.index(name)]
var.validate_value(value, config=_config)
if name in self.input_names:
_config = (
None
if self.input_validation_config is None
else self.input_validation_config.get(name)
)
var = self.input_variables[self.input_names.index(name)]
var.validate_value(value, config=_config)
else:
raise ValueError(
f"Input variable {name} not found in model input variables."
)
return input_dict

def output_validation(self, output_dict: dict[str, Any]) -> dict[str, Any]:
Expand Down
10 changes: 5 additions & 5 deletions lume_torch/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from torch.distributions import Normal
from torch.distributions.distribution import Distribution as TDistribution

from lume_torch.models.prob_model_base import ProbModelBaseModel
from lume_torch.models.prob_model_base import ProbabilisticBaseModel
from lume_torch.models.torch_model import TorchModel

logger = logging.getLogger(__name__)


class NNEnsemble(ProbModelBaseModel):
class NNEnsemble(ProbabilisticBaseModel):
"""LUME-model class for neural network ensembles.

This class allows for the evaluation of multiple neural network models as an
Expand All @@ -43,9 +43,9 @@ def __init__(self, *args, **kwargs):
Parameters
----------
*args
Positional arguments forwarded to :class:`ProbModelBaseModel`.
Positional arguments forwarded to :class:`ProbabilisticBaseModel`.
**kwargs
Keyword arguments forwarded to :class:`ProbModelBaseModel`.
Keyword arguments forwarded to :class:`ProbabilisticBaseModel`.

Notes
-----
Expand Down Expand Up @@ -111,7 +111,7 @@ def _get_predictions(
) -> dict[str, TDistribution]:
"""Get the predictions of the ensemble of models.

This implements the abstract method from :class:`ProbModelBaseModel` by
This implements the abstract method from :class:`ProbabilisticBaseModel` by
evaluating each model in the ensemble and aggregating their outputs.

Parameters
Expand Down
Loading
Loading