@@ -49,6 +49,8 @@ def __init__(
4949 self .shape = tuple (shape )
5050 self .ndim = len (self .dims )
5151 self .name = name
52+ self .numpy_dtype = np .dtype (self .dtype )
53+ self .filter_checks_isfinite = False
5254
5355 def clone (
5456 self ,
@@ -66,8 +68,9 @@ def clone(
6668 return type (self )(dtype = dtype , shape = shape , dims = dims , ** kwargs )
6769
6870 def filter (self , value , strict = False , allow_downcast = None ):
69- # TODO implement this
70- return value
71+ return TensorType .filter (
72+ self , value , strict = strict , allow_downcast = allow_downcast
73+ )
7174
7275 def convert_variable (self , var ):
7376 # TODO: Implement this
@@ -530,16 +533,19 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
530533 if isinstance (x .type , XTensorType ):
531534 return x
532535 if isinstance (x .type , TensorType ):
533- if x .type .ndim > 0 and dims is None :
534- raise TypeError (
535- "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
536- )
537- return px .basic .xtensor_from_tensor (x , dims )
536+ if dims is None :
537+ if x .type .ndim == 0 :
538+ dims = ()
539+ else :
540+ raise TypeError (
541+ "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
542+ )
543+ return px .basic .xtensor_from_tensor (x , dims = dims , name = name )
538544 else :
539545 raise TypeError (
540546 "Variable with type {x.type} cannot be converted to XTensorVariable."
541547 )
542548 try :
543- return xtensor_constant (x , name = name , dims = dims )
549+ return xtensor_constant (x , dims = dims , name = name )
544550 except TypeError as err :
545551 raise TypeError (f"Cannot convert { x } to XTensorType { type (x )} " ) from err
0 commit comments