55import contextlib
66import copy
77import dataclasses
8+ import gc
89import inspect
10+ import io
911import itertools
1012import pickle
1113import unittest
1214import weakref
1315from unittest .mock import patch
14- import io
15- import gc
1616
1717import numpy as np
18+
1819import torch
1920import torch ._dynamo
2021import torch ._functorch .config
2122import torch ._prims as prims
2223import torch .testing ._internal .optests as optests
2324import torch .utils ._pytree as pytree
24-
2525from torch import distributed as dist
2626from torch ._C ._functorch import _add_batch_dim , get_unwrapped , is_batchedtensor
2727from torch ._dispatch .python import enable_python_dispatcher
3232 _CacheKeyState ,
3333 DynamicOutputShapeException ,
3434 extract_tensor_metadata ,
35- MetadataMismatchError ,
3635 FakeTensor ,
3736 FakeTensorConverter ,
3837 FakeTensorMode ,
38+ MetadataMismatchError ,
3939 unset_fake_temporarily ,
4040 UnsupportedOperatorException ,
4141)
5656 OpDTypes ,
5757 ops ,
5858)
59+ from torch .testing ._internal .common_dtype import all_types_complex_float8_and
5960from torch .testing ._internal .common_utils import (
6061 instantiate_parametrized_tests ,
6162 parametrize ,
6869 TestCase ,
6970 xfailIfTorchDynamo ,
7071)
71- from torch .testing ._internal .common_dtype import all_types_complex_float8_and
7272from torch .testing ._internal .custom_op_db import custom_op_db
73-
7473from torch .testing ._internal .inductor_utils import GPU_TYPE
7574from torch .testing ._internal .jit_utils import RUN_CUDA
7675from torch .testing ._internal .two_tensor import TwoTensor
7776from torch .utils ._mode_utils import no_dispatch
7877from torch .utils ._python_dispatch import TorchDispatchMode
7978
79+
8080aten = torch .ops .aten
8181
8282torch ._dynamo .config .fake_tensor_cache_enabled = True
@@ -977,10 +977,12 @@ def test_fast_div(self):
977977 with mode :
978978 x = torch .empty (2 , 2 , device = "cpu" , dtype = torch .int32 )
979979 from torch ._subclasses .fake_impls import get_fast_op_impls
980+
980981 fast_div = get_fast_op_impls ()[torch .ops .aten .div .Tensor ]
981982 y = fast_div (mode , x , 2 )
982983 self .assertEqual (y .dtype , torch .float32 )
983984
985+
984986instantiate_parametrized_tests (FakeTensorTest )
985987
986988
@@ -1115,7 +1117,9 @@ def test_fake(self, device, dtype, op):
11151117make_propagate_real_tensors_cls (FakeTensorOpInfoTest )
11161118instantiate_device_type_tests (FakeTensorOpInfoTest , globals (), only_for = ("cpu" , "cuda" ))
11171119instantiate_device_type_tests (
1118- PropagateRealTensorsFakeTensorOpInfoTest , globals (), only_for = ("cpu" ,) # noqa: F821
1120+ PropagateRealTensorsFakeTensorOpInfoTest , # noqa: F821
1121+ globals (),
1122+ only_for = ("cpu" ,),
11191123)
11201124
11211125
@@ -1415,13 +1419,11 @@ def forward(self, arg1, arg2, arg3):
14151419 self .assertTrue ("output[0]" not in str (e ))
14161420 if self .__class__ .__name__ .startswith ("PropagateRealTensors" ):
14171421 self .assertTrue (
1418- "Real tensor propagation found a metadata mismatch"
1419- in str (e )
1422+ "Real tensor propagation found a metadata mismatch" in str (e )
14201423 )
14211424 else :
14221425 self .assertTrue (
1423- "found mismatched tensor metadata for output"
1424- in str (e )
1426+ "found mismatched tensor metadata for output" in str (e )
14251427 )
14261428
14271429 # IMPORTANT!!! Always run even if CUDA is not available
@@ -1623,61 +1625,74 @@ def test_nonzero_stride(self):
16231625 def test_torch_load_with_fake_mode (self ):
16241626 model = torch .nn .Linear (5 , 10 )
16251627 sd = model .state_dict ()
1626- sd ['tt' ] = TwoTensor (torch .randn (2 ), torch .randn (2 ))
1628+ sd ["tt" ] = TwoTensor (torch .randn (2 ), torch .randn (2 ))
16271629
16281630 def _read_tensor_and_check (key , sd_loaded , all_bytes , device ):
16291631 dtype = torch .float32
16301632 t = sd_loaded [key ]
16311633 self .assertEqual (t .device .type , device )
16321634 if isinstance (t , TwoTensor ):
1633- untyped_storage_a , untyped_storage_b = t .a .untyped_storage (), t .b .untyped_storage ()
1634- offset_a , offset_b = untyped_storage_a ._checkpoint_offset , untyped_storage_b ._checkpoint_offset
1635- nbytes_a , nbytes_b = untyped_storage_a .nbytes () // 4 , untyped_storage_b .nbytes () // 4
1636- result_a = torch .frombuffer (all_bytes , dtype = dtype , count = nbytes_a , offset = offset_a ).resize_ (t .a .size ())
1637- result_b = torch .frombuffer (all_bytes , dtype = dtype , count = nbytes_b , offset = offset_b ).resize_ (t .b .size ())
1635+ untyped_storage_a , untyped_storage_b = (
1636+ t .a .untyped_storage (),
1637+ t .b .untyped_storage (),
1638+ )
1639+ offset_a , offset_b = (
1640+ untyped_storage_a ._checkpoint_offset ,
1641+ untyped_storage_b ._checkpoint_offset ,
1642+ )
1643+ nbytes_a , nbytes_b = (
1644+ untyped_storage_a .nbytes () // 4 ,
1645+ untyped_storage_b .nbytes () // 4 ,
1646+ )
1647+ result_a = torch .frombuffer (
1648+ all_bytes , dtype = dtype , count = nbytes_a , offset = offset_a
1649+ ).resize_ (t .a .size ())
1650+ result_b = torch .frombuffer (
1651+ all_bytes , dtype = dtype , count = nbytes_b , offset = offset_b
1652+ ).resize_ (t .b .size ())
16381653 self .assertEqual (TwoTensor (result_a , result_b ), sd [key ])
16391654 else :
16401655 untyped_storage = t .untyped_storage ()
16411656 offset = untyped_storage ._checkpoint_offset
16421657 nbytes = untyped_storage .nbytes () // 4
1643- result = torch .frombuffer (all_bytes , dtype = dtype , count = nbytes , offset = offset ).resize_ (t .size ())
1658+ result = torch .frombuffer (
1659+ all_bytes , dtype = dtype , count = nbytes , offset = offset
1660+ ).resize_ (t .size ())
16441661 self .assertEqual (result , sd [key ])
16451662
1646-
16471663 with TemporaryFileName () as f , torch .serialization .safe_globals ([TwoTensor ]):
16481664 # Create state_dict to be loaded later
16491665 torch .save (sd , f )
1650- with open (f , 'rb' ) as g :
1666+ with open (f , "rb" ) as g :
16511667 all_bytes = g .read ()
16521668
16531669 fake_mode = FakeTensorMode ()
16541670 with fake_mode :
16551671 sd_loaded = torch .load (f )
16561672 for k in sd :
1657- _read_tensor_and_check (k , sd_loaded , all_bytes , ' cpu' )
1673+ _read_tensor_and_check (k , sd_loaded , all_bytes , " cpu" )
16581674 with fake_mode :
16591675 sd_loaded = torch .load (f , map_location = "cuda" )
16601676 for k in sd :
1661- _read_tensor_and_check (k , sd_loaded , all_bytes , 'cuda' )
1662-
1677+ _read_tensor_and_check (k , sd_loaded , all_bytes , "cuda" )
16631678
16641679 for k in sd .keys ():
1665- sd [k ] = sd [k ].to (' cuda' )
1680+ sd [k ] = sd [k ].to (" cuda" )
16661681
16671682 with TemporaryFileName () as f , torch .serialization .safe_globals ([TwoTensor ]):
16681683 torch .save (sd , f )
1669- with open (f , 'rb' ) as g :
1684+ with open (f , "rb" ) as g :
16701685 all_bytes = g .read ()
16711686
16721687 fake_mode = FakeTensorMode ()
16731688 with fake_mode :
16741689 sd_loaded = torch .load (f )
16751690 for k in sd :
1676- _read_tensor_and_check (k , sd_loaded , all_bytes , ' cuda' )
1691+ _read_tensor_and_check (k , sd_loaded , all_bytes , " cuda" )
16771692 with fake_mode :
16781693 sd_loaded = torch .load (f , map_location = "cpu" )
16791694 for k in sd :
1680- _read_tensor_and_check (k , sd_loaded , all_bytes , ' cpu' )
1695+ _read_tensor_and_check (k , sd_loaded , all_bytes , " cpu" )
16811696
16821697
16831698make_propagate_real_tensors_cls (FakeTensorPropTest )
@@ -1994,9 +2009,9 @@ def test_fft_hfft2_issue145522(self):
19942009 x = torch .randn (s0 , s1 , s2 )
19952010 out = torch .randn (s0 , s3 , s4 )
19962011 kwargs = {
1997- 's' : (s3 , s4 ),
1998- ' dim' : (1 , s5 ),
1999- ' norm' : ' ortho' ,
2012+ "s" : (s3 , s4 ),
2013+ " dim" : (1 , s5 ),
2014+ " norm" : " ortho" ,
20002015 }
20012016 r = torch ._C ._fft .fft_hfft2 (x , ** kwargs , out = out )
20022017 self .assertEqual (r .shape , out .shape )
@@ -2074,8 +2089,12 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
20742089 def __torch_dispatch__ (cls , func , types , args , kwargs ):
20752090 if kwargs is None :
20762091 kwargs = {}
2077- args = pytree .tree_map_only (DifferentDeviceTensor , lambda x : x .inner_tensor , args )
2078- kwargs = pytree .tree_map_only (DifferentDeviceTensor , lambda x : x .inner_tensor , kwargs )
2092+ args = pytree .tree_map_only (
2093+ DifferentDeviceTensor , lambda x : x .inner_tensor , args
2094+ )
2095+ kwargs = pytree .tree_map_only (
2096+ DifferentDeviceTensor , lambda x : x .inner_tensor , kwargs
2097+ )
20792098 # Returns unwrapped tensor
20802099 return func (* args , ** kwargs )
20812100
@@ -2098,7 +2117,7 @@ def f(x):
20982117 return torch .nn .functional .interpolate (
20992118 x ,
21002119 size = [256 , 256 ],
2101- mode = ' bilinear' ,
2120+ mode = " bilinear" ,
21022121 align_corners = False ,
21032122 antialias = True ,
21042123 )
@@ -2108,8 +2127,13 @@ def f(x):
21082127 x = fake_m .from_tensor (
21092128 torch .randn (1 , 3 , 2005 , 1920 , requires_grad = True ),
21102129 symbolic_context = StatelessSymbolicContext (
2111- dynamic_sizes = [DimDynamic .STATIC , DimDynamic .STATIC , DimDynamic .DYNAMIC , DimDynamic .DYNAMIC ],
2112- constraint_sizes = [None , None , None , None ]
2130+ dynamic_sizes = [
2131+ DimDynamic .STATIC ,
2132+ DimDynamic .STATIC ,
2133+ DimDynamic .DYNAMIC ,
2134+ DimDynamic .DYNAMIC ,
2135+ ],
2136+ constraint_sizes = [None , None , None , None ],
21132137 ),
21142138 )
21152139 with fake_m , enable_python_dispatcher ():
@@ -2126,14 +2150,14 @@ def test_from_buffer(self):
21262150
21272151 t = torch .ByteTensor (storage )
21282152 self .assertTrue (isinstance (t , FakeTensor ))
2129- self .assertEqual (t .device , torch .device (' cpu' ))
2153+ self .assertEqual (t .device , torch .device (" cpu" ))
21302154
21312155 def test_meta_tensor_to_fake_cpu (self ):
2132- x = torch .randn (4 , 4 , device = ' meta' )
2156+ x = torch .randn (4 , 4 , device = " meta" )
21332157 with FakeTensorMode (allow_non_fake_inputs = True ):
2134- x_cpu = x .to (device = ' cpu' )
2158+ x_cpu = x .to (device = " cpu" )
21352159 self .assertTrue (isinstance (x_cpu , FakeTensor ))
2136- self .assertEqual (x_cpu .device , torch .device (' cpu' ))
2160+ self .assertEqual (x_cpu .device , torch .device (" cpu" ))
21372161
21382162 def test_cache_tuple_outputs (self ):
21392163 """
@@ -2158,7 +2182,6 @@ def test_cache_tuple_outputs(self):
21582182 extract_tensor_metadata (b ),
21592183 )
21602184
2161-
21622185 def test_cache_aten_index (self ):
21632186 with FakeTensorMode ():
21642187 x = torch .randn (4 , 4 , 4 )
@@ -2178,10 +2201,16 @@ def test_cache_aten_index(self):
21782201 with FakeTensorMode ():
21792202 x = torch .randn (4 , 4 , 4 )
21802203 idx_tensor1 = torch .tensor ([True , True , False , True ])
2181- self .assertRaises (DynamicOutputShapeException , lambda : torch .ops .aten .index (x , [None , idx_tensor1 ]))
2204+ self .assertRaises (
2205+ DynamicOutputShapeException ,
2206+ lambda : torch .ops .aten .index (x , [None , idx_tensor1 ]),
2207+ )
21822208
21832209 idx_tensor1 = torch .tensor ([1 , - 2 , 3 , - 4 ], dtype = torch .int8 )
2184- self .assertRaises (DynamicOutputShapeException , lambda : torch .ops .aten .index (x , [None , idx_tensor1 ]))
2210+ self .assertRaises (
2211+ DynamicOutputShapeException ,
2212+ lambda : torch .ops .aten .index (x , [None , idx_tensor1 ]),
2213+ )
21852214
21862215 @skipIfTorchDynamo ("cache hit/miss changes with invoke_subgraph caching" )
21872216 def test_invoke_subgraph (self ):
@@ -2335,11 +2364,14 @@ def forward(
23352364 lengths = torch .tensor ([0 , 2 , 3 , 1 , 4 ])
23362365 indices = torch .tensor ([2 , 3 , 4 , 6 , 7 , 8 , 9 ])
23372366 offsets = torch .cumsum (lengths , 0 )
2338- ep = torch .export .export (LengthsGather (), (input , lengths , indices , offsets ), strict = False )
2367+ ep = torch .export .export (
2368+ LengthsGather (), (input , lengths , indices , offsets ), strict = False
2369+ )
23392370
23402371 FakeTensorMode .cache_clear ()
23412372 ep .run_decompositions ({})
23422373 self .assertBypasses ("unrepresented symbol in output" , 2 )
23432374
2375+
23442376if __name__ == "__main__" :
23452377 run_tests ()
0 commit comments