|
| 1 | +// RUN: enzymexlamlir-opt --pass-pipeline="any(enzyme-hlo-generate-td{patterns=concat_insert_dim_elementwise},transform-interpreter,enzyme-hlo-remove-transform)" |
| 2 | + |
| 3 | +module { |
| 4 | + func.func @mapped_sub(%arg0: tensor<3x5x10xf32>, %arg1: tensor<3x5x10xf32>) -> (tensor<5x3x10xf32>, tensor<3x5x10xf32>, tensor<3x5x10xf32>) { |
| 5 | + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32> |
| 6 | + %1 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32> |
| 7 | + %2 = stablehlo.slice %0 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 8 | + %3 = stablehlo.transpose %2, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 9 | + %4 = stablehlo.reshape %3 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 10 | + %5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 11 | + %6 = stablehlo.convert %5 : tensor<10x3xf32> |
| 12 | + %7 = stablehlo.broadcast_in_dim %6, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 13 | + %8 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 14 | + %9 = stablehlo.slice %1 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 15 | + %10 = stablehlo.transpose %9, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 16 | + %11 = stablehlo.reshape %10 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 17 | + %12 = stablehlo.transpose %11, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 18 | + %13 = stablehlo.convert %12 : tensor<10x3xf32> |
| 19 | + %14 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 20 | + %15 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 21 | + %16 = stablehlo.subtract %8, %15 : tensor<10x3xf32> |
| 22 | + %17 = stablehlo.slice %0 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 23 | + %18 = stablehlo.transpose %17, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 24 | + %19 = stablehlo.reshape %18 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 25 | + %20 = stablehlo.transpose %19, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 26 | + %21 = stablehlo.convert %20 : tensor<10x3xf32> |
| 27 | + %22 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 28 | + %23 = stablehlo.broadcast_in_dim %22, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 29 | + %24 = stablehlo.slice %1 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 30 | + %25 = stablehlo.transpose %24, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 31 | + %26 = stablehlo.reshape %25 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 32 | + %27 = stablehlo.transpose %26, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 33 | + %28 = stablehlo.convert %27 : tensor<10x3xf32> |
| 34 | + %29 = stablehlo.broadcast_in_dim %28, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 35 | + %30 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 36 | + %31 = stablehlo.subtract %23, %30 : tensor<10x3xf32> |
| 37 | + %32 = stablehlo.slice %0 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 38 | + %33 = stablehlo.transpose %32, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 39 | + %34 = stablehlo.reshape %33 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 40 | + %35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 41 | + %36 = stablehlo.convert %35 : tensor<10x3xf32> |
| 42 | + %37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 43 | + %38 = stablehlo.broadcast_in_dim %37, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 44 | + %39 = stablehlo.slice %1 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 45 | + %40 = stablehlo.transpose %39, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 46 | + %41 = stablehlo.reshape %40 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 47 | + %42 = stablehlo.transpose %41, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 48 | + %43 = stablehlo.convert %42 : tensor<10x3xf32> |
| 49 | + %44 = stablehlo.broadcast_in_dim %43, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 50 | + %45 = stablehlo.broadcast_in_dim %44, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 51 | + %46 = stablehlo.subtract %38, %45 : tensor<10x3xf32> |
| 52 | + %47 = stablehlo.slice %0 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 53 | + %48 = stablehlo.transpose %47, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 54 | + %49 = stablehlo.reshape %48 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 55 | + %50 = stablehlo.transpose %49, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 56 | + %51 = stablehlo.convert %50 : tensor<10x3xf32> |
| 57 | + %52 = stablehlo.broadcast_in_dim %51, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 58 | + %53 = stablehlo.broadcast_in_dim %52, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 59 | + %54 = stablehlo.slice %1 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 60 | + %55 = stablehlo.transpose %54, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 61 | + %56 = stablehlo.reshape %55 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 62 | + %57 = stablehlo.transpose %56, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 63 | + %58 = stablehlo.convert %57 : tensor<10x3xf32> |
| 64 | + %59 = stablehlo.broadcast_in_dim %58, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 65 | + %60 = stablehlo.broadcast_in_dim %59, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 66 | + %61 = stablehlo.subtract %53, %60 : tensor<10x3xf32> |
| 67 | + %62 = stablehlo.slice %0 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 68 | + %63 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 69 | + %64 = stablehlo.reshape %63 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 70 | + %65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 71 | + %66 = stablehlo.convert %65 : tensor<10x3xf32> |
| 72 | + %67 = stablehlo.broadcast_in_dim %66, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 73 | + %68 = stablehlo.broadcast_in_dim %67, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 74 | + %69 = stablehlo.slice %1 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32> |
| 75 | + %70 = stablehlo.transpose %69, dims = [2, 1, 0] : (tensor<10x1x3xf32>) -> tensor<3x1x10xf32> |
| 76 | + %71 = stablehlo.reshape %70 : (tensor<3x1x10xf32>) -> tensor<3x10xf32> |
| 77 | + %72 = stablehlo.transpose %71, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32> |
| 78 | + %73 = stablehlo.convert %72 : tensor<10x3xf32> |
| 79 | + %74 = stablehlo.broadcast_in_dim %73, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 80 | + %75 = stablehlo.broadcast_in_dim %74, dims = [0, 1] : (tensor<10x3xf32>) -> tensor<10x3xf32> |
| 81 | + %76 = stablehlo.subtract %68, %75 : tensor<10x3xf32> |
| 82 | + %77 = stablehlo.broadcast_in_dim %16, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32> |
| 83 | + %78 = stablehlo.broadcast_in_dim %31, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32> |
| 84 | + %79 = stablehlo.broadcast_in_dim %46, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32> |
| 85 | + %80 = stablehlo.broadcast_in_dim %61, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32> |
| 86 | + %81 = stablehlo.broadcast_in_dim %76, dims = [2, 1] : (tensor<10x3xf32>) -> tensor<1x3x10xf32> |
| 87 | + %82 = stablehlo.concatenate %77, %78, %79, %80, %81, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32> |
| 88 | + // CHECK: stablehlo.concatenate |
| 89 | + %83 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<10x5x3xf32>) -> tensor<3x5x10xf32> |
| 90 | + %84 = stablehlo.transpose %1, dims = [2, 1, 0] : (tensor<10x5x3xf32>) -> tensor<3x5x10xf32> |
| 91 | + return %82, %83, %84 : tensor<5x3x10xf32>, tensor<3x5x10xf32>, tensor<3x5x10xf32> |
| 92 | + } |
| 93 | +} |
0 commit comments