-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
see:
%21 = stablehlo.slice %15 [8:12, 6:6134, 12271:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
%31 = stablehlo.slice %15 [8:12, 6:6134, 0:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc4550)
%723 = stablehlo.concatenate %21, %31, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
%42 = stablehlo.slice %36 [8:12, 11:6139, 12271:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
%47 = stablehlo.slice %36 [8:12, 11:6139, 0:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc4550)
%744 = stablehlo.concatenate %42, %47, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
%895 = stablehlo.slice %827 [0:4, 0:1, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
%1190 = stablehlo.slice %893 [0:4, 0:1, 0:12272] : (tensor<4x6126x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
%1189 = stablehlo.add %1149, %1188 : tensor<4x6124x12272xf32> loc(#loc2658)
%894 = stablehlo.slice %893 [0:4, 6125:6126, 0:12272] : (tensor<4x6126x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
%896 = stablehlo.slice %827 [0:4, 6127:6128, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
%1191 = stablehlo.concatenate %895, %1190, %1189, %894, %896, dim = 1 : (tensor<4x1x12272xf32>, tensor<4x1x12272xf32>, tensor<4x6124x12272xf32>, tensor<4x1x12272xf32>, tensor<4x1x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc2576)
%1587 = stablehlo.add %1582, %1586 : tensor<4x6126x12272xf32> loc(#loc2432)
%1588 = stablehlo.slice %1572 [0:4, 0:1, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
%1589 = stablehlo.slice %1572 [0:4, 6127:6128, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
%1590 = stablehlo.concatenate %1588, %1587, %1589, dim = 1 : (tensor<4x1x12272xf32>, tensor<4x6126x12272xf32>, tensor<4x1x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc2576)
%1609 = stablehlo.slice %arg20 [8:12, 8:6136, 8:12278] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x12270xf32> loc(#loc2576)
%1610 = stablehlo.slice %arg20 [8:12, 8:6136, 12277:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x3xf32> loc(#loc2576)
%1611 = stablehlo.concatenate %1610, %1609, dim = 2 : (tensor<4x6128x3xf32>, tensor<4x6128x12270xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
%1613 = stablehlo.slice %arg20 [8:12, 8:6136, 12278:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x2xf32> loc(#loc2576)
%1614 = stablehlo.slice %arg20 [8:12, 8:6136, 8:12279] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x12271xf32> loc(#loc2576)
%1615 = stablehlo.concatenate %1613, %1614, dim = 2 : (tensor<4x6128x2xf32>, tensor<4x6128x12271xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
%1608 = stablehlo.slice %arg20 [8:12, 8:6136, 12279:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
...
%1617 = stablehlo.concatenate %1608, %695, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
...
%1619 = stablehlo.concatenate %695, %1241, dim = 2 : (tensor<4x6128x12272xf32>, tensor<4x6128x1xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
%1621 = stablehlo.slice %arg20 [8:12, 8:6136, 9:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x12271xf32> loc(#loc2576)
%1622 = stablehlo.slice %arg20 [8:12, 8:6136, 8:10] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x2xf32> loc(#loc2576)
%1623 = stablehlo.concatenate %1621, %1622, dim = 2 : (tensor<4x6128x12271xf32>, tensor<4x6128x2xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
...
%1653 = stablehlo.slice %arg20 [8:12, 8:6136, 10:11] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
%1654 = stablehlo.concatenate %1246, %1653, dim = 2 : (tensor<4x6128x12272xf32>, tensor<4x6128x1xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
...
%1795 = stablehlo.concatenate %26, %24, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
...
%1798 = stablehlo.concatenate %46, %45, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)
...
%1825 = stablehlo.concatenate %95, %94, dim = 2 : (tensor<4x6128x12272xf32>, tensor<4x6128x1xf32>) -> tensor<4x6128x12273xf32> loc(#loc2100)
two clear optimizations from here:
-
if we have a concat of two things, and one is a slice of the other (on the opposite side), we should pad the bigger one to the end size, rotate the pad to place the corresponding slice at the right location, the dus it it
-
if we have a concat we should identify the operand with he biggest size along the concat dimension. we should pad it to the result operand. We then will dus in the other operands.
all of these should be added to https://github.com/EnzymeAD/Enzyme-JAX/blob/main/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp
cc @avik-pal
Copilot