Skip to content

Improve concat like commsΒ #1844

@wsmoses

Description

@wsmoses

see:


    %21 = stablehlo.slice %15 [8:12, 6:6134, 12271:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
    %31 = stablehlo.slice %15 [8:12, 6:6134, 0:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc4550)
    %723 = stablehlo.concatenate %21, %31, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)


    %42 = stablehlo.slice %36 [8:12, 11:6139, 12271:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
    %47 = stablehlo.slice %36 [8:12, 11:6139, 0:12272] : (tensor<20x6144x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc4550)
    %744 = stablehlo.concatenate %42, %47, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)


    %895 = stablehlo.slice %827 [0:4, 0:1, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
    %1190 = stablehlo.slice %893 [0:4, 0:1, 0:12272] : (tensor<4x6126x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
    %1189 = stablehlo.add %1149, %1188 : tensor<4x6124x12272xf32> loc(#loc2658)
    %894 = stablehlo.slice %893 [0:4, 6125:6126, 0:12272] : (tensor<4x6126x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
    %896 = stablehlo.slice %827 [0:4, 6127:6128, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
    %1191 = stablehlo.concatenate %895, %1190, %1189, %894, %896, dim = 1 : (tensor<4x1x12272xf32>, tensor<4x1x12272xf32>, tensor<4x6124x12272xf32>, tensor<4x1x12272xf32>, tensor<4x1x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc2576)


    %1587 = stablehlo.add %1582, %1586 : tensor<4x6126x12272xf32> loc(#loc2432)
    %1588 = stablehlo.slice %1572 [0:4, 0:1, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
    %1589 = stablehlo.slice %1572 [0:4, 6127:6128, 0:12272] : (tensor<4x6128x12272xf32>) -> tensor<4x1x12272xf32> loc(#loc2576)
    %1590 = stablehlo.concatenate %1588, %1587, %1589, dim = 1 : (tensor<4x1x12272xf32>, tensor<4x6126x12272xf32>, tensor<4x1x12272xf32>) -> tensor<4x6128x12272xf32> loc(#loc2576)


    %1609 = stablehlo.slice %arg20 [8:12, 8:6136, 8:12278] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x12270xf32> loc(#loc2576)
    %1610 = stablehlo.slice %arg20 [8:12, 8:6136, 12277:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x3xf32> loc(#loc2576)
    %1611 = stablehlo.concatenate %1610, %1609, dim = 2 : (tensor<4x6128x3xf32>, tensor<4x6128x12270xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)


    %1613 = stablehlo.slice %arg20 [8:12, 8:6136, 12278:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x2xf32> loc(#loc2576)
    %1614 = stablehlo.slice %arg20 [8:12, 8:6136, 8:12279] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x12271xf32> loc(#loc2576)
    %1615 = stablehlo.concatenate %1613, %1614, dim = 2 : (tensor<4x6128x2xf32>, tensor<4x6128x12271xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)

    
    %1608 = stablehlo.slice %arg20 [8:12, 8:6136, 12279:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
    ...
    %1617 = stablehlo.concatenate %1608, %695, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)


    ...
    %1619 = stablehlo.concatenate %695, %1241, dim = 2 : (tensor<4x6128x12272xf32>, tensor<4x6128x1xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)



    %1621 = stablehlo.slice %arg20 [8:12, 8:6136, 9:12280] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x12271xf32> loc(#loc2576)
    %1622 = stablehlo.slice %arg20 [8:12, 8:6136, 8:10] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x2xf32> loc(#loc2576)
    %1623 = stablehlo.concatenate %1621, %1622, dim = 2 : (tensor<4x6128x12271xf32>, tensor<4x6128x2xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)


    ...
    %1653 = stablehlo.slice %arg20 [8:12, 8:6136, 10:11] : (tensor<20x6144x12288xf32>) -> tensor<4x6128x1xf32> loc(#loc2576)
    %1654 = stablehlo.concatenate %1246, %1653, dim = 2 : (tensor<4x6128x12272xf32>, tensor<4x6128x1xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)



    ...
    %1795 = stablehlo.concatenate %26, %24, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)


    ...
    %1798 = stablehlo.concatenate %46, %45, dim = 2 : (tensor<4x6128x1xf32>, tensor<4x6128x12272xf32>) -> tensor<4x6128x12273xf32> loc(#loc2576)


    ...
    %1825 = stablehlo.concatenate %95, %94, dim = 2 : (tensor<4x6128x12272xf32>, tensor<4x6128x1xf32>) -> tensor<4x6128x12273xf32> loc(#loc2100)

two clear optimizations from here:

  1. if we have a concat of two things, and one is a slice of the other (on the opposite side), we should pad the bigger one to the end size, rotate the pad to place the corresponding slice at the right location, the dus it it

  2. if we have a concat we should identify the operand with he biggest size along the concat dimension. we should pad it to the result operand. We then will dus in the other operands.

all of these should be added to https://github.com/EnzymeAD/Enzyme-JAX/blob/main/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp

cc @avik-pal

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions