5151from types import TracebackType
5252from typing import Any , Callable , Generator , Optional , Union
5353
54+ import numpy as np
55+
5456import torch
5557from torch import Size , SymBool , SymInt , Tensor
5658from torch ._C import DispatchKey , DispatchKeySet , ScriptObject
7072from . import _c10d
7173
7274
73- def _int_on_rank (i : "LocalIntNode | ConstantIntNode" , r : int ) -> int :
75+ def _int_on_rank (i : "int | LocalIntNode | ConstantIntNode" , r : int ) -> int :
7476 if isinstance (i , LocalIntNode ):
7577 return i ._local_ints [r ]
7678 elif isinstance (i , ConstantIntNode ):
7779 return i .val
80+ elif isinstance (i , int ):
81+ return i
7882 else :
7983 raise AssertionError (type (i ))
8084
@@ -216,7 +220,7 @@ def is_constant(self) -> bool:
216220 return False
217221
218222 def sym_max (
219- self , other : "LocalIntNode | ConstantIntNode"
223+ self , other : "int | LocalIntNode | ConstantIntNode"
220224 ) -> "LocalIntNode | ConstantIntNode" :
221225 return LocalIntNode (
222226 {
@@ -226,36 +230,50 @@ def sym_max(
226230 )
227231
228232 def add (
229- self , other : "LocalIntNode | ConstantIntNode"
233+ self , other : "int | LocalIntNode | ConstantIntNode"
230234 ) -> "LocalIntNode | ConstantIntNode" :
231235 return LocalIntNode (
232236 {r : self ._local_ints [r ] + _int_on_rank (other , r ) for r in self ._local_ints }
233237 )
234238
235239 def sub (
236- self , other : "LocalIntNode | ConstantIntNode"
240+ self , other : "int | LocalIntNode | ConstantIntNode"
237241 ) -> "LocalIntNode | ConstantIntNode" :
238242 return LocalIntNode (
239243 {r : self ._local_ints [r ] - _int_on_rank (other , r ) for r in self ._local_ints }
240244 )
241245
242246 def mul (
243- self , other : "LocalIntNode | ConstantIntNode"
247+ self , other : "int | LocalIntNode | ConstantIntNode"
244248 ) -> "LocalIntNode | ConstantIntNode" :
245249 return LocalIntNode (
246250 {r : self ._local_ints [r ] * _int_on_rank (other , r ) for r in self ._local_ints }
247251 )
248252
249- def eq (self , other : "LocalIntNode | ConstantIntNode" ) -> bool | SymBool :
253+ def mod (
254+ self , other : "int | LocalIntNode | ConstantIntNode"
255+ ) -> "LocalIntNode | ConstantIntNode" :
256+ return LocalIntNode (
257+ {r : self ._local_ints [r ] % _int_on_rank (other , r ) for r in self ._local_ints }
258+ )
259+
260+ def int_floordiv (
261+ self , other : "int | LocalIntNode | ConstantIntNode"
262+ ) -> "LocalIntNode | ConstantIntNode" :
263+ return LocalIntNode (
264+ {r : self ._local_ints [r ] // _int_on_rank (other , r ) for r in self ._local_ints }
265+ )
266+
267+ def eq (self , other : "int | LocalIntNode | ConstantIntNode" ) -> bool | SymBool :
250268 r = {self ._local_ints [r ] == _int_on_rank (other , r ) for r in self ._local_ints }
251269 return torch ._C ._get_constant_bool_symnode (len (r ) == 1 and next (iter (r )))
252270
253- def gt (self , other : "LocalIntNode | ConstantIntNode" ) -> bool | SymBool :
271+ def gt (self , other : "int | LocalIntNode | ConstantIntNode" ) -> bool | SymBool :
254272 r = {self ._local_ints [r ] > _int_on_rank (other , r ) for r in self ._local_ints }
255273 assert len (r ) == 1 , (self , other )
256274 return torch ._C ._get_constant_bool_symnode (next (iter (r )))
257275
258- def lt (self , other : "LocalIntNode | ConstantIntNode" ) -> bool | SymBool :
276+ def lt (self , other : "int | LocalIntNode | ConstantIntNode" ) -> bool | SymBool :
259277 r = {self ._local_ints [r ] < _int_on_rank (other , r ) for r in self ._local_ints }
260278 assert len (r ) == 1 , (self , other )
261279 return torch ._C ._get_constant_bool_symnode (next (iter (r )))
@@ -437,6 +455,27 @@ def __torch_dispatch__( # type: ignore[override]
437455 with LocalTensorMode (local_tensor ._ranks ):
438456 return func (* args , ** kwargs )
439457
458+ def numpy (self , * , force : bool = False ) -> np .ndarray :
459+ return self .reconcile ().numpy (force = force )
460+
461+ def __lt__ (
462+ self , other : torch .Tensor | bool | complex | float | int
463+ ) -> torch .Tensor :
464+ self_rec = self .reconcile ()
465+ other_rec = other
466+ if isinstance (other , LocalTensor ):
467+ other_rec = other .reconcile ()
468+ return self_rec < other_rec
469+
470+ def __gt__ (
471+ self , other : torch .Tensor | bool | complex | float | int
472+ ) -> torch .Tensor :
473+ self_rec = self .reconcile ()
474+ other_rec = other
475+ if isinstance (other , LocalTensor ):
476+ other_rec = other .reconcile ()
477+ return self_rec > other_rec
478+
440479 def tolist (self ) -> list [Any ]:
441480 """
442481 Reconcile and convert result to list.
0 commit comments