33
44import logging
55from typing import Any , Callable
6- from collections import deque
76
87import numpy as np
98from numpy .typing import DTypeLike
@@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
7473 _tensor_type : type
7574 _meta : Any
7675 _data : Any | None
77- _lazy : deque [LazyBase ] # shared within a graph, to avoid deep recursion when making eager
7876 _args : tuple
79- _func : Callable [[tuple ], Any ] | None
77+ _kwargs : dict [str , Any ]
78+ _func : Callable [[Any ], Any ] | None
8079
81- def __init__ (self , * , meta : Any , data : Any | None = None , lazy : deque [ LazyBase ] | None = None , args : tuple = (), func : Callable [[tuple ], Any ] | None = None ):
80+ def __init__ (self , * , meta : Any , data : Any | None = None , args : tuple = (), kwargs : dict [ str , Any ] | None = None , func : Callable [[Any ], Any ] | None = None ):
8281 super ().__init__ ()
8382 self ._meta = meta
8483 self ._data = data
85- self ._lazy = lazy if lazy is not None else deque ()
8684 self ._args = args
85+ self ._kwargs = kwargs if kwargs is not None else {}
8786 self ._func = func
8887 assert self ._func is not None or self ._data is not None
89- if self ._data is None :
90- self ._lazy .append (self )
9188
9289 def __init_subclass__ (cls ) -> None :
9390 if "_tensor_type" not in cls .__dict__ :
@@ -117,6 +114,7 @@ def wrapped_fn(*args, **kwargs):
117114 args = ((use_self ,) if use_self is not None else ()) + args
118115
119116 meta_args = LazyBase ._recurse_apply (args , lambda t : t ._meta )
117+ # TODO: maybe handle tensors in kwargs too
120118
121119 if isinstance (meta_noop , bool ) and not meta_noop :
122120 try :
@@ -140,23 +138,7 @@ def wrapped_fn(*args, **kwargs):
140138 res = cls .meta_with_dtype_and_shape (meta_noop , res .shape )
141139
142140 if isinstance (res , cls ._tensor_type ):
143- class CollectSharedLazy :
144- # emulating a static variable
145- shared_lazy : None | deque [LazyBase ] = None
146-
147- @staticmethod
148- def collect_replace (t : LazyBase ):
149- if CollectSharedLazy .shared_lazy is None :
150- CollectSharedLazy .shared_lazy = t ._lazy
151- else :
152- CollectSharedLazy .shared_lazy .extend (t ._lazy )
153- t ._lazy = CollectSharedLazy .shared_lazy
154-
155- LazyBase ._recurse_apply (args , CollectSharedLazy .collect_replace )
156-
157- shared_lazy = CollectSharedLazy .shared_lazy
158-
159- return cls (meta = cls .eager_to_meta (res ), lazy = shared_lazy , args = args , func = lambda a : fn (* a , ** kwargs ))
141+ return cls (meta = cls .eager_to_meta (res ), args = args , kwargs = kwargs , func = fn )
160142 else :
161143 del res # not needed
162144 # non-tensor return likely relies on the contents of the args
@@ -168,26 +150,18 @@ def collect_replace(t: LazyBase):
168150 @classmethod
169151 def to_eager (cls , t : Any ) -> Any :
170152 def simple_to_eager (_t : LazyBase ) -> Any :
171- def already_eager_to_eager (_t : LazyBase ) -> Any :
172- assert _t ._data is not None
153+ if _t ._data is not None :
173154 return _t ._data
174155
175- while _t ._data is None :
176- lt = _t ._lazy .popleft ()
177- if lt ._data is not None :
178- # Lazy tensor did not belong in the lazy queue.
179- # Weirdly only happens with Bloom models...
180- # likely because tensors aren't unique in the queue.
181- # The final output is still the same as in eager mode,
182- # so it's safe to ignore this.
183- continue
184- assert lt ._func is not None
185- lt ._args = cls ._recurse_apply (lt ._args , already_eager_to_eager )
186- lt ._data = lt ._func (lt ._args )
187- # sanity check
188- assert lt ._data is not None
189- assert lt ._data .dtype == lt ._meta .dtype
190- assert lt ._data .shape == lt ._meta .shape
156+ # NOTE: there's a recursion limit in Python (usually 1000)
157+
158+ assert _t ._func is not None
159+ _t ._args = cls ._recurse_apply (_t ._args , simple_to_eager )
160+ _t ._data = _t ._func (* _t ._args , ** _t ._kwargs )
161+ # sanity check
162+ assert _t ._data is not None
163+ assert _t ._data .dtype == _t ._meta .dtype
164+ assert _t ._data .shape == _t ._meta .shape
191165
192166 return _t ._data
193167
@@ -206,7 +180,7 @@ def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
206180 @classmethod
207181 def from_eager (cls , t : Any ) -> Any :
208182 if type (t ) is cls :
209- # already eager
183+ # already lazy
210184 return t
211185 elif isinstance (t , cls ._tensor_type ):
212186 return cls (meta = cls .eager_to_meta (t ), data = t )
@@ -228,8 +202,7 @@ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) ->
228202 def astype (self , dtype , * args , ** kwargs ):
229203 meta = type (self ).meta_with_dtype_and_shape (dtype , self ._meta .shape )
230204 full_args = (self , dtype ,) + args
231- # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
232- return type (self )(meta = meta , args = full_args , lazy = self ._lazy , func = (lambda a : a [0 ].astype (* a [1 :], ** kwargs )))
205+ return type (self )(meta = meta , args = full_args , kwargs = kwargs , func = (lambda a , * args , ** kwargs : a .astype (* args , ** kwargs )))
233206
234207 def tofile (self , * args , ** kwargs ):
235208 eager = LazyNumpyTensor .to_eager (self )
0 commit comments