Skip to content

Adding TorchScalarVariable and TorchNDVariable#140

Open
pluflou wants to merge 29 commits intomainfrom
ndvariable
Open

Adding TorchScalarVariable and TorchNDVariable#140
pluflou wants to merge 29 commits intomainfrom
ndvariable

Conversation

@pluflou
Copy link
Collaborator

@pluflou pluflou commented Feb 4, 2026

This pull request introduces several improvements and refactors across the codebase to better support array and tensor variables in LUME models, enhance input/output validation, and simplify utility functions. The main focus is on expanding variable support beyond scalars, improving tensor handling, and streamlining validation logic for model inputs and outputs.

Expanded Variable Support and Validation

  • Updated input_variables and output_variables in LUMEBaseModel to support TorchNDVariable and DistributionVariable, allowing models to handle arrays/tensors and distributions as inputs/outputs.
  • Refactored input validation logic in LUMEBaseModel and derived models to properly validate and handle TorchNDVariable and DistributionVariable instances, raising errors for unknown input names. [1] [2]
  • Improved output validation in TorchModel to differentiate between scalar and tensor outputs, ensuring correct validation for each variable type.

Utility and Model Refactoring

  • Enhanced the itemize_dict utility to support flattening and itemizing both numpy arrays and torch tensors, making it more robust for different input types. [1] [2]
  • Refactored the _arrange_inputs method in TorchModel to support batching and stacking of tensor inputs, handle default values, and enforce consistent input shapes, with clear error handling for mixed variable types. [1] [2]

Miscellaneous Improvements

  • Added _tkwargs property to ProbModelBase for consistent tensor device/dtype handling, and removed redundant code from GPModel. [1] [2]
  • Improved handling of default values for tensor inputs and outputs, ensuring proper cloning and detaching to avoid unwanted side effects. [1] [2]
  • Updated DistributionVariable.validate_value to correctly use ConfigEnum.NULL for validation configuration, improving clarity and correctness. [1] [2]

These changes collectively make the codebase more flexible for machine learning workflows that require complex input/output types, improve validation reliability, and simplify the handling of tensors and arrays throughout the model lifecycle.


def _arrange_inputs(
self, formatted_inputs: dict[str, torch.Tensor]
) -> torch.Tensor:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roussel-ryan can you review this function?

if value.ndim == 0:
pass # scalar tensor, valid
elif value.ndim == 1:
pass # 1D tensor (single scalar or batch of scalars), valid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be valid for a ScalarVariable type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we pass tensors of 1 dim like torch.tensor([1]) or batches torch.tensor([1, 2, 3]) and we don't want to treat these as NDVariables. Do you have other suggestions on how to validate this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in this case we should either check that the last dimension is 1 or that ndim=0, I don't think a shape (N,) should work. Also the comment below should read
# Batched scalars with shape (batch_size, 1), valid

if expected_dtype and value.dtype != expected_dtype:
raise ValueError(f"Expected dtype {expected_dtype}, got {value.dtype}")

def _get_image_shape_for_validation(self, value: Tensor) -> Tuple[int, ...]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too restrictive for NDVariable, if we want an ImageVariable subclass then this would be more relevant

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we want both
NDVariable -> ImageVariable(NDVariable) -> TorchImageVariable(ImageVariable)
and
NDVariable -> TorchNDvariable(NDVariable)

and similar implementations for numpy under lume-base?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what does the ImageVariable class add that we can't do with the shape argument in NDVariable? Is it just a convenience wrapper?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so I think we would want NDVariable -> TorchNDVariable(NDVariable) -> TorchImageVariable(TorchNDVariable)

Copy link
Collaborator Author

@pluflou pluflou Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main reasons I added specific image type validation was:

  • to validate that the NDVariable is either 2D or 3D only
  • to validate that the torch images are being correctly defined vs let's say numpy images, since torch expects [Channels, Height, Width] but numpy expects [Height, Width, Channels].

Both of these can be removed and we can just use NDVariable, as long as we assume the user is defining images correctly for each case. We expect to be using numpy images on the LCLS side AFAIK.

Or I can add an image subclass as discussed above. Any preference?

Copy link
Collaborator

@roussel-ryan roussel-ryan Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have channels at all? we deal with greyscale images

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep it general. For now maybe it's best to remove image specific checks and keep it as a general NDVariable class, and implement features/validation as needed if specific image use cases come up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think that makes sense

@pluflou pluflou deleted the branch main February 12, 2026 19:42
@pluflou pluflou closed this Feb 12, 2026
@pluflou pluflou reopened this Feb 12, 2026
@pluflou pluflou changed the base branch from lume-torch to main February 18, 2026 01:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants