diff --git a/lume_torch/base.py b/lume_torch/base.py index 8fa17dd3..9bb2d133 100644 --- a/lume_torch/base.py +++ b/lume_torch/base.py @@ -10,7 +10,13 @@ import numpy as np from pydantic import BaseModel, ConfigDict, field_validator -from lume_torch.variables import ScalarVariable, get_variable, ConfigEnum +from lume_torch.variables import ( + TorchScalarVariable, + get_variable, + ConfigEnum, + DistributionVariable, + TorchNDVariable, +) from lume_torch.utils import ( try_import_module, verify_unique_variable_names, @@ -34,6 +40,11 @@ np.float64: lambda x: float(x), } +# Add torch.Tensor encoder if torch is available +torch = try_import_module("torch") +if torch is not None: + JSON_ENCODERS[torch.Tensor] = lambda x: x.tolist() + def process_torch_module( module, @@ -341,9 +352,9 @@ class LUMETorch(BaseModel, ABC): Attributes ---------- - input_variables : list of ScalarVariable + input_variables : list of TorchScalarVariable List defining the input variables and their order. - output_variables : list of ScalarVariable + output_variables : list of TorchScalarVariable List defining the output variables and their order. input_validation_config : dict of str to ConfigEnum, optional Determines the behavior during input validation by specifying the validation @@ -378,8 +389,10 @@ class LUMETorch(BaseModel, ABC): """ - input_variables: list[ScalarVariable] - output_variables: list[ScalarVariable] + input_variables: list[Union[TorchScalarVariable, TorchNDVariable]] + output_variables: list[ + Union[TorchScalarVariable, TorchNDVariable, DistributionVariable] + ] input_validation_config: Optional[dict[str, ConfigEnum]] = None output_validation_config: Optional[dict[str, ConfigEnum]] = None @@ -396,7 +409,7 @@ def validate_input_variables(cls, value): Returns ------- - list of ScalarVariable + list of TorchScalarVariable List of validated variable instances. Raises @@ -411,7 +424,14 @@ def validate_input_variables(cls, value): if isinstance(val, dict): variable_class = get_variable(val["variable_class"]) new_value.append(variable_class(name=name, **val)) - elif isinstance(val, ScalarVariable): + elif isinstance( + val, + ( + TorchScalarVariable, + TorchNDVariable, + DistributionVariable, + ), + ): new_value.append(val) else: raise TypeError(f"type {type(val)} not supported") @@ -510,13 +530,18 @@ def input_validation(self, input_dict: dict[str, Any]) -> dict[str, Any]: """ for name, value in input_dict.items(): - _config = ( - "none" - if self.input_validation_config is None - else self.input_validation_config.get(name) - ) - var = self.input_variables[self.input_names.index(name)] - var.validate_value(value, config=_config) + if name in self.input_names: + _config = ( + None + if self.input_validation_config is None + else self.input_validation_config.get(name) + ) + var = self.input_variables[self.input_names.index(name)] + var.validate_value(value, config=_config) + else: + raise ValueError( + f"Input variable {name} not found in model input variables." + ) return input_dict def output_validation(self, output_dict: dict[str, Any]) -> dict[str, Any]: diff --git a/lume_torch/models/ensemble.py b/lume_torch/models/ensemble.py index dfa7e30f..0232eabe 100644 --- a/lume_torch/models/ensemble.py +++ b/lume_torch/models/ensemble.py @@ -10,13 +10,13 @@ from torch.distributions import Normal from torch.distributions.distribution import Distribution as TDistribution -from lume_torch.models.prob_model_base import ProbModelBaseModel +from lume_torch.models.prob_model_base import ProbabilisticBaseModel from lume_torch.models.torch_model import TorchModel logger = logging.getLogger(__name__) -class NNEnsemble(ProbModelBaseModel): +class NNEnsemble(ProbabilisticBaseModel): """LUME-model class for neural network ensembles. This class allows for the evaluation of multiple neural network models as an @@ -43,9 +43,9 @@ def __init__(self, *args, **kwargs): Parameters ---------- *args - Positional arguments forwarded to :class:`ProbModelBaseModel`. + Positional arguments forwarded to :class:`ProbabilisticBaseModel`. **kwargs - Keyword arguments forwarded to :class:`ProbModelBaseModel`. + Keyword arguments forwarded to :class:`ProbabilisticBaseModel`. Notes ----- @@ -111,7 +111,7 @@ def _get_predictions( ) -> dict[str, TDistribution]: """Get the predictions of the ensemble of models. - This implements the abstract method from :class:`ProbModelBaseModel` by + This implements the abstract method from :class:`ProbabilisticBaseModel` by evaluating each model in the ensemble and aggregating their outputs. Parameters diff --git a/lume_torch/models/gp_model.py b/lume_torch/models/gp_model.py index 11f628d2..19e1ea03 100644 --- a/lume_torch/models/gp_model.py +++ b/lume_torch/models/gp_model.py @@ -13,7 +13,7 @@ from linear_operator.operators import DiagLinearOperator from lume_torch.models.prob_model_base import ( - ProbModelBaseModel, + ProbabilisticBaseModel, TorchDistributionWrapper, ) @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class GPModel(ProbModelBaseModel): +class GPModel(ProbabilisticBaseModel): """LUME-model class for Gaussian process (GP) models. This class wraps BoTorch/GPyTorch GP models (``SingleTaskGP``, ``MultiTaskGP``, @@ -69,16 +69,6 @@ class GPModel(ProbModelBaseModel): list[OutcomeTransform | ReversibleInputTransform | torch.nn.Linear] | None ) = None - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.input_transformers = ( - [] if self.input_transformers is None else self.input_transformers - ) - self.output_transformers = ( - [] if self.output_transformers is None else self.output_transformers - ) - @field_validator("model", mode="before") def validate_gp_model(cls, v): if isinstance(v, (str, os.PathLike)): @@ -92,6 +82,8 @@ def validate_gp_model(cls, v): @field_validator("input_transformers", "output_transformers", mode="before") def validate_transformers(cls, v): + if v is None: + return [] if not isinstance(v, list): logger.error(f"Transformers must be a list, got {type(v)}") raise ValueError("Transformers must be passed as list.") @@ -105,8 +97,7 @@ def validate_transformers(cls, v): logger.error(f"Transformer file not found: {t}") raise OSError(f"File {t} is not found.") loaded_transformers.append(t) - v = loaded_transformers - return v + return loaded_transformers def get_input_size(self) -> int: """Get the dimensionality of the input space. @@ -123,20 +114,27 @@ def get_input_size(self) -> int: """ if isinstance(self.model, SingleTaskGP): - num_inputs = self.model.train_inputs[0].shape[-1] + return self.model.train_inputs[0].shape[-1] elif isinstance(self.model, MultiTaskGP): - num_inputs = self.model.train_inputs[0].shape[-1] - 1 + return self.model.train_inputs[0].shape[-1] - 1 elif isinstance(self.model, ModelListGP): - if isinstance(self.model.models[0], SingleTaskGP): - num_inputs = self.model.models[0].train_inputs[0].shape[-1] - elif isinstance(self.model.models[0], MultiTaskGP): - num_inputs = self.model.models[0].train_inputs[0].shape[-1] - 1 + first_model = self.model.models[0] + if isinstance(first_model, SingleTaskGP): + return first_model.train_inputs[0].shape[-1] + elif isinstance(first_model, MultiTaskGP): + return first_model.train_inputs[0].shape[-1] - 1 + else: + logger.error( + f"Unsupported model type in ModelListGP: {type(first_model)}" + ) + raise ValueError( + "ModelListGP must contain SingleTaskGP or MultiTaskGP models." + ) else: logger.error(f"Unsupported GP model type: {type(self.model)}") raise ValueError( "Model must be an instance of SingleTaskGP, MultiTaskGP or ModelListGP." ) - return num_inputs def get_output_size(self) -> int: """Get the dimensionality of the output space. @@ -153,28 +151,14 @@ def get_output_size(self) -> int: """ if isinstance(self.model, ModelListGP): - num_outputs = sum(model.num_outputs for model in self.model.models) - elif isinstance(self.model, SingleTaskGP) or isinstance( - self.model, MultiTaskGP - ): - num_outputs = self.model.num_outputs + return sum(model.num_outputs for model in self.model.models) + elif isinstance(self.model, (SingleTaskGP, MultiTaskGP)): + return self.model.num_outputs else: + logger.error(f"Unsupported GP model type: {type(self.model)}") raise ValueError( "Model must be an instance of SingleTaskGP, MultiTaskGP or ModelListGP." ) - return num_outputs - - @property - def _tkwargs(self): - """Return tensor keyword arguments for this GP model. - - Returns - ------- - dict - Dictionary with ``"device"`` and ``"dtype"`` keys. - - """ - return {"device": self.device, "dtype": self.dtype} def likelihood(self): """Return the likelihood module of the underlying GP model. @@ -214,7 +198,7 @@ def _get_predictions( ) -> dict[str, TDistribution]: """Get predictive distributions from the GP model. - This implements the abstract method from :class:`ProbModelBaseModel` by + This implements the abstract method from :class:`ProbabilisticBaseModel` by constructing a BoTorch posterior and wrapping it as a distribution over the outputs. @@ -236,7 +220,7 @@ def _get_predictions( # Create tensor from input_dict x = super()._create_tensor_from_dict(input_dict) # Transform the input - if self.input_transformers is not None: + if self.input_transformers: x = self._transform_inputs(x) # Get the posterior distribution posterior = self._posterior(x, observation_noise=observation_noise) @@ -343,7 +327,7 @@ def _create_output_dict( # Check that the covariance matrix is positive definite _cov = self._check_covariance_matrix(_cov) - if self.output_transformers is not None: + if self.output_transformers: # TODO: make this more robust? # If we have two outputs, but transformer has length 1 (e.g. multitask), # we should apply the same transform to both outputs @@ -375,7 +359,53 @@ def _transform_inputs(self, input_tensor: torch.Tensor) -> torch.Tensor: input_tensor = transformer(input_tensor) return input_tensor - def _transform_mean(self, mean: torch.Tensor, i) -> torch.Tensor: + def _get_transformer_params( + self, transformer, i: int + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Extract scale factor and offset from a transformer. + + Parameters + ---------- + transformer : ReversibleInputTransform or OutcomeTransform + The transformer to extract parameters from. + i : int + Index of the output dimension. + + Returns + ------- + tuple of (torch.Tensor, torch.Tensor | None) + Scale factor and offset (offset is None for covariance transforms). + + Raises + ------ + NotImplementedError + If the transformer type is not supported. + + """ + if isinstance(transformer, ReversibleInputTransform): + try: + scale_fac = transformer.coefficient[i] + offset = transformer.offset[i] + except IndexError: + # If the transformer has only one coefficient, use it for all outputs + # This is needed in the case of multitask models + scale_fac = transformer.coefficient[0] + offset = transformer.offset[0] + elif isinstance(transformer, OutcomeTransform): + try: + scale_fac = transformer.stdvs.squeeze(0)[i] + offset = transformer.means.squeeze(0)[i] + except IndexError: + # If the transformer has only one coefficient, use it for all outputs + scale_fac = transformer.stdvs.squeeze(0)[0] + offset = transformer.means.squeeze(0)[0] + else: + raise NotImplementedError( + f"Output transformer {type(transformer)} is not supported." + ) + return scale_fac, offset + + def _transform_mean(self, mean: torch.Tensor, i: int) -> torch.Tensor: """(Un-)transform the model output mean. Parameters @@ -392,29 +422,8 @@ def _transform_mean(self, mean: torch.Tensor, i) -> torch.Tensor: """ for transformer in self.output_transformers: - if isinstance(transformer, ReversibleInputTransform): - try: - scale_fac = transformer.coefficient[i] - offset = transformer.offset[i] - except IndexError: - # If the transformer has only one coefficient, use it for all outputs - # This is needed in the case of multitask models - scale_fac = transformer.coefficient[0] - offset = transformer.offset[0] - mean = offset + scale_fac * mean - elif isinstance(transformer, OutcomeTransform): - try: - scale_fac = transformer.stdvs.squeeze(0)[i] - offset = transformer.means.squeeze(0)[i] - except IndexError: - # If the transformer has only one coefficient, use it for all outputs - scale_fac = transformer.stdvs.squeeze(0)[0] - offset = transformer.means.squeeze(0)[0] - mean = offset + scale_fac * mean - else: - raise NotImplementedError( - f"Output transformer {type(transformer)} is not supported." - ) + scale_fac, offset = self._get_transformer_params(transformer, i) + mean = offset + scale_fac * mean return mean def _transform_covar(self, cov: torch.Tensor, i: int) -> torch.Tensor: @@ -434,28 +443,10 @@ def _transform_covar(self, cov: torch.Tensor, i: int) -> torch.Tensor: """ for transformer in self.output_transformers: - if isinstance(transformer, ReversibleInputTransform): - try: - scale_fac = transformer.coefficient[i] - except IndexError: - # If the transformer has only one coefficient, use it for all outputs - scale_fac = transformer.coefficient[0] - scale_fac = scale_fac.expand(cov.shape[:-1]) - scale_mat = DiagLinearOperator(scale_fac) - cov = scale_mat @ cov @ scale_mat - elif isinstance(transformer, OutcomeTransform): - try: - scale_fac = transformer.stdvs.squeeze(0)[i] - except IndexError: - # If the transformer has only one coefficient, use it for all outputs - scale_fac = transformer.stdvs.squeeze(0)[0] - scale_fac = scale_fac.expand(cov.shape[:-1]) - scale_mat = DiagLinearOperator(scale_fac) - cov = scale_mat @ cov @ scale_mat - else: - raise NotImplementedError( - f"Output transformer {type(transformer)} is not supported." - ) + scale_fac, _ = self._get_transformer_params(transformer, i) + scale_fac = scale_fac.expand(cov.shape[:-1]) + scale_mat = DiagLinearOperator(scale_fac) + cov = scale_mat @ cov @ scale_mat return cov def _check_covariance_matrix(self, cov: torch.Tensor) -> torch.Tensor: diff --git a/lume_torch/models/prob_model_base.py b/lume_torch/models/prob_model_base.py index 7bbc54c9..46338e2b 100644 --- a/lume_torch/models/prob_model_base.py +++ b/lume_torch/models/prob_model_base.py @@ -7,13 +7,13 @@ from torch.distributions import Distribution as TDistribution from lume_torch.variables import DistributionVariable -from lume_torch.models.utils import InputDictModel, format_inputs, itemize_dict +from lume_torch.models.utils import format_inputs, itemize_dict from lume_torch.base import LUMETorch logger = logging.getLogger(__name__) -class ProbModelBaseModel(LUMETorch): # TODO: brainstorm a better name +class ProbabilisticBaseModel(LUMETorch): """Abstract base class for probabilistic models. This class provides a common interface for probabilistic models. Subclasses must @@ -36,10 +36,10 @@ class ProbModelBaseModel(LUMETorch): # TODO: brainstorm a better name Methods ------- + _evaluate(input_dict, **kwargs) + Evaluates the model by calling :meth:`_get_predictions`. _get_predictions(input_dict, **kwargs) Abstract method that returns a dictionary of output distributions. - _evaluate(input_dict, **kwargs) - Evaluates the model and returns a dictionary of output distributions. input_validation(input_dict) Validates and normalizes the input dictionary prior to evaluation. output_validation(output_dict) @@ -69,13 +69,9 @@ def validate_output_variables(cls, values: dict[str, Any]) -> dict[str, Any]: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # No range validation for probabilistic models currently implemented - # self.input_validation_config = {x: "none" for x in self.input_names} @property - def dtype( - self, - ): + def dtype(self): """Returns the data type for the model.""" if self.precision == "double": return torch.double @@ -87,6 +83,10 @@ def dtype( f"expected one of ['double', 'single']." ) + @property + def _tkwargs(self) -> dict: + return {"device": self.device, "dtype": self.dtype} + def _arrange_inputs( self, d: dict[str, Union[float, torch.Tensor]] ) -> dict[str, Union[float, torch.Tensor]]: @@ -149,53 +149,52 @@ def _create_tensor_from_dict( "All values must be either floats or tensors, and all tensors must have the same length." ) - @abstractmethod - def _get_predictions( - self, input_dict: dict[str, float | torch.Tensor], **kwargs + def _evaluate( + self, input_dict: dict[str, Union[float, torch.Tensor]], **kwargs ) -> dict[str, TDistribution]: - """Get predictions from the model. + """Evaluate the probabilistic model. + + This method bridges the base class evaluation contract with the probabilistic + model's prediction interface by calling :meth:`_get_predictions`. Parameters ---------- input_dict : dict of str to float or torch.Tensor Dictionary of input variable names to values. Values can be floats or - tensors of shape ``n`` or ``b x n`` (batch mode). + tensors of shape ``n`` or ``b × n`` (batch mode). **kwargs - Additional keyword arguments passed through to the concrete - implementation in subclasses. + Additional keyword arguments forwarded to :meth:`_get_predictions`. Returns ------- - dict of str to TDistribution - A dictionary of output variable names to distributions. + dict of str to torch.distributions.Distribution + Dictionary mapping output variable names to predictive distributions. """ - pass + return self._get_predictions(input_dict, **kwargs) - def _evaluate( - self, input_dict: dict[str, Union[float, torch.Tensor]], **kwargs + @abstractmethod + def _get_predictions( + self, input_dict: dict[str, float | torch.Tensor], **kwargs ) -> dict[str, TDistribution]: - """Evaluate the probabilistic model. + """Get predictions from the model. Parameters ---------- input_dict : dict of str to float or torch.Tensor Dictionary of input variable names to values. Values can be floats or - tensors of shape ``n`` or ``b × n`` (batch mode). + tensors of shape ``n`` or ``b x n`` (batch mode). **kwargs - Additional keyword arguments forwarded to :meth:`_get_predictions`. + Additional keyword arguments passed through to the concrete + implementation in subclasses. Returns ------- - dict of str to torch.distributions.Distribution - Dictionary mapping output variable names to predictive distributions. + dict of str to TDistribution + A dictionary of output variable names to distributions. """ - # Evaluate and get mean and variance for each output - output_dict = self._get_predictions(input_dict, **kwargs) - # Split multi-dimensional output into separate distributions and - # return output dictionary - return output_dict + pass def input_validation(self, input_dict: dict[str, Union[float, torch.Tensor]]): """Validates input dictionary before evaluation. @@ -211,28 +210,18 @@ def input_validation(self, input_dict: dict[str, Union[float, torch.Tensor]]): Validated input dictionary. """ - # validate input type (ints only are cast to floats for scalars) - validated_input = InputDictModel(input_dict=input_dict).input_dict + # validate original inputs (catches dtype mismatches) + super().input_validation(input_dict) + # format inputs as tensors w/o changing the dtype - formatted_inputs = format_inputs(validated_input) - # itemize inputs for validation - itemized_inputs = itemize_dict(formatted_inputs) - - for ele in itemized_inputs: - # validate values that were in the torch tensor - # any ints in the torch tensor will be cast to floats by Pydantic - # but others will be caught, e.g. booleans - ele = InputDictModel(input_dict=ele).input_dict - # validate each value based on its var class and config - super().input_validation(ele) - - # return the validated input dict for consistency w/ casting ints to floats - if any([isinstance(value, torch.Tensor) for value in validated_input.values()]): - validated_input = { - k: v.to(**self._tkwargs).squeeze(-1) for k, v in validated_input.items() - } - - return validated_input + formatted_inputs = format_inputs(input_dict.copy()) + + # cast tensors to expected dtype and device + formatted_inputs = { + k: v.to(**self._tkwargs).squeeze(-1) for k, v in formatted_inputs.items() + } + + return formatted_inputs def output_validation(self, output_dict: dict[str, TDistribution]): """Itemizes tensors before performing output validation. @@ -307,7 +296,7 @@ def variance(self) -> torch.Tensor: result, attr_name = self._get_attr(attribute_names) if attr_name in ["cov", "covariance", "covariance_matrix"]: - return torch.diagonal(torch.tensor(result)) + return torch.diagonal(result) return result @@ -318,7 +307,7 @@ def covariance_matrix(self) -> torch.Tensor: result, _ = self._get_attr(attribute_names) return result - def confidence_region(self) -> Tuple[torch.tensor, torch.tensor]: + def confidence_region(self) -> Tuple[torch.Tensor, torch.Tensor]: """Return a 2-sigma confidence region around the mean. Adapted from :mod:`gpytorch.distributions.multivariate_normal`. @@ -363,7 +352,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: return result # TODO: check fn signature - def rsample(self, sample_shape: torch.Size()) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """Generate reparameterized samples from the custom distribution. Parameters @@ -382,7 +371,7 @@ def rsample(self, sample_shape: torch.Size()) -> torch.Tensor: result, _ = self._get_attr(attribute_names, sample_shape) return result - def sample(self, sample_shape: torch.Size()) -> torch.Tensor: + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """Generate samples from the custom distribution (non-differentiable if using sample). Parameters @@ -398,8 +387,8 @@ def sample(self, sample_shape: torch.Size()) -> torch.Tensor: """ attribute_names = ["sample", "rvs"] # Assume non-torch.Distribution takes an integer sample_shape - sample_shape = sample_shape.numel() - result, _ = self._get_attr(attribute_names, sample_shape) + num_samples = sample_shape.numel() + result, _ = self._get_attr(attribute_names, num_samples) return result def __repr__(self): diff --git a/lume_torch/models/torch_model.py b/lume_torch/models/torch_model.py index 3f2b599d..c46ad6cf 100644 --- a/lume_torch/models/torch_model.py +++ b/lume_torch/models/torch_model.py @@ -8,8 +8,9 @@ from botorch.models.transforms.input import ReversibleInputTransform from lume_torch.base import LUMETorch -from lume_torch.variables import ScalarVariable -from lume_torch.models.utils import itemize_dict, format_inputs, InputDictModel +from lume_torch.variables import TorchScalarVariable, TorchNDVariable +from lume_torch.models.utils import itemize_dict, format_inputs + logger = logging.getLogger(__name__) @@ -24,16 +25,18 @@ class TorchModel(LUMETorch): ---------- model : torch.nn.Module The underlying torch model. - input_variables : list of ScalarVariable - List defining the input variables and their order. - output_variables : list of ScalarVariable + input_variables : list of TorchScalarVariable or TorchNDVariable + List defining the input variables and their order. Supports both scalar + variables and multi-dimensional array variables. + output_variables : list of TorchScalarVariable or TorchNDVariable List defining the output variables and their order. - input_transformers : list of callable or modules + input_transformers : list of ReversibleInputTransform, torch.nn.Linear, or Callable Transformer objects applied to the inputs before passing to the model. - output_transformers : list of callable or modules + output_transformers : list of ReversibleInputTransform, torch.nn.Linear, or Callable Transformer objects applied to the outputs of the model. output_format : {"tensor", "variable", "raw"} - Determines format of outputs. + Determines format of outputs. "tensor" returns tensors, "variable" and + "raw" return scalars where possible. device : torch.device or str Device on which the model will be evaluated. Defaults to ``"cpu"``. fixed_model : bool @@ -57,6 +60,11 @@ class TorchModel(LUMETorch): to(device) Move the model, transformers, and default values to a given device. + Notes + ----- + When using TorchNDVariable inputs, all inputs must be TorchNDVariable. + Mixing TorchScalarVariable and TorchNDVariable is not currently supported. + """ model: torch.nn.Module @@ -124,7 +132,27 @@ def _tkwargs(self): return {"device": self.device, "dtype": self.dtype} @field_validator("model", mode="before") - def validate_torch_model(cls, v): + @classmethod + def validate_torch_model( + cls, v: Union[str, os.PathLike, torch.nn.Module] + ) -> torch.nn.Module: + """Validate and load the torch model from file if needed. + + Parameters + ---------- + v : str, os.PathLike, or torch.nn.Module + Model or path to model file. + + Returns + ------- + torch.nn.Module + Loaded or validated torch model. + + Raises + ------ + OSError + If the model file does not exist. + """ if isinstance(v, (str, os.PathLike)): if os.path.exists(v): fname = v @@ -140,8 +168,27 @@ def validate_torch_model(cls, v): return v @field_validator("input_variables") - def verify_input_default_value(cls, value): - """Verifies that input variables have the required default values.""" + @classmethod + def verify_input_default_value( + cls, value: list[Union[TorchScalarVariable, TorchNDVariable]] + ) -> list[Union[TorchScalarVariable, TorchNDVariable]]: + """Verify that input variables have the required default values. + + Parameters + ---------- + value : list of TorchScalarVariable or TorchNDVariable + Input variables to validate. + + Returns + ------- + list of TorchScalarVariable or TorchNDVariable + Validated input variables. + + Raises + ------ + ValueError + If any input variable is missing a default value. + """ for var in value: if var.default_value is None: logger.error( @@ -153,7 +200,27 @@ def verify_input_default_value(cls, value): return value @field_validator("input_transformers", "output_transformers", mode="before") - def validate_transformers(cls, v): + @classmethod + def validate_transformers(cls, v: Union[list, str, os.PathLike]) -> list: + """Validate and load transformers from files if needed. + + Parameters + ---------- + v : list, str, or os.PathLike + List of transformers or paths to transformer files. + + Returns + ------- + list + List of loaded transformers. + + Raises + ------ + ValueError + If transformers are not provided as a list. + OSError + If a transformer file does not exist. + """ if not isinstance(v, list): logger.error(f"Transformers must be a list, got {type(v)}") raise ValueError("Transformers must be passed as list.") @@ -171,7 +238,25 @@ def validate_transformers(cls, v): return v @field_validator("output_format") - def validate_output_format(cls, v): + @classmethod + def validate_output_format(cls, v: str) -> str: + """Validate the output format. + + Parameters + ---------- + v : str + Output format to validate. + + Returns + ------- + str + Validated output format. + + Raises + ------ + ValueError + If output format is not one of the supported formats. + """ supported_formats = ["tensor", "variable", "raw"] if v not in supported_formats: logger.error( @@ -183,12 +268,38 @@ def validate_output_format(cls, v): return v def _set_precision(self, value: torch.dtype): - """Sets the precision of the model.""" + """Sets the precision of the model and transformers. + + Parameters + ---------- + value : torch.dtype + Dtype to set for the model and transformers. + """ self.model.to(dtype=value) for t in self.input_transformers + self.output_transformers: if isinstance(t, torch.nn.Module): t.to(dtype=value) + def _default_to_tensor( + self, default_value: Union[torch.Tensor, float] + ) -> torch.Tensor: + """Convert a default value to a tensor with proper dtype and device. + + Parameters + ---------- + default_value : torch.Tensor or float + Default value to convert. + + Returns + ------- + torch.Tensor + Default value as a tensor with proper dtype and device. + """ + if isinstance(default_value, torch.Tensor): + return default_value.detach().clone().to(**self._tkwargs) + else: + return torch.tensor(default_value, **self._tkwargs) + def _evaluate( self, input_dict: dict[str, Union[float, torch.Tensor]], @@ -229,30 +340,21 @@ def input_validation(self, input_dict: dict[str, Union[float, torch.Tensor]]): Validated input dictionary. """ - # validate input type (ints only are cast to floats for scalars) - validated_input = InputDictModel(input_dict=input_dict).input_dict + # validate original inputs (catches dtype mismatches) + super().input_validation(input_dict) + # format inputs as tensors w/o changing the dtype - formatted_inputs = format_inputs(validated_input) + formatted_inputs = format_inputs(input_dict) + + # cast tensors to expected dtype and device + formatted_inputs = { + k: v.to(**self._tkwargs) for k, v in formatted_inputs.items() + } + # check default values for missing inputs filled_inputs = self._fill_default_inputs(formatted_inputs) - # itemize inputs for validation - itemized_inputs = itemize_dict(filled_inputs) - - for ele in itemized_inputs: - # validate values that were in the torch tensor - # any ints in the torch tensor will be cast to floats by Pydantic - # but others will be caught, e.g. booleans - ele = InputDictModel(input_dict=ele).input_dict - # validate each value based on its var class and config - super().input_validation(ele) - - # return the validated input dict for consistency w/ casting ints to floats - if any([isinstance(value, torch.Tensor) for value in validated_input.values()]): - validated_input = { - k: v.to(**self._tkwargs) for k, v in validated_input.items() - } - return validated_input + return filled_inputs def output_validation(self, output_dict: dict[str, Union[float, torch.Tensor]]): """Itemize tensors before performing output validation. @@ -263,9 +365,15 @@ def output_validation(self, output_dict: dict[str, Union[float, torch.Tensor]]): Output dictionary to validate. """ - itemized_outputs = itemize_dict(output_dict) - for ele in itemized_outputs: - super().output_validation(ele) + for var in self.output_variables: + if isinstance(var, TorchNDVariable): + # run the validation for TorchNDVariable (arrays/images) + super().output_validation({var.name: output_dict[var.name]}) + elif isinstance(var, TorchScalarVariable): + # itemize scalar tensors for element-wise validation + itemized_outputs = itemize_dict({var.name: output_dict[var.name]}) + for ele in itemized_outputs: + super().output_validation(ele) def random_input(self, n_samples: int = 1) -> dict[str, torch.Tensor]: """Generates random input(s) for the model. @@ -280,15 +388,28 @@ def random_input(self, n_samples: int = 1) -> dict[str, torch.Tensor]: dict of str to torch.Tensor Dictionary of input variable names to tensors. + Notes + ----- + For TorchScalarVariable inputs, generates random values within the variable's + value_range. For TorchNDVariable inputs, repeats the default value for + the requested number of samples. + """ input_dict = {} for var in self.input_variables: - if isinstance(var, ScalarVariable): + if isinstance(var, TorchScalarVariable): input_dict[var.name] = var.value_range[0] + torch.rand( - size=(n_samples,) + size=(n_samples,), **self._tkwargs ) * (var.value_range[1] - var.value_range[0]) - else: - torch.tensor(var.default_value, **self._tkwargs).repeat((n_samples, 1)) + elif isinstance(var, TorchNDVariable): + # For ND variables, repeat the default value for n_samples + # Works for any dimensionality: 1D arrays, 2D matrices, 3D images, etc. + default = self._default_to_tensor(var.default_value) + # Add batch dim and repeat n_samples times (keeping original shape) + # e.g., (3, 64, 64) -> (1, 3, 64, 64) -> (n_samples, 3, 64, 64) + input_dict[var.name] = default.unsqueeze(0).repeat( + n_samples, *([1] * default.ndim) + ) return input_dict def random_evaluate( @@ -366,7 +487,7 @@ def insert_output_transformer( def update_input_variables_to_transformer( self, transformer_loc: int - ) -> list[ScalarVariable]: + ) -> list[TorchScalarVariable]: """Return input variables updated to the transformer at the given location. Updated are the value ranges and defaults of the input variables. This @@ -380,7 +501,7 @@ def update_input_variables_to_transformer( Returns ------- - list of ScalarVariable + list of TorchScalarVariable The updated input variables. """ @@ -424,18 +545,14 @@ def update_input_variables_to_transformer( ) # backtrack through transformers for transformer in self.input_transformers[:transformer_loc][::-1]: - if isinstance( - self.input_transformers[transformer_loc], ReversibleInputTransform - ): + if isinstance(transformer, ReversibleInputTransform): x = transformer.untransform(x) - elif isinstance( - self.input_transformers[transformer_loc], torch.nn.Linear - ): + elif isinstance(transformer, torch.nn.Linear): w, b = transformer.weight, transformer.bias x = torch.matmul((x - b), torch.linalg.inv(w.T)) else: raise NotImplementedError( - f"Reverse transformation for type {type(self.input_transformers[transformer_loc])} is not supported." + f"Reverse transformation for type {type(transformer)} is not supported." ) x_new[key] = x @@ -464,17 +581,20 @@ def _fill_default_inputs( """ for var in self.input_variables: if var.name not in input_dict.keys(): - input_dict[var.name] = torch.tensor(var.default_value, **self._tkwargs) + input_dict[var.name] = self._default_to_tensor(var.default_value) + return input_dict def _arrange_inputs( self, formatted_inputs: dict[str, torch.Tensor] ) -> torch.Tensor: - """Enforce the order of input variables. + """Enforces ordering, batching, and default filling of inputs. - Enforces the order of the input variables to be passed to the - transformers and model and updates the returned tensor with default - values for any inputs that are missing. + * If all inputs are `TorchNDVariable`, stacks them into shape `(batch, num_arrays, *array_shape)`. + * If all inputs are `TorchScalarVariable`, concatenates them so the last dimension matches the number of inputs, + broadcasting defaults as needed. + * If a mix of array and scalar inputs is provided, raises `NotImplementedError`. + * Missing inputs are filled with their default values before arranging. Parameters ---------- @@ -487,26 +607,103 @@ def _arrange_inputs( Ordered input tensor to be passed to the transformers. """ - default_tensor = torch.tensor( - [var.default_value for var in self.input_variables], **self._tkwargs + contains_array = any( + isinstance(v, TorchNDVariable) for v in self.input_variables ) + contains_scalar = any( + isinstance(v, TorchScalarVariable) for v in self.input_variables + ) + + if contains_array and contains_scalar: + raise NotImplementedError( + "Mixing TorchScalarVariable and TorchNDVariable inputs is not supported." + ) + + # All TorchNDVariable + if contains_array: + tensor_list = [] + batch_shape = None + for var in self.input_variables: + if var.name in formatted_inputs: + value = formatted_inputs[var.name] + else: + value = self._default_to_tensor(var.default_value) + + expected_sample_shape = tuple(var.shape) + sample_ndim = len(expected_sample_shape) + if value.shape[-sample_ndim:] != expected_sample_shape: + raise ValueError( + f"Input {var.name} has shape {value.shape}, " + f"expected sample shape {expected_sample_shape}" + ) + + current_batch = value.shape[:-sample_ndim] + if current_batch == torch.Size(): + # No batch dim provided -> add singleton batch + value = value.unsqueeze(0) + current_batch = torch.Size([1]) + + if batch_shape is None: + batch_shape = current_batch + elif current_batch != batch_shape: + raise ValueError( + f"Inputs have inconsistent batch shapes: " + f"{batch_shape} vs {current_batch}" + ) - # determine input shape - input_shapes = [formatted_inputs[k].shape for k in formatted_inputs.keys()] - if not all(ele == input_shapes[0] for ele in input_shapes): - raise ValueError("Inputs have inconsistent shapes.") + tensor_list.append(value.to(**self._tkwargs)) + + stacked = torch.stack(tensor_list, dim=1) # (batch, num_arrays, ...) + logger.debug(f"Arranged ND inputs into tensor shape: {stacked.shape}") + return stacked + + # All TorchScalarVariables + default_list = [] + for var in self.input_variables: + default_list.append(self._default_to_tensor(var.default_value)) - input_tensor = torch.tile(default_tensor, dims=(*input_shapes[0], 1)) - for key, value in formatted_inputs.items(): - input_tensor[..., self.input_names.index(key)] = value + default_tensor = torch.cat([d.flatten() for d in default_list]).to( + **self._tkwargs + ) + + if formatted_inputs: + batch_shape = None + for var in self.input_variables: + if var.name in formatted_inputs: + value = formatted_inputs[var.name] + if value.ndim > 0 and value.shape[-1] in (0, 1): + batch_shape = value.shape[:-1] + else: + batch_shape = value.shape + break + + if batch_shape and len(batch_shape) > 0: + expanded_shape = (*batch_shape, default_tensor.shape[0]) + input_tensor = ( + default_tensor.unsqueeze(0).expand(expanded_shape).clone() + ) + else: + input_tensor = default_tensor.unsqueeze(0) + + current_idx = 0 + for var in self.input_variables: + if var.name in formatted_inputs: + value = formatted_inputs[var.name] + if value.ndim > 0 and value.shape[-1] == 1: + input_tensor[..., current_idx] = value.squeeze(-1) + else: + input_tensor[..., current_idx] = value + current_idx += 1 + else: + input_tensor = default_tensor.unsqueeze(0) - if input_tensor.shape[-1] != len(self.input_names): + expected_features = len(self.input_variables) + if input_tensor.shape[-1] != expected_features: raise ValueError( - f""" - Last dimension of input tensor doesn't match the expected number of inputs\n - received: {default_tensor.shape}, expected {len(self.input_names)} as the last dimension - """ + "Last dimension of input tensor doesn't match the expected number of features\n" + f"received: {input_tensor.shape}, expected {expected_features} as the last dimension" ) + return input_tensor def _transform_inputs(self, input_tensor: torch.Tensor) -> torch.Tensor: @@ -576,13 +773,78 @@ def _parse_outputs(self, output_tensor: torch.Tensor) -> dict[str, torch.Tensor] """ parsed_outputs = {} - if output_tensor.dim() in [0, 1]: + + # Check if all outputs are scalar variables + all_scalars = all( + isinstance(v, TorchScalarVariable) for v in self.output_variables + ) + + # Handle 0D and 1D tensors + if output_tensor.dim() == 0: + # 0D tensor - always add batch dimension at start output_tensor = output_tensor.unsqueeze(0) + elif output_tensor.dim() == 1: + # 1D tensor - interpretation depends on variable types + if all_scalars and len(self.output_names) == 1: + # For single scalar output, 1D means (batch,) -> should become (batch, 1) + output_tensor = output_tensor.unsqueeze(-1) + elif all_scalars and len(self.output_names) > 1: + # For multiple scalar outputs, 1D means (features,) -> should become (1, features) + output_tensor = output_tensor.unsqueeze(0) + else: + # For non-scalar outputs, default to adding batch dimension at start + output_tensor = output_tensor.unsqueeze(0) + if len(self.output_names) == 1: - parsed_outputs[self.output_names[0]] = output_tensor.squeeze() + output = output_tensor + # For scalar variables, ensure shape is (batch, 1) for single-sample batched outputs + # or (batch, samples) for multi-sample outputs + if all_scalars: + if output.dim() == 2: + # Already 2D: could be (batch, 1) or (batch, samples) - keep as is + parsed_outputs[self.output_names[0]] = output + elif output.dim() == 1: + # Shape is (batch,), reshape to (batch, 1) + parsed_outputs[self.output_names[0]] = output.unsqueeze(-1) + else: + # 3D or higher dimensional - squeeze last dim if it's 1 + # This handles multi-sample cases: (batch, samples, 1) -> (batch, samples) + if output.shape[-1] == 1: + parsed_outputs[self.output_names[0]] = output.squeeze(-1) + else: + # Shouldn't happen, but handle by squeezing all and adding feature dim + parsed_outputs[self.output_names[0]] = ( + output.squeeze().unsqueeze(-1) + if output.squeeze().dim() > 0 + else output.squeeze().unsqueeze(0).unsqueeze(-1) + ) + else: + # For non-scalar outputs (NDVariable), keep original behavior + parsed_outputs[self.output_names[0]] = output.squeeze() else: for idx, output_name in enumerate(self.output_names): - parsed_outputs[output_name] = output_tensor[..., idx].squeeze() + output = output_tensor[..., idx] + var = self.output_variables[idx] + + # For scalar variables, ensure shape is (batch, 1) for batched outputs + if isinstance(var, TorchScalarVariable): + if output.dim() == 1: + # Shape is (batch,), reshape to (batch, 1) + parsed_outputs[output_name] = output.unsqueeze(-1) + elif output.dim() == 0: + # Scalar output, reshape to (1, 1) + parsed_outputs[output_name] = output.unsqueeze(0).unsqueeze(-1) + else: + # Already has proper dimensions or higher, ensure last dim is 1 + parsed_outputs[output_name] = ( + output.squeeze().unsqueeze(-1) + if output.squeeze().dim() > 0 + else output.squeeze().unsqueeze(0).unsqueeze(-1) + ) + else: + # For non-scalar outputs (NDVariable), keep original behavior + parsed_outputs[output_name] = output.squeeze() + return parsed_outputs def _prepare_outputs( diff --git a/lume_torch/models/torch_module.py b/lume_torch/models/torch_module.py index d19b0708..6d67ae4f 100644 --- a/lume_torch/models/torch_module.py +++ b/lume_torch/models/torch_module.py @@ -111,7 +111,7 @@ def output_order(self): return self._output_order def forward(self, x: torch.Tensor): - # input shape: [n_batch, n_samples, n_dim] + # input shape: [..., n_features] or [..., n_features, 1] for scalar variables x = self._validate_input(x) model_input = self._tensor_to_dictionary(x) y_model = self.evaluate_model(model_input) @@ -243,15 +243,32 @@ def manipulate_output(self, y_model: dict[str, torch.Tensor]): def _tensor_to_dictionary(self, x: torch.Tensor): input_dict = {} - for idx, input_name in enumerate(self.input_order): - input_dict[input_name] = x[..., idx].unsqueeze(-1) + # Handle both old format (..., n_features) and new format (..., n_features, 1) + if x.shape[-1] == 1: + # New scalar format: (..., n_features, 1) + # Index the second-to-last dimension and keep trailing 1 + for idx, input_name in enumerate(self.input_order): + input_dict[input_name] = x[..., idx, :] + else: + # Old format: (..., n_features) + # Index last dimension and add trailing 1 + for idx, input_name in enumerate(self.input_order): + input_dict[input_name] = x[..., idx].unsqueeze(-1) return input_dict def _dictionary_to_tensor(self, y_model: dict[str, torch.Tensor]): - output_tensor = torch.stack( - [y_model[output_name].unsqueeze(-1) for output_name in self.output_order], - dim=-1, - ) + # Model outputs have shape (batch, 1) for single-sample or (batch, samples) for multi-sample + # We need to stack them into (..., n_outputs) format + output_list = [] + for output_name in self.output_order: + output = y_model[output_name] + # If output has trailing 1 (single-sample), squeeze it for stacking + # Multi-sample outputs like (batch, samples) are kept as-is + if output.shape[-1] == 1: + output = output.squeeze(-1) + output_list.append(output) + + output_tensor = torch.stack(output_list, dim=-1) return output_tensor @staticmethod diff --git a/lume_torch/models/utils.py b/lume_torch/models/utils.py index 991e9b1d..e293f1bf 100644 --- a/lume_torch/models/utils.py +++ b/lume_torch/models/utils.py @@ -5,13 +5,24 @@ import torch from torch.distributions import Distribution + +def _flatten_and_itemize(value): + if isinstance(value, torch.Tensor): + return [v.item() for v in value.flatten()] + else: + return [value] + + logger = logging.getLogger(__name__) def itemize_dict( d: dict[str, Union[float, torch.Tensor, Distribution]], ) -> list[dict[str, Union[float, torch.Tensor]]]: - """Itemizes the given in-/output dictionary. + """ + Converts a dictionary of values (floats or torch tensors) into a flat list of dictionaries, + each containing the key-value pairs for the scalar elements in the original arrays/tensors. + If the input dictionary contains only scalars (no arrays/tensors), returns a list with the original dict. Parameters ---------- @@ -24,15 +35,16 @@ def itemize_dict( List of in-/output dictionaries, each containing only a single value per in-/output. """ - has_tensors = any([isinstance(value, torch.Tensor) for value in d.values()]) + has_tensors = any(isinstance(value, torch.Tensor) for value in d.values()) itemized_dicts = [] if has_tensors: for k, v in d.items(): - for i, ele in enumerate(v.flatten()): + flat = _flatten_and_itemize(v) + for i, ele in enumerate(flat): if i >= len(itemized_dicts): - itemized_dicts.append({k: ele.item()}) + itemized_dicts.append({k: ele}) else: - itemized_dicts[i][k] = ele.item() + itemized_dicts[i][k] = ele else: itemized_dicts = [d] return itemized_dicts diff --git a/lume_torch/utils.py b/lume_torch/utils.py index 26f5e317..ace50b1e 100644 --- a/lume_torch/utils.py +++ b/lume_torch/utils.py @@ -5,7 +5,7 @@ import importlib from typing import Union, get_origin, get_args -from lume_torch.variables import ScalarVariable, get_variable +from lume_torch.variables import TorchScalarVariable, get_variable logger = logging.getLogger(__name__) @@ -37,14 +37,14 @@ def try_import_module(name: str): return module -def verify_unique_variable_names(variables: list[ScalarVariable]): +def verify_unique_variable_names(variables: list[TorchScalarVariable]): """Verifies that variable names are unique. Raises a ValueError if any reoccurring variable names are found. Parameters ---------- - variables : list of ScalarVariable + variables : list of TorchScalarVariable List of scalar variables. Raises @@ -117,17 +117,17 @@ def deserialize_variables(v): def variables_as_yaml( - input_variables: list[ScalarVariable], - output_variables: list[ScalarVariable], + input_variables: list[TorchScalarVariable], + output_variables: list[TorchScalarVariable], file: Union[str, os.PathLike] = None, ) -> str: """Returns and optionally saves YAML formatted string defining the in- and output variables. Parameters ---------- - input_variables : list of ScalarVariable + input_variables : list of TorchScalarVariable List of input variables. - output_variables : list of ScalarVariable + output_variables : list of TorchScalarVariable List of output variables. file : str or os.PathLike, optional If not None, YAML formatted string is saved to given file path. @@ -157,7 +157,7 @@ def variables_as_yaml( def variables_from_dict( config: dict, -) -> tuple[list[ScalarVariable], list[ScalarVariable]]: +) -> tuple[list[TorchScalarVariable], list[TorchScalarVariable]]: """Parses given config and returns in- and output variable lists. Parameters @@ -167,7 +167,7 @@ def variables_from_dict( Returns ------- - tuple of (list of ScalarVariable, list of ScalarVariable) + tuple of (list of TorchScalarVariable, list of TorchScalarVariable) In- and output variable lists. """ @@ -191,7 +191,7 @@ def variables_from_dict( def variables_from_yaml( yaml_obj: Union[str, os.PathLike], -) -> tuple[list[ScalarVariable], list[ScalarVariable]]: +) -> tuple[list[TorchScalarVariable], list[TorchScalarVariable]]: """Parses YAML object and returns in- and output variable lists. Parameters @@ -201,7 +201,7 @@ def variables_from_yaml( Returns ------- - tuple of (list of ScalarVariable, list of ScalarVariable) + tuple of (list of TorchScalarVariable, list of TorchScalarVariable) In- and output variable lists. """ diff --git a/lume_torch/variables.py b/lume_torch/variables.py index 023192b6..835e625c 100644 --- a/lume_torch/variables.py +++ b/lume_torch/variables.py @@ -2,22 +2,34 @@ This module contains definitions of LUME-model variables for use with lume tools. Variables are designed as pure descriptors and thus aren't intended to hold actual values, but they can be used to validate encountered values. - -For now, only scalar floating-point variables are supported. """ import logging -from typing import Optional, Type +import warnings +from typing import Optional, Type, Union, ClassVar +import torch +from torch import Tensor from torch.distributions import Distribution as TDistribution -from lume.variables import Variable, ScalarVariable, ConfigEnum +from pydantic import field_validator, model_validator, ConfigDict + +from lume.variables import Variable, ScalarVariable, NDVariable, ConfigEnum logger = logging.getLogger(__name__) +# Rename base ScalarVariable for internal use +_BaseScalarVariable = ScalarVariable + # Re-export base classes for backward compatibility and clean API +# We alias TorchScalarVariable (defined below) as ScalarVariable for backwards compatibility +# with a deprecation warning. See end of this module for the aliasing. +# NOTE: ScalarVariable will be deprecated in the next release - use TorchScalarVariable instead. __all__ = [ "Variable", "ScalarVariable", + "TorchScalarVariable", + "NDVariable", + "TorchNDVariable", "ConfigEnum", "DistributionVariable", "get_variable", @@ -53,11 +65,13 @@ def validate_value(self, value: TDistribution, config: ConfigEnum = None): If the value is not an instance of Distribution. """ - _config = self.default_validation_config if config is None else config + config = self.default_validation_config if config is None else config + if isinstance(config, str): + config = ConfigEnum(config) # mandatory validation self._validate_value_type(value) # optional validation - if config != "none": + if config != ConfigEnum.NULL: pass # not implemented @staticmethod @@ -69,6 +83,266 @@ def _validate_value_type(value: TDistribution): ) +class TorchScalarVariable(_BaseScalarVariable): + """Variable for scalar values represented as PyTorch tensors. + + This class extends ScalarVariable to support scalar values as torch.Tensor + with 0 or 1 dimensions (i.e., a scalar tensor or a single-element tensor). + + Attributes + ---------- + default_value : Tensor | None + Default value for the variable (must be 0D or 1D with size 1). + dtype : torch.dtype | None + Optional data type of the tensor. If specified, validates that tensor values + match this exact dtype. If None (default), only validates that the dtype is + a floating-point type without enforcing a specific precision. + value_range : tuple[float, float] | None + Value range that is considered valid for the variable. + unit : str | None + Unit associated with the variable. + + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + default_value: Optional[Union[Tensor, float]] = None + dtype: Optional[torch.dtype] = None + + @field_validator("dtype", mode="before") + @classmethod + def validate_dtype(cls, value): + """Validate that dtype is a torch.dtype and is a floating-point type.""" + if value is None: + return None + + # Validate that value is a torch.dtype instance + if not isinstance(value, torch.dtype): + raise TypeError( + f"dtype must be a torch.dtype instance, " + f"got {type(value).__name__}. " + f"Received value: {repr(value)}" + ) + + # Validate that the dtype is a floating-point type + if not value.is_floating_point: + raise ValueError(f"dtype must be a floating-point type, got {value}") + return value + + @model_validator(mode="after") + def validate_default_value(self): + if self.default_value is not None: + self._validate_value_type(self.default_value) + self._validate_dtype(self.default_value) + if self.value_range is not None: + scalar_value = ( + self.default_value.item() + if isinstance(self.default_value, Tensor) + else self.default_value + ) + if not self._value_is_within_range(scalar_value): + raise ValueError( + "Default value ({}) is out of valid range: ([{},{}]).".format( + scalar_value, *self.value_range + ) + ) + return self + + def validate_value(self, value: Union[Tensor, float], config: ConfigEnum = None): + """Validates the given tensor or float value. + + Parameters + ---------- + value : Tensor | float + The value to be validated. If a tensor, must be 0D or 1D with size 1. + config : ConfigEnum, optional + The configuration for validation. Defaults to None. + Allowed values are "none", "warn", and "error". + + Raises + ------ + TypeError + If the value is not a torch.Tensor or float. + ValueError + If a tensor has more than 1 dimension, or if 1D with size != 1, + or if tensor dtype is not a float type, or if value is out of range. + + """ + config = self.default_validation_config if config is None else config + if isinstance(config, str): + config = ConfigEnum(config) + + # mandatory validation + self._validate_value_type(value) + self._validate_dtype(value) + self._validate_read_only(value) + + # optional validation + if config != ConfigEnum.NULL: + scalar_value = value.item() if isinstance(value, Tensor) else value + self._validate_value_is_within_range(scalar_value, config=config) + + def _validate_value_type(self, value): + """Validates that value is a torch.Tensor (0D, 1D or batched 1D) or a regular float/int.""" + if isinstance(value, Tensor): + if value.ndim == 0 or value.shape[-1] == 1: + pass # Batched scalars with shape (batch_size, 1), valid + else: + raise ValueError( + f"Expected tensor with 0 dimensions, or multi-dimensional tensor " + f"with last dimension equal to 1 for batched scalar values, " + f"but got {value.ndim} dimensions with shape {value.shape}." + ) + else: + # Delegate to parent class for non-tensor validation + _BaseScalarVariable._validate_value_type(value) + + def _validate_dtype(self, value): + """Validates the dtype of the tensor is a float type. Skips check for regular floats.""" + if not isinstance(value, Tensor): + return # Regular floats don't have dtype to validate + if not value.dtype.is_floating_point: + raise ValueError( + f"Expected tensor dtype to be a floating-point type, got {value.dtype}." + ) + if self.dtype and value.dtype != self.dtype: + raise ValueError(f"Expected dtype {self.dtype}, got {value.dtype}") + + def _validate_read_only(self, value: Union[Tensor, float]): + """Validates that read-only variables match their default value. + + Handles batched tensors by ensuring ALL values in the batch equal the default. + """ + if not self.read_only: + return + + if self.default_value is None: + raise ValueError( + f"Variable '{self.name}' is read-only but has no default value." + ) + + # Extract scalar value from default if it's a tensor + if isinstance(self.default_value, Tensor): + expected_scalar = self.default_value.item() + else: + expected_scalar = self.default_value + + # Compare based on actual value type + if isinstance(value, Tensor): + # For batched tensors, check that ALL values equal the default + # Broadcast expected to match value's shape for comparison + expected_broadcasted = torch.full_like(value, expected_scalar) + values_match = torch.allclose( + value, expected_broadcasted, rtol=1e-9, atol=1e-9 + ) + else: + # Scalar comparison + values_match = abs(expected_scalar - value) < 1e-9 + + if not values_match: + raise ValueError( + f"Variable '{self.name}' is read-only and must equal its default value " + f"({expected_scalar}), but received {value}." + ) + + +class TorchNDVariable(NDVariable): + """Variable for PyTorch tensor data. + + Attributes + ---------- + default_value : Tensor | None + Default value for the variable. Must match the expected + shape and dtype if provided. Defaults to None. + dtype : torch.dtype + Data type of the tensor. Defaults to torch.float32. + + Examples + -------- + >>> import torch + >>> from lume_torch.variables import TorchNDVariable + >>> + >>> var = TorchNDVariable( + ... name="my_tensor", + ... shape=(3, 4), + ... dtype=torch.float32, + ... unit="m" + ... ) + >>> + >>> tensor = torch.rand(3, 4) + >>> var.validate_value(tensor, config="error") # Passes + + """ + + default_value: Optional[Tensor] = None + dtype: torch.dtype = torch.float32 + array_type: ClassVar[type] = Tensor + + def _validate_read_only(self, value: Tensor) -> None: + """Validates that read-only ND-variables match their default value. + + Uses element-wise comparison for tensors. + Handles batched tensors by comparing each batch element to the default. + """ + if not self.read_only: + return + + if self.default_value is None: + raise ValueError( + f"Variable '{self.name}' is read-only but has no default value." + ) + + # Get the expected shape dimensions + expected_ndim = len(self.shape) + + # Check if value is batched + if value.ndim > expected_ndim: + # Batched input - compare each batch item to default + # Flatten all leading batch dimensions into single dimension + value_flat = value.reshape(-1, *self.shape) + + # Check all batch items against default + for i in range(value_flat.shape[0]): + if not torch.allclose( + value_flat[i], self.default_value, rtol=1e-9, atol=1e-9 + ): + raise ValueError( + f"Variable '{self.name}' is read-only and must equal its default value, " + f"but received different array values in batch element {i}." + ) + else: + # Single input - direct comparison + if not torch.allclose(value, self.default_value, rtol=1e-9, atol=1e-9): + raise ValueError( + f"Variable '{self.name}' is read-only and must equal its default value, " + f"but received different array values." + ) + + def validate_value(self, value: Tensor, config: str = None): + super().validate_value(value, config) + self._validate_read_only(value) + + +# Alias TorchScalarVariable as ScalarVariable for backwards compatibility +# This will be deprecated in the next release +class ScalarVariable(TorchScalarVariable): + """Deprecated alias for TorchScalarVariable. + + .. deprecated:: + ScalarVariable is deprecated and will be removed in the next release. + Use TorchScalarVariable instead. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "ScalarVariable is deprecated and will be removed in the next release. " + "Please use TorchScalarVariable instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + + def get_variable(name: str) -> Type[Variable]: """Returns the Variable subclass with the given name. @@ -83,7 +357,12 @@ def get_variable(name: str) -> Type[Variable]: Variable subclass with the given name. """ - classes = [ScalarVariable, DistributionVariable] + classes = [ + TorchScalarVariable, + ScalarVariable, + DistributionVariable, + TorchNDVariable, + ] class_lookup = {c.__name__: c for c in classes} if name not in class_lookup.keys(): logger.error( diff --git a/pyproject.toml b/pyproject.toml index aa737b09..908190d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "numpy", "pyyaml", "botorch>=0.15", - "lume-base @ git+https://github.com/roussel-ryan/lume-base@main" + "lume-base @ git+https://github.com/pluflou/lume-base@add-ndvariable" ] dynamic = ["version"] [tool.setuptools_scm] diff --git a/tests/conftest.py b/tests/conftest.py index a48f0061..b654fa96 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import pytest from lume_torch.utils import variables_from_yaml -from lume_torch.variables import ScalarVariable, DistributionVariable +from lume_torch.variables import TorchScalarVariable, DistributionVariable try: import torch @@ -24,12 +24,17 @@ def rootdir() -> str: # TorchModel fixtures @pytest.fixture(scope="session") -def simple_variables() -> dict[str, Union[list[ScalarVariable], list[ScalarVariable]]]: +def simple_variables() -> dict[ + str, Union[list[TorchScalarVariable], list[TorchScalarVariable]] +]: input_variables = [ - ScalarVariable(name="input1", default_value=1.0, value_range=(0.0, 5.0)), - ScalarVariable(name="input2", default_value=2.0, value_range=(1.0, 3.0)), + TorchScalarVariable(name="input1", default_value=1.0, value_range=(0.0, 5.0)), + TorchScalarVariable(name="input2", default_value=2.0, value_range=(1.0, 3.0)), + ] + output_variables = [ + TorchScalarVariable(name="output1"), + TorchScalarVariable(name="output2"), ] - output_variables = [ScalarVariable(name="output1"), ScalarVariable(name="output2")] return {"input_variables": input_variables, "output_variables": output_variables} @@ -46,7 +51,9 @@ def california_model_info(rootdir) -> dict[str, str]: @pytest.fixture(scope="module") -def california_variables(rootdir) -> tuple[list[ScalarVariable], list[ScalarVariable]]: +def california_variables( + rootdir, +) -> tuple[list[TorchScalarVariable], list[TorchScalarVariable]]: try: file = f"{rootdir}/test_files/california_regression/variables.yml" input_variables, output_variables = variables_from_yaml(file) @@ -112,7 +119,8 @@ def california_test_input_tensor(rootdir: str): test_input_tensor = torch.load( f"{rootdir}/test_files/california_regression/test_input_tensor.pt", weights_only=False, - ) + ).unsqueeze(-1) + print(f"Loaded test input tensor with shape {test_input_tensor.shape}") except FileNotFoundError as e: pytest.skip(str(e)) return test_input_tensor @@ -125,9 +133,11 @@ def california_test_input_dict( pytest.importorskip("botorch") test_input_dict = { - key: california_test_input_tensor[0, idx] + key: california_test_input_tensor[:, idx] for idx, key in enumerate(california_model_info["model_in_list"]) } + for key, tensor in test_input_dict.items(): + print(f"Test input for {key} has shape {tensor.shape}") return test_input_dict @@ -148,9 +158,9 @@ def california_module(california_model): # GPModel fixtures @pytest.fixture(scope="session") def gp_variables() -> dict[ - str, Union[list[ScalarVariable], list[DistributionVariable]] + str, Union[list[TorchScalarVariable], list[DistributionVariable]] ]: - input_variables = [ScalarVariable(name="input")] + input_variables = [TorchScalarVariable(name="input")] output_variables = [ DistributionVariable(name="output1"), DistributionVariable(name="output2"), diff --git a/tests/models/test_fixed_variable_model.py b/tests/models/test_fixed_variable_model.py index 48448882..bba411ed 100644 --- a/tests/models/test_fixed_variable_model.py +++ b/tests/models/test_fixed_variable_model.py @@ -4,7 +4,7 @@ try: import torch - from lume_torch.variables import ScalarVariable + from lume_torch.variables import TorchScalarVariable from lume_torch.models.torch_module import ( TorchModel, TorchModule, @@ -29,11 +29,11 @@ def forward(self, X): return 0.5 * x**2 + y**2 input_variables = [ - ScalarVariable(name="x", default_value=0.0, value_range=[-3.0, 3.0]), - ScalarVariable(name="y", default_value=1.0, value_range=[0.5, 1.5]), + TorchScalarVariable(name="x", default_value=0.0, value_range=[-3.0, 3.0]), + TorchScalarVariable(name="y", default_value=1.0, value_range=[0.5, 1.5]), ] - output_variables = [ScalarVariable(name="f")] + output_variables = [TorchScalarVariable(name="f")] torch_model = TorchModel( model=PriorTorchModel(), diff --git a/tests/models/test_torch_model.py b/tests/models/test_torch_model.py index 062b0710..7bf46ec3 100644 --- a/tests/models/test_torch_model.py +++ b/tests/models/test_torch_model.py @@ -12,7 +12,7 @@ ReversibleInputTransform, ) from lume_torch.models import TorchModel - from lume_torch.variables import ScalarVariable + from lume_torch.variables import TorchScalarVariable torch.manual_seed(42) except ImportError: @@ -42,7 +42,9 @@ def test_model_from_objects( self, california_model_info: dict[str, str], california_model_kwargs: dict[str, Union[list, dict, str]], - california_variables: tuple[list[ScalarVariable], list[ScalarVariable]], + california_variables: tuple[ + list[TorchScalarVariable], list[TorchScalarVariable] + ], california_transformers: tuple[list, list], california_model, ): @@ -94,7 +96,11 @@ def test_precision(self, california_model): def test_model_evaluate_single_sample( self, california_test_input_dict: dict, california_model ): - results = california_model.evaluate(california_test_input_dict) + # Extract single sample from the batched input + single_sample_input = { + key: value[0] for key, value in california_test_input_dict.items() + } + results = california_model.evaluate(single_sample_input) assert isinstance(results["MedHouseVal"], torch.Tensor) assert torch.isclose( @@ -111,7 +117,7 @@ def test_model_evaluate_n_samples( } results = california_model.evaluate(test_dict) target_tensor = torch.tensor( - [4.063651, 2.7774928, 2.792812], dtype=results["MedHouseVal"].dtype + [[4.063651], [2.7774928], [2.792812]], dtype=results["MedHouseVal"].dtype ) assert torch.all(torch.isclose(results["MedHouseVal"], target_tensor)) @@ -124,9 +130,8 @@ def test_model_evaluate_batch_n_samples( # model should be able to handle input of shape [n_batch, n_samples, n_dim] input_dict = { key: california_test_input_tensor[:, idx] - .unsqueeze(-1) - .unsqueeze(1) - .repeat((1, 3, 1)) + .unsqueeze(1) # Add samples dimension: (3, 1) -> (3, 1, 1) + .repeat((1, 3, 1)) # Repeat samples: (3, 1, 1) -> (3, 3, 1) for idx, key in enumerate(california_model.input_names) } results = california_model.evaluate(input_dict) @@ -142,10 +147,10 @@ def test_model_evaluate_raw( kwargs = deepcopy(california_model_kwargs) kwargs["output_format"] = "raw" california_model = TorchModel(**kwargs) - float_dict = { - key: value.item() for key, value in california_test_input_dict.items() + float_dict_single_sample = { + key: value[0].item() for key, value in california_test_input_dict.items() } - results = california_model.evaluate(float_dict) + results = california_model.evaluate(float_dict_single_sample) assert isinstance(results["MedHouseVal"], float) assert results["MedHouseVal"] == pytest.approx(4.063651) @@ -160,10 +165,10 @@ def test_model_evaluate_shuffled_input( results = california_model.evaluate(shuffled_input) assert isinstance(results["MedHouseVal"], torch.Tensor) - assert torch.isclose( - results["MedHouseVal"], - torch.tensor(4.063651, dtype=results["MedHouseVal"].dtype), + target_tensor = torch.tensor( + [[4.063651], [2.7774928], [2.792812]], dtype=results["MedHouseVal"].dtype ) + assert torch.all(torch.isclose(results["MedHouseVal"], target_tensor)) @pytest.mark.parametrize( "test_idx,expected", [(0, 4.063651), (1, 2.7774928), (2, 2.792812)] @@ -191,7 +196,11 @@ def test_model_evaluate_with_no_output_transformers( kwargs = deepcopy(california_model_kwargs) kwargs["output_transformers"] = [] model = TorchModel(**kwargs) - results = model.evaluate(california_test_input_dict) + # Use only the first sample from the batched input + single_sample_input = { + key: value[0] for key, value in california_test_input_dict.items() + } + results = model.evaluate(single_sample_input) assert torch.isclose( results["MedHouseVal"], diff --git a/tests/models/test_torch_module.py b/tests/models/test_torch_module.py index d8d4cb6d..b4794ebc 100644 --- a/tests/models/test_torch_module.py +++ b/tests/models/test_torch_module.py @@ -174,7 +174,10 @@ def test_module_call_single_input_bad_shape( model=california_model, input_order=[california_model.input_names[0]], ) - input_tensor = deepcopy(california_test_input_tensor[:, 0]) # shape (3,) + # Create a 1D tensor by squeezing - this should fail validation + input_tensor = deepcopy( + california_test_input_tensor[:, 0].squeeze() + ) # shape (3,) with pytest.raises(ValueError): lume_module(input_tensor) @@ -255,9 +258,11 @@ def manipulate_output(self, y_model: dict[str, torch.Tensor]): def test_module_call_batch_n_samples( self, california_test_input_tensor, california_module ): - # module should be able to handle input of shape [n_batch, n_samples, n_dim] + # module should be able to handle input of shape [n_batch, n_samples, n_features, 1] n_batch = 5 - input_tensor = california_test_input_tensor.unsqueeze(0).repeat((n_batch, 1, 1)) + input_tensor = california_test_input_tensor.unsqueeze(0).repeat( + (n_batch, 1, 1, 1) + ) result = california_module(input_tensor) assert tuple(result.shape) == (n_batch, 3) @@ -267,7 +272,8 @@ def test_module_call_batch_n_samples( def test_module_as_gp_prior_mean( self, california_test_input_tensor, california_module ): - train_x = california_test_input_tensor.double() + # Squeeze trailing dimension for BoTorch compatibility: (3, 8, 1) -> (3, 8) + train_x = california_test_input_tensor.squeeze(-1).double() train_y = california_module(train_x).unsqueeze(-1) with warnings.catch_warnings(): warnings.simplefilter( diff --git a/tests/test_base.py b/tests/test_base.py index 8363b846..b72485cc 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -3,7 +3,7 @@ import yaml from lume_torch.base import LUMETorch -from lume_torch.variables import ScalarVariable +from lume_torch.variables import TorchScalarVariable class ExampleModel(LUMETorch): @@ -61,7 +61,7 @@ def test_yaml_serialization(self, simple_variables): yaml_output = example_model.yaml() dict_output = yaml.safe_load(yaml_output) dict_output["input_variables"]["input1"]["variable_class"] = ( - ScalarVariable.__name__ + TorchScalarVariable.__name__ ) # test loading from yaml diff --git a/tests/test_files/california_regression/torch_model.yml b/tests/test_files/california_regression/torch_model.yml index 69ec6a1d..5ec52db4 100644 --- a/tests/test_files/california_regression/torch_model.yml +++ b/tests/test_files/california_regression/torch_model.yml @@ -1,41 +1,41 @@ model_class: TorchModel input_variables: MedInc: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.4999000132083893, 15.000100135803223] default_value: 3.7857346534729004 HouseAge: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [1.0, 52.0] default_value: 29.282135009765625 AveRooms: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.8461538553237915, 141.90908813476562] default_value: 5.4074907302856445 AveBedrms: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.375, 34.06666564941406] default_value: 1.1071722507476807 Population: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [3.0, 28566.0] default_value: 1437.0687255859375 AveOccup: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.692307710647583, 599.7142944335938] default_value: 3.035413980484009 Latitude: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [32.65999984741211, 41.95000076293945] default_value: 35.28323745727539 Longitude: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [-124.3499984741211, -114.30999755859375] default_value: -119.11573028564453 input_validation_config: MedInc: "error" output_variables: - MedHouseVal: {variable_class: ScalarVariable} + MedHouseVal: {variable_class: TorchScalarVariable} model: model.pt input_transformers: [input_transformers_0.pt] output_transformers: [output_transformers_0.pt] diff --git a/tests/test_files/california_regression/torch_module.yml b/tests/test_files/california_regression/torch_module.yml index ac15b145..ac771f20 100644 --- a/tests/test_files/california_regression/torch_module.yml +++ b/tests/test_files/california_regression/torch_module.yml @@ -5,39 +5,39 @@ output_order: [MedHouseVal] model: input_variables: MedInc: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.4999000132083893, 15.000100135803223] default_value: 3.7857346534729004 HouseAge: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [1.0, 52.0] default_value: 29.282135009765625 AveRooms: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.8461538553237915, 141.90908813476562] default_value: 5.4074907302856445 AveBedrms: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.375, 34.06666564941406] default_value: 1.1071722507476807 Population: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [3.0, 28566.0] default_value: 1437.0687255859375 AveOccup: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.692307710647583, 599.7142944335938] default_value: 3.035413980484009 Latitude: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [32.65999984741211, 41.95000076293945] default_value: 35.28323745727539 Longitude: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [-124.3499984741211, -114.30999755859375] default_value: -119.11573028564453 output_variables: - MedHouseVal: {variable_class: ScalarVariable} + MedHouseVal: {variable_class: TorchScalarVariable} model: model.pt input_transformers: [input_transformers_0.pt] output_transformers: [output_transformers_0.pt] diff --git a/tests/test_files/california_regression/variables.yml b/tests/test_files/california_regression/variables.yml index 886dd8e0..781076d8 100644 --- a/tests/test_files/california_regression/variables.yml +++ b/tests/test_files/california_regression/variables.yml @@ -1,35 +1,35 @@ input_variables: MedInc: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.4999000132083893, 15.000100135803223] default_value: 3.7857346534729004 HouseAge: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [1.0, 52.0] default_value: 29.282135009765625 AveRooms: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.8461538553237915, 141.90908813476562] default_value: 5.4074907302856445 AveBedrms: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.375, 34.06666564941406] default_value: 1.1071722507476807 Population: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [3.0, 28566.0] default_value: 1437.0687255859375 AveOccup: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [0.692307710647583, 599.7142944335938] default_value: 3.035413980484009 Latitude: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [32.65999984741211, 41.95000076293945] default_value: 35.28323745727539 Longitude: - variable_class: ScalarVariable + variable_class: TorchScalarVariable value_range: [-124.3499984741211, -114.30999755859375] default_value: -119.11573028564453 output_variables: - MedHouseVal: {variable_class: ScalarVariable} + MedHouseVal: {variable_class: TorchScalarVariable} diff --git a/tests/test_files/single_task_gp/single_task_gp.yml b/tests/test_files/single_task_gp/single_task_gp.yml index 2e6aa6d0..d37ae886 100644 --- a/tests/test_files/single_task_gp/single_task_gp.yml +++ b/tests/test_files/single_task_gp/single_task_gp.yml @@ -1,7 +1,7 @@ model_class: GPModel input_variables: input1: - variable_class: ScalarVariable + variable_class: TorchScalarVariable is_constant: false value_range: [-.inf, .inf] value_range_tolerance: 1.0e-08 diff --git a/tests/test_mlflow_utils.py b/tests/test_mlflow_utils.py index d49149f9..68ba812f 100644 --- a/tests/test_mlflow_utils.py +++ b/tests/test_mlflow_utils.py @@ -6,7 +6,7 @@ import tempfile from unittest import mock -from lume_torch.variables import ScalarVariable +from lume_torch.variables import TorchScalarVariable from lume_torch.base import LUMETorch @@ -21,10 +21,10 @@ def _evaluate(self, input_dict, **kwargs): def simple_model(): """Create a simple test model.""" input_variables = [ - ScalarVariable(name="input", default_value=1.0, value_range=(0.0, 10.0)) + TorchScalarVariable(name="input", default_value=1.0, value_range=(0.0, 10.0)) ] output_variables = [ - ScalarVariable(name="output", default_value=2.0, value_range=(0.0, 20.0)) + TorchScalarVariable(name="output", default_value=2.0, value_range=(0.0, 20.0)) ] return SimpleModel( input_variables=input_variables, output_variables=output_variables diff --git a/tests/test_variables.py b/tests/test_variables.py index d684ecca..53e0f318 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -1,6 +1,672 @@ -from lume_torch.variables import ScalarVariable, get_variable +"""Test suite for lume_torch.variables module.""" +import pytest +import torch +from pydantic import ValidationError +from torch.distributions import Normal, Uniform -def test_get_variable(): - var = get_variable(ScalarVariable.__name__)(name="test") - assert isinstance(var, ScalarVariable) +from lume_torch.variables import ( + ScalarVariable, + TorchScalarVariable, + TorchNDVariable, + DistributionVariable, + Variable, + ConfigEnum, + get_variable, +) + + +class TestGetVariable: + """Tests for the get_variable function.""" + + def test_get_scalar_variable(self): + """Test getting ScalarVariable by name.""" + var_cls = get_variable("ScalarVariable") + assert var_cls is ScalarVariable + + def test_get_torch_scalar_variable(self): + """Test getting TorchScalarVariable by name.""" + var_cls = get_variable("TorchScalarVariable") + assert var_cls is TorchScalarVariable + + def test_get_distribution_variable(self): + """Test getting DistributionVariable by name.""" + var_cls = get_variable("DistributionVariable") + assert var_cls is DistributionVariable + + def test_get_torch_nd_variable(self): + """Test getting TorchNDVariable by name.""" + var_cls = get_variable("TorchNDVariable") + assert var_cls is TorchNDVariable + + def test_get_variable_unknown_name_raises(self): + """Test that unknown variable name raises KeyError.""" + with pytest.raises(KeyError, match="No variable named"): + get_variable("UnknownVariable") + + def test_get_variable_creates_instance(self): + """Test that get_variable returns a class that can be instantiated.""" + # Using ScalarVariable will trigger a deprecation warning + with pytest.warns(DeprecationWarning, match="ScalarVariable is deprecated"): + var = get_variable(ScalarVariable.__name__)(name="test") + assert isinstance(var, ScalarVariable) + + +class TestScalarVariableAlias: + """Tests to ensure ScalarVariable is a deprecated subclass of TorchScalarVariable.""" + + def test_scalar_variable_is_torch_scalar_variable(self): + """ScalarVariable should be a deprecated subclass of TorchScalarVariable.""" + # ScalarVariable is now a subclass, not an alias + assert issubclass(ScalarVariable, TorchScalarVariable) + # Verify it issues a deprecation warning when instantiated + with pytest.warns(DeprecationWarning, match="ScalarVariable is deprecated"): + var = ScalarVariable(name="test_var", default_value=1.0) + # Verify it still works as expected + assert isinstance(var, TorchScalarVariable) + + +class TestTorchScalarVariable: + """Tests for TorchScalarVariable class.""" + + def test_basic_creation(self): + """Test basic variable creation with minimal parameters.""" + var = TorchScalarVariable(name="test_var") + assert var.name == "test_var" + assert var.default_value is None + assert var.value_range is None + assert var.read_only is False + + def test_creation_with_all_attributes(self): + """Test variable creation with all attributes.""" + var = TorchScalarVariable( + name="full_var", + default_value=torch.tensor(1.5), + value_range=(0.0, 10.0), + unit="meters", + read_only=True, + dtype=torch.float32, + ) + assert var.name == "full_var" + assert torch.isclose(var.default_value, torch.tensor(1.5)) + assert var.value_range == (0.0, 10.0) + assert var.unit == "meters" + assert var.read_only is True + assert var.dtype == torch.float32 + + def test_creation_with_mismatched_dtype_raises(self): + """Test variable creation with all attributes.""" + with pytest.raises(ValidationError): + TorchScalarVariable( + name="full_var", + default_value=torch.tensor(1.5), + value_range=(0.0, 10.0), + unit="meters", + read_only=True, + dtype=torch.float64, + ) + + def test_missing_name_raises_validation_error(self): + """Test that missing name raises ValidationError.""" + with pytest.raises(ValidationError): + TorchScalarVariable(default_value=0.1, value_range=(0, 1)) + + def test_default_value_as_float(self): + """Test that float default values are accepted.""" + var = TorchScalarVariable(name="float_var", default_value=2.5) + assert var.default_value == 2.5 + + def test_default_value_as_tensor(self): + """Test that tensor default values are accepted.""" + var = TorchScalarVariable(name="tensor_var", default_value=torch.tensor(3.0)) + assert torch.isclose(var.default_value, torch.tensor(3.0)) + + def test_default_value_out_of_range_raises(self): + """Test that default value out of range raises ValueError.""" + with pytest.raises(ValueError, match="out of valid range"): + TorchScalarVariable( + name="bad_var", default_value=15.0, value_range=(0.0, 10.0) + ) + + # Dtype validation tests + def test_dtype_float32(self): + """Test dtype with torch.float32.""" + var = TorchScalarVariable(name="test", dtype=torch.float32) + assert var.dtype == torch.float32 + + def test_dtype_float64(self): + """Test dtype with torch.float64.""" + var = TorchScalarVariable(name="test", dtype=torch.float64) + assert var.dtype == torch.float64 + + def test_dtype_double(self): + """Test dtype with torch.double.""" + var = TorchScalarVariable(name="test", dtype=torch.double) + assert var.dtype == torch.double + + def test_dtype_invalid_type_raises(self): + """Test that invalid dtype type raises TypeError.""" + with pytest.raises(TypeError, match="dtype must be a"): + TorchScalarVariable(name="test", dtype="invalid_dtype") + + def test_dtype_non_floating_raises(self): + """Test that non-floating dtype raises ValueError.""" + with pytest.raises(ValueError, match="must be a floating-point type"): + TorchScalarVariable(name="test", dtype=torch.int32) + + # Value validation tests + def test_validate_value_float(self): + """Test validation of float values.""" + var = TorchScalarVariable(name="test") + var.validate_value(5.0) # Should not raise + + def test_validate_value_int(self): + """Test validation of int values (allowed as scalars).""" + var = TorchScalarVariable(name="test") + var.validate_value(5) # Should not raise + + def test_validate_value_tensor_0d(self): + """Test validation of 0D tensor.""" + var = TorchScalarVariable(name="test") + var.validate_value(torch.tensor(5.0)) # Should not raise + + def test_validate_value_tensor_1d(self): + """Test validation of 1D tensor.""" + var = TorchScalarVariable(name="test") + var.validate_value(torch.tensor([[5.0], [6.0], [7.0]])) # Should not raise + + def test_validate_value_tensor_batched(self): + """Test validation of batched tensor with last dim = 1.""" + var = TorchScalarVariable(name="test") + var.validate_value(torch.tensor([[5.0], [6.0]])) # Should not raise + + def test_validate_value_tensor_invalid_shape_raises(self): + """Test that tensor with invalid shape raises ValueError.""" + var = TorchScalarVariable(name="test") + with pytest.raises(ValueError, match="Expected tensor with 0 dimensions,"): + var.validate_value(torch.tensor([[5.0, 6.0], [7.0, 8.0]])) + + def test_validate_value_invalid_type_raises(self): + """Test that invalid value type raises TypeError.""" + var = TorchScalarVariable(name="test") + with pytest.raises(TypeError): + var.validate_value("not_a_number") + + def test_validate_value_non_float_tensor_raises(self): + """Test that non-float tensor raises ValueError.""" + var = TorchScalarVariable(name="test") + with pytest.raises(ValueError, match="floating-point type"): + var.validate_value(torch.tensor([[1], [2], [3]])) # int64 tensor + + def test_validate_value_wrong_dtype_raises(self): + """Test that tensor with wrong dtype raises ValueError.""" + var = TorchScalarVariable(name="test", dtype=torch.float64) + with pytest.raises(ValueError, match="Expected dtype"): + var.validate_value(torch.tensor(5.0, dtype=torch.float32)) + + # Value range validation tests + def test_validate_value_in_range(self): + """Test validation of value within range.""" + var = TorchScalarVariable( + name="test", value_range=(0.0, 10.0), default_validation_config="error" + ) + var.validate_value(5.0) # Should not raise + + def test_validate_value_out_of_range_error(self): + """Test that out-of-range value raises error with error config.""" + var = TorchScalarVariable( + name="test", value_range=(0.0, 10.0), default_validation_config="error" + ) + with pytest.raises(ValueError, match="out of valid range"): + var.validate_value(15.0) + + def test_validate_value_out_of_range_warn(self): + """Test that out-of-range value warns with warn config.""" + var = TorchScalarVariable( + name="test", value_range=(0.0, 10.0), default_validation_config="warn" + ) + with pytest.warns(UserWarning): + var.validate_value(15.0) + + def test_validate_value_out_of_range_none(self): + """Test that out-of-range value passes with none config.""" + var = TorchScalarVariable( + name="test", value_range=(0.0, 10.0), default_validation_config="none" + ) + var.validate_value(15.0) # Should not raise + + def test_validate_value_with_config_override(self): + """Test that config parameter overrides default_validation_config.""" + var = TorchScalarVariable( + name="test", value_range=(0.0, 10.0), default_validation_config="none" + ) + # Default is "none" but we override with "error" + with pytest.raises(ValueError, match="out of valid range"): + var.validate_value(15.0, config="error") + + def test_validate_value_config_enum_object(self): + """Test validation with ConfigEnum object instead of string.""" + var = TorchScalarVariable( + name="test", + value_range=(0.0, 10.0), + default_validation_config=ConfigEnum.NULL, + ) + var.validate_value(15.0) # Should not raise with NULL config + + def test_value_range_validation(self): + """Test that value_range min must be <= max.""" + with pytest.raises(ValueError, match="Minimum value"): + TorchScalarVariable(name="test", value_range=(10.0, 0.0)) + + # Read-only validation tests + def test_read_only_matching_value(self): + """Test read-only variable with matching value.""" + var = TorchScalarVariable(name="test", default_value=5.0, read_only=True) + var.validate_value(5.0) # Should not raise + + def test_read_only_matching_tensor(self): + """Test read-only variable with matching tensor value.""" + var = TorchScalarVariable( + name="test", default_value=torch.tensor(5.0), read_only=True + ) + var.validate_value(torch.tensor(5.0)) # Should not raise + + def test_read_only_non_matching_value_raises(self): + """Test read-only variable with non-matching value raises.""" + var = TorchScalarVariable(name="test", default_value=5.0, read_only=True) + with pytest.raises(ValueError, match="read-only"): + var.validate_value(10.0) + + def test_read_only_no_default_raises(self): + """Test read-only variable without default raises.""" + var = TorchScalarVariable(name="test", read_only=True) + with pytest.raises(ValueError, match="no default value"): + var.validate_value(5.0) + + def test_read_only_batched_tensor(self): + """Test read-only validation with batched tensor.""" + var = TorchScalarVariable(name="test", default_value=5.0, read_only=True) + # All values in batch equal to default + var.validate_value(torch.tensor([[5.0], [5.0], [5.0]])) # Should not raise + + def test_read_only_batched_tensor_mismatch_raises(self): + """Test read-only validation with batched tensor containing mismatched values.""" + var = TorchScalarVariable(name="test", default_value=5.0, read_only=True) + with pytest.raises(ValueError, match="read-only"): + var.validate_value(torch.tensor([[5.0], [6.0], [5.0]])) + + def test_read_only_with_tensor_default_float_value(self): + """Test read-only with tensor default but float value.""" + var = TorchScalarVariable( + name="test", default_value=torch.tensor(5.0), read_only=True + ) + var.validate_value(5.0) # Float matching tensor default should work + + def test_read_only_with_float_default_tensor_value(self): + """Test read-only with float default but tensor value.""" + var = TorchScalarVariable(name="test", default_value=5.0, read_only=True) + var.validate_value( + torch.tensor(5.0) + ) # Tensor matching float default should work + + def test_read_only_with_multidim_batched_tensor(self): + """Test read-only with multi-dimensional batched tensor.""" + var = TorchScalarVariable(name="test", default_value=5.0, read_only=True) + # Shape (2, 3, 1) - multiple batch dimensions + batched = torch.full((2, 3, 1), 5.0) + var.validate_value(batched) # Should not raise + + def test_read_only_tensor_near_default_within_tolerance(self): + """Test read-only with values very close to default within tolerance.""" + var = TorchScalarVariable(name="test", default_value=5.0, read_only=True) + # Value very close to default (within 1e-9 tolerance) + var.validate_value(torch.tensor(5.0 + 1e-10)) # Should not raise + + def test_model_dump(self): + """Test model_dump includes variable_class.""" + var = TorchScalarVariable(name="test", default_value=1.0) + dump = var.model_dump() + assert "variable_class" in dump + assert dump["variable_class"] == "TorchScalarVariable" + assert dump["name"] == "test" + + def test_dtype_none_allows_any_float_dtype(self): + """Test that dtype=None allows any floating-point dtype.""" + var = TorchScalarVariable(name="test", dtype=None) + var.validate_value(torch.tensor(5.0, dtype=torch.float32)) # Should not raise + var.validate_value(torch.tensor(5.0, dtype=torch.float64)) # Should not raise + + def test_model_dump_includes_all_attributes(self): + """Test model_dump includes all relevant attributes.""" + var = TorchScalarVariable( + name="test", + default_value=5.0, + value_range=(0.0, 10.0), + unit="meters", + read_only=False, + ) + dump = var.model_dump() + assert dump["name"] == "test" + assert dump["default_value"] == 5.0 + assert dump["value_range"] == (0.0, 10.0) + assert dump["unit"] == "meters" + assert dump["read_only"] is False + + def test_numpy_float_value(self): + """Test validation of numpy float values.""" + import numpy as np + + var = TorchScalarVariable(name="test") + var.validate_value(np.float64(5.0)) # Should not raise + + +class TestTorchNDVariable: + """Tests for TorchNDVariable class.""" + + def test_basic_creation(self): + """Test basic ND variable creation.""" + var = TorchNDVariable(name="test_nd", shape=(10, 20)) + assert var.name == "test_nd" + assert var.shape == (10, 20) + assert var.dtype == torch.float32 + + def test_missing_name_raises_validation_error(self): + """Test that missing name raises ValidationError.""" + with pytest.raises(ValidationError): + TorchNDVariable(shape=(10, 20)) + + def test_missing_shape_raises_validation_error(self): + """Test that missing shape raises ValidationError.""" + with pytest.raises(ValidationError): + TorchNDVariable(name="test") + + def test_creation_with_default_value(self): + """Test ND variable creation with default value.""" + default = torch.randn(10, 20) + var = TorchNDVariable(name="test_nd", shape=(10, 20), default_value=default) + assert torch.allclose(var.default_value, default) + + def test_creation_with_dtype(self): + """Test dtype with torch dtype object.""" + var = TorchNDVariable(name="test", shape=(10,), dtype=torch.float64) + assert var.dtype == torch.float64 + + def test_creation_with_int_dtype(self): + """Test int dtype is accepted.""" + var = TorchNDVariable(name="test", shape=(10,), dtype=torch.int32) + assert var.dtype == torch.int32 + + def test_invalid_dtype_type_raises(self): + """Test invalid dtype type raises TypeError.""" + with pytest.raises(TypeError, match="dtype must be a"): + TorchNDVariable(name="test", shape=(10,), dtype="invalid") + + # Value validation tests + def test_validate_value_correct_tensor(self): + """Test validation of correct tensor value.""" + var = TorchNDVariable(name="test", shape=(10, 20)) + var.validate_value(torch.randn(10, 20)) # Should not raise + + def test_validate_value_batched_tensor(self): + """Test validation of batched tensor.""" + var = TorchNDVariable(name="test", shape=(10, 20)) + var.validate_value(torch.randn(5, 10, 20)) # Batch of 5 + + def test_validate_value_wrong_type_raises(self): + """Test that non-tensor value raises TypeError.""" + var = TorchNDVariable(name="test", shape=(10, 20)) + with pytest.raises(TypeError, match="Expected value to be a Tensor"): + var.validate_value([[1, 2], [3, 4]]) + + def test_validate_value_wrong_shape_raises(self): + """Test that wrong shape raises ValueError.""" + var = TorchNDVariable(name="test", shape=(10, 20)) + with pytest.raises(ValueError, match="Expected last"): + var.validate_value(torch.randn(10, 30)) + + def test_validate_value_wrong_dtype_raises(self): + """Test that wrong dtype raises ValueError.""" + var = TorchNDVariable(name="test", shape=(10,), dtype=torch.float32) + with pytest.raises(ValueError, match="Expected dtype"): + var.validate_value(torch.randn(10, dtype=torch.float64)) + + def test_validate_value_insufficient_dims_raises(self): + """Test that insufficient dimensions raise ValueError.""" + var = TorchNDVariable(name="test", shape=(10, 20, 30)) + with pytest.raises(ValueError, match="Expected last 3 dimension"): + var.validate_value(torch.randn(20, 30)) + + # Read-only validation tests + def test_read_only_matching_value(self): + """Test read-only ND variable with matching value.""" + default = torch.randn(10, 20) + var = TorchNDVariable( + name="test", shape=(10, 20), default_value=default, read_only=True + ) + var.validate_value(default.clone()) # Should not raise + + def test_read_only_non_matching_value_raises(self): + """Test read-only ND variable with non-matching value raises.""" + default = torch.randn(10, 20) + var = TorchNDVariable( + name="test", shape=(10, 20), default_value=default, read_only=True + ) + with pytest.raises(ValueError, match="read-only"): + var.validate_value(torch.randn(10, 20)) + + def test_read_only_no_default_raises(self): + """Test read-only ND variable without default raises.""" + var = TorchNDVariable(name="test", shape=(10, 20), read_only=True) + with pytest.raises(ValueError, match="no default value"): + var.validate_value(torch.randn(10, 20)) + + def test_read_only_batched_tensor(self): + """Test read-only ND variable with batched tensor matching default.""" + default = torch.randn(10, 20) + var = TorchNDVariable( + name="test", shape=(10, 20), default_value=default, read_only=True + ) + # Batched input where all items match default + batched = default.unsqueeze(0).repeat(3, 1, 1) # (3, 10, 20) + var.validate_value(batched) # Should not raise + + def test_read_only_batched_tensor_mismatch_raises(self): + """Test read-only ND variable with batched tensor not matching default raises.""" + default = torch.randn(10, 20) + var = TorchNDVariable( + name="test", shape=(10, 20), default_value=default, read_only=True + ) + # Batched input where items don't match default + batched = torch.randn(3, 10, 20) + with pytest.raises(ValueError, match="read-only"): + var.validate_value(batched) + + def test_read_only_value_within_tolerance(self): + """Test read-only ND variable with values within tolerance.""" + default = torch.randn(10, 20) + var = TorchNDVariable( + name="test", shape=(10, 20), default_value=default, read_only=True + ) + # Value very close to default (within tolerance) + close_value = default + 1e-10 + var.validate_value(close_value) # Should not raise + + +class TestDistributionVariable: + """Tests for DistributionVariable class.""" + + def test_basic_creation(self): + """Test basic distribution variable creation.""" + var = DistributionVariable(name="dist_var") + assert var.name == "dist_var" + assert var.unit is None + + def test_missing_name_raises_validation_error(self): + """Test that missing name raises ValidationError.""" + with pytest.raises(ValidationError): + DistributionVariable(unit="meters") + + def test_creation_with_unit(self): + """Test distribution variable with unit.""" + var = DistributionVariable(name="dist_var", unit="meters") + assert var.unit == "meters" + + def test_validate_normal_distribution(self): + """Test validation of Normal distribution.""" + var = DistributionVariable(name="test") + dist = Normal(loc=0.0, scale=1.0) + var.validate_value(dist) # Should not raise + + def test_validate_uniform_distribution(self): + """Test validation of Uniform distribution.""" + var = DistributionVariable(name="test") + dist = Uniform(low=0.0, high=1.0) + var.validate_value(dist) # Should not raise + + def test_validate_non_distribution_raises(self): + """Test that non-distribution value raises TypeError.""" + var = DistributionVariable(name="test") + with pytest.raises(TypeError, match="Expected value to be of type"): + var.validate_value(5.0) + + def test_validate_tensor_raises(self): + """Test that tensor value raises TypeError.""" + var = DistributionVariable(name="test") + with pytest.raises(TypeError, match="Expected value to be of type"): + var.validate_value(torch.tensor([1.0, 2.0])) + + +class TestConfigEnum: + """Tests for ConfigEnum.""" + + def test_enum_values(self): + """Test ConfigEnum values.""" + assert ConfigEnum.NULL.value == "none" + assert ConfigEnum.WARN.value == "warn" + assert ConfigEnum.ERROR.value == "error" + + def test_enum_from_string(self): + """Test ConfigEnum creation from string.""" + assert ConfigEnum("none") == ConfigEnum.NULL + assert ConfigEnum("warn") == ConfigEnum.WARN + assert ConfigEnum("error") == ConfigEnum.ERROR + + +class TestVariableInheritance: + """Tests for variable inheritance structure.""" + + def test_torch_scalar_variable_is_variable(self): + """Test TorchScalarVariable is a Variable.""" + var = TorchScalarVariable(name="test") + assert isinstance(var, Variable) + + def test_torch_nd_variable_is_variable(self): + """Test TorchNDVariable is a Variable.""" + var = TorchNDVariable(name="test", shape=(10,)) + assert isinstance(var, Variable) + + def test_distribution_variable_is_variable(self): + """Test DistributionVariable is a Variable.""" + var = DistributionVariable(name="test") + assert isinstance(var, Variable) + + +class TestTorchNDVariableEdgeCases: + """Additional edge case tests for TorchNDVariable.""" + + def test_1d_shape(self): + """Test 1D shape variable.""" + var = TorchNDVariable(name="test", shape=(100,)) + var.validate_value(torch.randn(100)) # Should not raise + + def test_4d_shape(self): + """Test 4D shape variable (e.g., video data).""" + var = TorchNDVariable(name="test", shape=(10, 3, 64, 64)) + var.validate_value(torch.randn(10, 3, 64, 64)) # Should not raise + + def test_nested_batch_dimensions(self): + """Test multiple batch dimensions.""" + var = TorchNDVariable(name="test", shape=(10, 20)) + # Shape (2, 3, 10, 20) means batch_size_1=2, batch_size_2=3 + var.validate_value(torch.randn(2, 3, 10, 20)) # Should not raise + + def test_model_dump_nd_variable(self): + """Test model_dump for TorchNDVariable.""" + var = TorchNDVariable(name="test", shape=(10, 20), unit="pixels") + dump = var.model_dump() + assert dump["variable_class"] == "TorchNDVariable" + assert dump["name"] == "test" + assert dump["shape"] == (10, 20) + assert dump["unit"] == "pixels" + + def test_default_value_wrong_shape_raises(self): + """Test that default value with wrong shape raises error.""" + with pytest.raises(ValueError, match="Expected last"): + TorchNDVariable( + name="test", shape=(10, 20), default_value=torch.randn(10, 30) + ) + + def test_default_value_wrong_dtype_raises(self): + """Test that default value with wrong dtype raises error.""" + with pytest.raises(ValueError, match="Expected dtype"): + TorchNDVariable( + name="test", + shape=(10,), + dtype=torch.float32, + default_value=torch.randn(10, dtype=torch.float64), + ) + + def test_validate_value_config_parameter(self): + """Test validate_value with config parameter.""" + var = TorchNDVariable(name="test", shape=(10, 20)) + # Should work with config parameter (even though optional validation is not implemented) + var.validate_value(torch.randn(10, 20), config="error") + var.validate_value(torch.randn(10, 20), config="warn") + var.validate_value(torch.randn(10, 20), config="none") + + +class TestDistributionVariableEdgeCases: + """Additional edge case tests for DistributionVariable.""" + + def test_read_only_attribute(self): + """Test read_only attribute on distribution variable.""" + var = DistributionVariable(name="test", read_only=True) + assert var.read_only is True + + def test_default_validation_config(self): + """Test default_validation_config attribute.""" + var = DistributionVariable(name="test", default_validation_config="warn") + assert var.default_validation_config == ConfigEnum.WARN + + def test_model_dump(self): + """Test model_dump for DistributionVariable.""" + var = DistributionVariable(name="test", unit="meters") + dump = var.model_dump() + assert dump["variable_class"] == "DistributionVariable" + assert dump["name"] == "test" + assert dump["unit"] == "meters" + + def test_validate_with_config_parameter(self): + """Test validate_value with config parameter.""" + var = DistributionVariable(name="test") + dist = Normal(loc=0.0, scale=1.0) + var.validate_value(dist, config="error") # Should not raise + var.validate_value(dist, config=ConfigEnum.WARN) # Should not raise + + def test_validate_batched_distribution(self): + """Test validation of batched distribution.""" + var = DistributionVariable(name="test") + # Batched normal distribution + dist = Normal(loc=torch.zeros(5), scale=torch.ones(5)) + var.validate_value(dist) # Should not raise + + def test_validate_multivariate_distribution(self): + """Test validation of multivariate distribution.""" + from torch.distributions import MultivariateNormal + + var = DistributionVariable(name="test") + dist = MultivariateNormal( + loc=torch.zeros(3), + covariance_matrix=torch.eye(3), + ) + var.validate_value(dist) # Should not raise