@@ -71,6 +71,8 @@ def __init__(
7171 self .name = name
7272 self .numpy_dtype = np .dtype (self .dtype )
7373 self .filter_checks_isfinite = False
74+ # broadcastable is here just for code that would work fine with XTensorType but checks for it
75+ self .broadcastable = (False ,) * self .ndim
7476
7577 def clone (
7678 self ,
@@ -93,6 +95,10 @@ def filter(self, value, strict=False, allow_downcast=None):
9395 self , value , strict = strict , allow_downcast = allow_downcast
9496 )
9597
98+ @staticmethod
99+ def may_share_memory (a , b ):
100+ return TensorType .may_share_memory (a , b )
101+
96102 def filter_variable (self , other , allow_convert = True ):
97103 if not isinstance (other , Variable ):
98104 # The value is not a Variable: we cast it into
@@ -160,7 +166,7 @@ def convert_variable(self, var):
160166 return None
161167
162168 def __repr__ (self ):
163- return f"XTensorType({ self .dtype } , { self .dims } , { self .shape } )"
169+ return f"XTensorType({ self .dtype } , shape= { self .shape } , dims= { self .dims } )"
164170
165171 def __hash__ (self ):
166172 return hash ((type (self ), self .dtype , self .shape , self .dims ))
0 commit comments