@@ -65,6 +65,8 @@ def __init__(
6565 )
6666 self .ndim = len (self .dims )
6767 self .name = name
68+ self .numpy_dtype = np .dtype (self .dtype )
69+ self .filter_checks_isfinite = False
6870
6971 def clone (
7072 self ,
@@ -82,8 +84,9 @@ def clone(
8284 return type (self )(dtype = dtype , shape = shape , dims = dims , ** kwargs )
8385
8486 def filter (self , value , strict = False , allow_downcast = None ):
85- # TODO implement this
86- return value
87+ return TensorType .filter (
88+ self , value , strict = strict , allow_downcast = allow_downcast
89+ )
8790
8891 def convert_variable (self , var ):
8992 # TODO: Implement this
@@ -689,17 +692,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
689692 if isinstance (x .type , XTensorType ):
690693 return x
691694 if isinstance (x .type , TensorType ):
692- if x .type .ndim > 0 and dims is None :
693- raise TypeError (
694- "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
695- )
696- return px .basic .xtensor_from_tensor (x , dims )
695+ if dims is None :
696+ if x .type .ndim == 0 :
697+ dims = ()
698+ else :
699+ raise TypeError (
700+ "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
701+ )
702+ return px .basic .xtensor_from_tensor (x , dims = dims , name = name )
697703 else :
698704 raise TypeError (
699705 "Variable with type {x.type} cannot be converted to XTensorVariable."
700706 )
701707 try :
702- return xtensor_constant (x , name = name , dims = dims )
708+ return xtensor_constant (x , dims = dims , name = name )
703709 except TypeError as err :
704710 raise TypeError (f"Cannot convert { x } to XTensorType { type (x )} " ) from err
705711
0 commit comments