Skip to content

Commit fa503f1

Browse files
committed
adjust based on lume-base changes to NDVariable
1 parent 154f000 commit fa503f1

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

lume_torch/variables.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
import warnings
9-
from typing import Optional, Type, Union
9+
from typing import Optional, Type, Union, ClassVar
1010

1111
import torch
1212
from torch import Tensor
@@ -257,11 +257,26 @@ class TorchNDVariable(NDVariable):
257257
dtype : torch.dtype
258258
Data type of the tensor. Defaults to torch.float32.
259259
260+
Examples
261+
--------
262+
>>> import torch
263+
>>> from lume_torch.variables import TorchNDVariable
264+
>>>
265+
>>> var = TorchNDVariable(
266+
... name="my_tensor",
267+
... shape=(3, 4),
268+
... dtype=torch.float32,
269+
... unit="m"
270+
... )
271+
>>>
272+
>>> tensor = torch.rand(3, 4)
273+
>>> var.validate_value(tensor, config="error") # Passes
274+
260275
"""
261276

262277
default_value: Optional[Tensor] = None
263278
dtype: torch.dtype = torch.float32
264-
array_type: type = torch.Tensor
279+
array_type: ClassVar[type] = Tensor
265280

266281
def _validate_read_only(self, value: Tensor) -> None:
267282
"""Validates that read-only ND-variables match their default value.

0 commit comments

Comments
 (0)