Skip to content

dot_general(reshape(bcast), ...) simplify #1877

@avik-pal

Description

@avik-pal

If both operands of the dot_general has delete_dims, we can add a pass similar to reduce_delete_dims but for dot_general and incorporate those dims in the dot_general contracting dims. The following is an example where this would apply.

module @reactant_covariance attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<128x2048xf32> {enzymexla.memory_effects = []}) -> tensor<128x128xf32> attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<4.88519785E-4> : tensor<128x128xf32>
    %cst_0 = stablehlo.constant dense<4.8828125E-4> : tensor<128x1xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.reshape %arg0 : (tensor<128x2048xf32>) -> tensor<128x2048x1xf32>
    %1 = stablehlo.reduce(%0 init: %cst_1) applies stablehlo.add across dimensions = [1] : (tensor<128x2048x1xf32>, tensor<f32>) -> tensor<128x1xf32>
    %2 = stablehlo.multiply %1, %cst_0 : tensor<128x1xf32>
    %3 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<128x2048xf32>) -> tensor<2048x128x1x1xf32>
    %4 = stablehlo.broadcast_in_dim %2, dims = [1, 2] : (tensor<128x1xf32>) -> tensor<2048x128x1x1xf32>
    %5 = stablehlo.subtract %3, %4 : tensor<2048x128x1x1xf32>
    %6 = stablehlo.broadcast_in_dim %5, dims = [2, 0, 3, 4] : (tensor<2048x128x1x1xf32>) -> tensor<128x128x2048x1x1xf32>
    %7 = stablehlo.broadcast_in_dim %5, dims = [2, 1, 3, 4] : (tensor<2048x128x1x1xf32>) -> tensor<128x128x2048x1x1xf32>
    %8 = stablehlo.reshape %6 : (tensor<128x128x2048x1x1xf32>) -> tensor<128x128x2048x1xf32>
    %9 = stablehlo.reshape %7 : (tensor<128x128x2048x1x1xf32>) -> tensor<128x128x2048x1xf32>
    %10 = stablehlo.dot_general %8, %9, batching_dims = [0, 1, 3] x [0, 1, 3], contracting_dims = [2] x [2] : (tensor<128x128x2048x1xf32>, tensor<128x128x2048x1xf32>) -> tensor<128x128x1xf32>
    %11 = stablehlo.transpose %10, dims = [1, 0, 2] : (tensor<128x128x1xf32>) -> tensor<128x128x1xf32>
    %12 = stablehlo.reshape %11 : (tensor<128x128x1xf32>) -> tensor<128x128xf32>
    %13 = stablehlo.multiply %12, %cst : tensor<128x128xf32>
    return %13 : tensor<128x128xf32>
  }
}

However, if there is a mismatch in delete dims in the operands, we should check if we remove delete dims from one then can we do better fusions (with broadcast and dot_general). If that is the case, then we can expand the dims of the other operand.

module @reactant_syrk attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f32> {enzymexla.memory_effects = []}, %arg2: tensor<2048x2048xf32> {enzymexla.memory_effects = []}, %arg3: tensor<2048x2048xf32> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}) -> tensor<2048x2048xf32> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<2048x2048x1x1xf32>
    %1 = stablehlo.reshape %arg2 : (tensor<2048x2048xf32>) -> tensor<2048x2048x1x1xf32>
    %2 = stablehlo.multiply %1, %0 : tensor<2048x2048x1x1xf32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [1, 2, 3, 4] : (tensor<2048x2048x1x1xf32>) -> tensor<2048x2048x2048x1x1xf32>
    %4 = stablehlo.reshape %3 : (tensor<2048x2048x2048x1x1xf32>) -> tensor<2048x2048x2048x1xf32>
    %5 = stablehlo.dot_general %4, %arg2, batching_dims = [0] x [1], contracting_dims = [1] x [0] : (tensor<2048x2048x2048x1xf32>, tensor<2048x2048xf32>) -> tensor<2048x2048x1xf32>
    %6 = stablehlo.broadcast_in_dim %arg3, dims = [1, 0] : (tensor<2048x2048xf32>) -> tensor<2048x2048x1xf32>
    %7 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<2048x2048x1xf32>
    %8 = stablehlo.multiply %6, %7 : tensor<2048x2048x1xf32>
    %9 = stablehlo.add %5, %8 : tensor<2048x2048x1xf32>
    %10 = stablehlo.reshape %9 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : (tensor<2048x2048x1xf32>) -> tensor<2048x2048xf32>
    %11 = stablehlo.transpose %10, dims = [1, 0] : (tensor<2048x2048xf32>) -> tensor<2048x2048xf32>
    return %11 : tensor<2048x2048xf32>
  }
}

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions