@@ -1292,6 +1292,185 @@ def fn(x, y, z):
12921292
12931293 self .assertEqual (ref , res )
12941294
1295+ @torch ._inductor .config .patch (emulate_precision_casts = True )
1296+ def test_dont_inplace_disjoint_accesses (self ):
1297+ # TODO - would not need mms if we could annotate donated buffer..
1298+ def forward ( # noqa: F821, F722
1299+ arg0_1 : "bf16[2048, 2048][2048, 1]cuda:0" , # noqa: F821, F722
1300+ arg1_1 : "bf16[8, 4096, 2048][8388608, 2048, 1]cuda:0" , # noqa: F821, F722
1301+ arg2_1 : "bf16[2048, 2048][2048, 1]cuda:0" , # noqa: F821, F722
1302+ arg3_1 : "bf16[2048, 2048][2048, 1]cuda:0" , # noqa: F821, F722
1303+ arg4_1 : "bf16[2048][1]cuda:0" , # noqa: F821, F722
1304+ arg5_1 : "bf16[2048][1]cuda:0" , # noqa: F821, F722
1305+ arg6_1 : "f32[4096, 128][128, 1]cuda:0" , # noqa: F821, F722
1306+ arg7_1 : "f32[4096, 128][128, 1]cuda:0" , # noqa: F821, F722
1307+ ):
1308+ permute = torch .ops .aten .permute .default (arg0_1 , [1 , 0 ])
1309+ arg0_1 = None
1310+ view = torch .ops .aten .view .default (arg1_1 , [32768 , 2048 ])
1311+ mm = torch .ops .aten .mm .default (view , permute )
1312+ view = permute = None
1313+ view_1 = torch .ops .aten .view .default (mm , [8 , 4096 , 2048 ])
1314+ mm = None
1315+ permute_1 = torch .ops .aten .permute .default (arg2_1 , [1 , 0 ])
1316+ arg2_1 = None
1317+ view_2 = torch .ops .aten .view .default (arg1_1 , [32768 , 2048 ])
1318+ mm_1 = torch .ops .aten .mm .default (view_2 , permute_1 )
1319+ view_2 = permute_1 = None
1320+ view_3 = torch .ops .aten .view .default (mm_1 , [8 , 4096 , 2048 ])
1321+ mm_1 = None
1322+ permute_2 = torch .ops .aten .permute .default (arg3_1 , [1 , 0 ])
1323+ arg3_1 = None
1324+ view_4 = torch .ops .aten .view .default (arg1_1 , [32768 , 2048 ])
1325+ arg1_1 = None
1326+ mm_2 = torch .ops .aten .mm .default (view_4 , permute_2 )
1327+ view_4 = permute_2 = None
1328+ view_5 = torch .ops .aten .view .default (mm_2 , [8 , 4096 , 2048 ])
1329+ mm_2 = None
1330+ convert_element_type_6 = torch .ops .prims .convert_element_type .default (
1331+ view_1 , torch .float32
1332+ )
1333+ view_1 = None
1334+ pow_1 = torch .ops .aten .pow .Tensor_Scalar (convert_element_type_6 , 2 )
1335+ mean = torch .ops .aten .mean .dim (pow_1 , [- 1 ], True )
1336+ pow_1 = None
1337+ add = torch .ops .aten .add .Tensor (mean , 1e-06 )
1338+ mean = None
1339+ rsqrt = torch .ops .aten .rsqrt .default (add )
1340+ add = None
1341+ mul = torch .ops .aten .mul .Tensor (convert_element_type_6 , rsqrt )
1342+ convert_element_type_6 = rsqrt = None
1343+ convert_element_type_7 = torch .ops .prims .convert_element_type .default (
1344+ arg4_1 , torch .float32
1345+ )
1346+ arg4_1 = None
1347+ mul_1 = torch .ops .aten .mul .Tensor (convert_element_type_7 , mul )
1348+ convert_element_type_7 = mul = None
1349+ convert_element_type_8 = torch .ops .prims .convert_element_type .default (
1350+ mul_1 , torch .bfloat16
1351+ )
1352+ mul_1 = None
1353+ convert_element_type_9 = torch .ops .prims .convert_element_type .default (
1354+ view_3 , torch .float32
1355+ )
1356+ view_3 = None
1357+ pow_2 = torch .ops .aten .pow .Tensor_Scalar (convert_element_type_9 , 2 )
1358+ mean_1 = torch .ops .aten .mean .dim (pow_2 , [- 1 ], True )
1359+ pow_2 = None
1360+ add_1 = torch .ops .aten .add .Tensor (mean_1 , 1e-06 )
1361+ mean_1 = None
1362+ rsqrt_1 = torch .ops .aten .rsqrt .default (add_1 )
1363+ add_1 = None
1364+ mul_2 = torch .ops .aten .mul .Tensor (convert_element_type_9 , rsqrt_1 )
1365+ convert_element_type_9 = rsqrt_1 = None
1366+ convert_element_type_10 = torch .ops .prims .convert_element_type .default (
1367+ arg5_1 , torch .float32
1368+ )
1369+ arg5_1 = None
1370+ mul_3 = torch .ops .aten .mul .Tensor (convert_element_type_10 , mul_2 )
1371+ convert_element_type_10 = mul_2 = None
1372+ convert_element_type_11 = torch .ops .prims .convert_element_type .default (
1373+ mul_3 , torch .bfloat16
1374+ )
1375+ mul_3 = None
1376+ view_6 = torch .ops .aten .view .default (
1377+ convert_element_type_8 , [8 , 4096 , - 1 , 128 ]
1378+ )
1379+ convert_element_type_8 = None
1380+ view_7 = torch .ops .aten .view .default (
1381+ convert_element_type_11 , [8 , 4096 , - 1 , 128 ]
1382+ )
1383+ convert_element_type_11 = None
1384+ view_8 = torch .ops .aten .view .default (view_5 , [8 , 4096 , - 1 , 128 ])
1385+ view_5 = None
1386+ convert_element_type_12 = torch .ops .prims .convert_element_type .default (
1387+ view_6 , torch .float32
1388+ )
1389+ view_6 = None
1390+ convert_element_type_13 = torch .ops .prims .convert_element_type .default (
1391+ view_7 , torch .float32
1392+ )
1393+ view_7 = None
1394+ unsqueeze = torch .ops .aten .unsqueeze .default (arg6_1 , 0 )
1395+ unsqueeze_1 = torch .ops .aten .unsqueeze .default (unsqueeze , 2 )
1396+ unsqueeze = None
1397+ unsqueeze_2 = torch .ops .aten .unsqueeze .default (arg7_1 , 0 )
1398+ unsqueeze_3 = torch .ops .aten .unsqueeze .default (unsqueeze_2 , 2 )
1399+ unsqueeze_2 = None
1400+ mul_4 = torch .ops .aten .mul .Tensor (convert_element_type_12 , unsqueeze_3 )
1401+ unsqueeze_3 = None
1402+ view_9 = torch .ops .aten .view .default (
1403+ convert_element_type_12 , [8 , 4096 , 16 , 2 , 64 ]
1404+ )
1405+ convert_element_type_12 = None
1406+ unbind = torch .ops .aten .unbind .int (view_9 , - 2 )
1407+ view_9 = None
1408+ getitem = unbind [0 ]
1409+ getitem_1 = unbind [1 ]
1410+ unbind = None
1411+ neg = torch .ops .aten .neg .default (getitem_1 )
1412+ getitem_1 = None
1413+ cat = torch .ops .aten .cat .default ([neg , getitem ], - 1 )
1414+ neg = getitem = None
1415+ mul_5 = torch .ops .aten .mul .Tensor (cat , unsqueeze_1 )
1416+ cat = unsqueeze_1 = None
1417+ add_2 = torch .ops .aten .add .Tensor (mul_4 , mul_5 )
1418+ mul_4 = mul_5 = None
1419+ unsqueeze_4 = torch .ops .aten .unsqueeze .default (arg6_1 , 0 )
1420+ arg6_1 = None
1421+ unsqueeze_5 = torch .ops .aten .unsqueeze .default (unsqueeze_4 , 2 )
1422+ unsqueeze_4 = None
1423+ unsqueeze_6 = torch .ops .aten .unsqueeze .default (arg7_1 , 0 )
1424+ arg7_1 = None
1425+ unsqueeze_7 = torch .ops .aten .unsqueeze .default (unsqueeze_6 , 2 )
1426+ unsqueeze_6 = None
1427+ mul_6 = torch .ops .aten .mul .Tensor (convert_element_type_13 , unsqueeze_7 )
1428+ unsqueeze_7 = None
1429+ view_10 = torch .ops .aten .view .default (
1430+ convert_element_type_13 , [8 , 4096 , 16 , 2 , 64 ]
1431+ )
1432+ convert_element_type_13 = None
1433+ unbind_1 = torch .ops .aten .unbind .int (view_10 , - 2 )
1434+ view_10 = None
1435+ getitem_2 = unbind_1 [0 ]
1436+ getitem_3 = unbind_1 [1 ]
1437+ unbind_1 = None
1438+ neg_1 = torch .ops .aten .neg .default (getitem_3 )
1439+ getitem_3 = None
1440+ cat_1 = torch .ops .aten .cat .default ([neg_1 , getitem_2 ], - 1 )
1441+ neg_1 = getitem_2 = None
1442+ mul_7 = torch .ops .aten .mul .Tensor (cat_1 , unsqueeze_5 )
1443+ cat_1 = unsqueeze_5 = None
1444+ add_3 = torch .ops .aten .add .Tensor (mul_6 , mul_7 )
1445+ mul_6 = mul_7 = None
1446+ convert_element_type_14 = torch .ops .prims .convert_element_type .default (
1447+ add_2 , torch .bfloat16
1448+ )
1449+ add_2 = None
1450+ convert_element_type_15 = torch .ops .prims .convert_element_type .default (
1451+ add_3 , torch .bfloat16
1452+ )
1453+ add_3 = None
1454+ permute_3 = torch .ops .aten .permute .default (
1455+ convert_element_type_14 , [0 , 2 , 1 , 3 ]
1456+ )
1457+ convert_element_type_14 = None
1458+ permute_4 = torch .ops .aten .permute .default (
1459+ convert_element_type_15 , [0 , 2 , 1 , 3 ]
1460+ )
1461+ convert_element_type_15 = None
1462+ permute_5 = torch .ops .aten .permute .default (view_8 , [0 , 2 , 1 , 3 ])
1463+ view_8 = None
1464+ return (permute_3 , permute_4 , permute_5 )
1465+
1466+ from torch ._dynamo .debug_utils import aot_graph_input_parser
1467+
1468+ kwargs = aot_graph_input_parser (forward )
1469+ out , code = run_and_get_code (torch .compile (forward ), ** kwargs )
1470+ # ignore tiny values.. prior to this fix absolute error was ~28
1471+ self .assertEqual (forward (** kwargs ), out , atol = 0.01 , rtol = 2 )
1472+ FileCheck ().check_not ("in_out" ).run (code [0 ])
1473+
12951474 # https://github.com/pytorch/pytorch/issues/104937
12961475 def test_linear_with_zero_infeature_size (self ):
12971476 m = nn .Linear (in_features = 0 , out_features = 0 , bias = True ).to ("cuda" )
0 commit comments