Adding TorchScalarVariable and TorchNDVariable#140
Conversation
|
|
||
| def _arrange_inputs( | ||
| self, formatted_inputs: dict[str, torch.Tensor] | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
@roussel-ryan can you review this function?
lume_torch/variables.py
Outdated
| if value.ndim == 0: | ||
| pass # scalar tensor, valid | ||
| elif value.ndim == 1: | ||
| pass # 1D tensor (single scalar or batch of scalars), valid |
There was a problem hiding this comment.
should this be valid for a ScalarVariable type?
There was a problem hiding this comment.
Currently we pass tensors of 1 dim like torch.tensor([1]) or batches torch.tensor([1, 2, 3]) and we don't want to treat these as NDVariables. Do you have other suggestions on how to validate this?
There was a problem hiding this comment.
I see, in this case we should either check that the last dimension is 1 or that ndim=0, I don't think a shape (N,) should work. Also the comment below should read
# Batched scalars with shape (batch_size, 1), valid
lume_torch/variables.py
Outdated
| if expected_dtype and value.dtype != expected_dtype: | ||
| raise ValueError(f"Expected dtype {expected_dtype}, got {value.dtype}") | ||
|
|
||
| def _get_image_shape_for_validation(self, value: Tensor) -> Tuple[int, ...]: |
There was a problem hiding this comment.
I think this is too restrictive for NDVariable, if we want an ImageVariable subclass then this would be more relevant
There was a problem hiding this comment.
So do we want both
NDVariable -> ImageVariable(NDVariable) -> TorchImageVariable(ImageVariable)
and
NDVariable -> TorchNDvariable(NDVariable)
and similar implementations for numpy under lume-base?
There was a problem hiding this comment.
so what does the ImageVariable class add that we can't do with the shape argument in NDVariable? Is it just a convenience wrapper?
There was a problem hiding this comment.
If so I think we would want NDVariable -> TorchNDVariable(NDVariable) -> TorchImageVariable(TorchNDVariable)
There was a problem hiding this comment.
So the main reasons I added specific image type validation was:
- to validate that the NDVariable is either 2D or 3D only
- to validate that the torch images are being correctly defined vs let's say numpy images, since torch expects [Channels, Height, Width] but numpy expects [Height, Width, Channels].
Both of these can be removed and we can just use NDVariable, as long as we assume the user is defining images correctly for each case. We expect to be using numpy images on the LCLS side AFAIK.
Or I can add an image subclass as discussed above. Any preference?
There was a problem hiding this comment.
why do we have channels at all? we deal with greyscale images
There was a problem hiding this comment.
To keep it general. For now maybe it's best to remove image specific checks and keep it as a general NDVariable class, and implement features/validation as needed if specific image use cases come up.
There was a problem hiding this comment.
yes, I think that makes sense
This pull request introduces several improvements and refactors across the codebase to better support array and tensor variables in LUME models, enhance input/output validation, and simplify utility functions. The main focus is on expanding variable support beyond scalars, improving tensor handling, and streamlining validation logic for model inputs and outputs.
Expanded Variable Support and Validation
input_variablesandoutput_variablesinLUMEBaseModelto supportTorchNDVariableandDistributionVariable, allowing models to handle arrays/tensors and distributions as inputs/outputs.LUMEBaseModeland derived models to properly validate and handleTorchNDVariableandDistributionVariableinstances, raising errors for unknown input names. [1] [2]TorchModelto differentiate between scalar and tensor outputs, ensuring correct validation for each variable type.Utility and Model Refactoring
itemize_dictutility to support flattening and itemizing both numpy arrays and torch tensors, making it more robust for different input types. [1] [2]_arrange_inputsmethod inTorchModelto support batching and stacking of tensor inputs, handle default values, and enforce consistent input shapes, with clear error handling for mixed variable types. [1] [2]Miscellaneous Improvements
_tkwargsproperty toProbModelBasefor consistent tensor device/dtype handling, and removed redundant code fromGPModel. [1] [2]DistributionVariable.validate_valueto correctly useConfigEnum.NULLfor validation configuration, improving clarity and correctness. [1] [2]These changes collectively make the codebase more flexible for machine learning workflows that require complex input/output types, improve validation reliability, and simplify the handling of tensors and arrays throughout the model lifecycle.