File tree Expand file tree Collapse file tree 1 file changed +0
-2
lines changed Expand file tree Collapse file tree 1 file changed +0
-2
lines changed Original file line number Diff line number Diff line change @@ -139,15 +139,13 @@ def init_hook(self, module: torch.nn.Module):
139139 module .to (dtype = self .storage_dtype )
140140 return module
141141
142- @torch ._dynamo .disable (recursive = False )
143142 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ):
144143 module .to (dtype = self .compute_dtype )
145144 # How do we account for LongTensor, BoolTensor, etc.?
146145 # args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args)
147146 # kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()}
148147 return args , kwargs
149148
150- @torch ._dynamo .disable (recursive = False )
151149 def post_forward (self , module : torch .nn .Module , output ):
152150 module .to (dtype = self .storage_dtype )
153151 return output
You can’t perform that action at this time.
0 commit comments