@@ -2299,6 +2299,63 @@ def forward(self, x, w):
22992299 self .assertEqual (actual , expected , atol = atol , rtol = rtol )
23002300 self .assertEqual (counters ["inductor" ]["select_algorithm_autotune" ], 2 )
23012301
2302+ @patches
2303+ @torch .no_grad
2304+ @unittest .skipIf (not TEST_MKL , "Test requires MKL" )
2305+ @set_num_threads (1 ) # avoid k_slicing to make the test deterministic
2306+ @parametrize (
2307+ "out_features1" ,
2308+ (
2309+ 8 ,
2310+ 16 ,
2311+ 24 ,
2312+ 32 ,
2313+ 48 ,
2314+ ),
2315+ )
2316+ @dtypes (torch .float )
2317+ def test_local_and_global_accumulator (self , out_features1 , dtype ):
2318+ batch_size = 256
2319+ in_features = 64
2320+ out_features = 129
2321+ in_features1 = 128
2322+ bias = True
2323+ try :
2324+ try :
2325+ from . import test_aot_inductor_utils
2326+ except ImportError :
2327+ import test_aot_inductor_utils
2328+ except Exception :
2329+ # skip this UT if import failed
2330+ return
2331+
2332+ class M (torch .nn .Module ):
2333+ def __init__ (self ):
2334+ super ().__init__ ()
2335+
2336+ self .linear = torch .nn .Linear (in_features , out_features , bias )
2337+ self .linear1 = torch .nn .Linear (in_features1 , out_features1 , bias )
2338+
2339+ def forward (self , x ):
2340+ y = self .linear (x )
2341+ view = torch .ops .aten .view .default (y , [- 1 , in_features1 ])
2342+ return self .linear1 (view )
2343+
2344+ counters .clear ()
2345+ x = torch .randn (batch_size , in_features ).to (dtype = dtype )
2346+ mod = M ().to (dtype = dtype ).eval ()
2347+ with verify (dtype ) as (atol , rtol ), torch .no_grad ():
2348+ expected = mod (
2349+ x ,
2350+ )
2351+ actual = test_aot_inductor_utils .AOTIRunnerUtil .run (
2352+ "cpu" ,
2353+ mod ,
2354+ (x ,),
2355+ )
2356+ self .assertEqual (actual , expected , atol = atol , rtol = rtol )
2357+ self .assertEqual (counters ["inductor" ]["select_algorithm_autotune" ], 2 )
2358+
23022359
23032360@dynamo_config .patch ({"dynamic_shapes" : True , "assume_static_by_default" : False })
23042361class _DynamicShapesTestBase (BaseTestSelectAlgorithm ):
0 commit comments