|
1 | 1 | # RUN: %PYTHON %s | FileCheck %s |
2 | 2 |
|
3 | | -from mlir.dialects import arith, func, linalg, tensor, memref |
| 3 | +from mlir.dialects import arith, func, linalg, tensor, memref, builtin |
4 | 4 | from mlir.dialects.linalg.opdsl.lang import * |
| 5 | +from mlir.extras import types as T |
5 | 6 | from mlir.ir import * |
6 | 7 |
|
7 | 8 |
|
@@ -857,3 +858,76 @@ def elementwise_op( |
857 | 858 | ) |
858 | 859 |
|
859 | 860 | print(module) |
| 861 | + |
| 862 | + |
| 863 | +@run |
| 864 | +def testReduceOp(): |
| 865 | + with Context(), Location.unknown(): |
| 866 | + f32 = T.f32() |
| 867 | + tensor_type = T.tensor(10, f32) |
| 868 | + |
| 869 | + @builtin.module |
| 870 | + def module(): |
| 871 | + @func.func(tensor_type) |
| 872 | + def reduce_op(input): |
| 873 | + c1 = arith.constant(f32, 1.0) |
| 874 | + single_result = ir.RankedTensorType.get((), f32) |
| 875 | + dims = ir.DenseI64ArrayAttr.get([0]) |
| 876 | + init = tensor.splat(single_result, c1, []) |
| 877 | + |
| 878 | + @linalg.reduce( |
| 879 | + result=[single_result], |
| 880 | + inputs=[input], |
| 881 | + inits=[init], |
| 882 | + dimensions=dims, |
| 883 | + ) |
| 884 | + def reduced(element: f32, acc: f32): |
| 885 | + return arith.mulf(acc, element) |
| 886 | + |
| 887 | + return tensor.extract(reduced, []) |
| 888 | + |
| 889 | + print(module) |
| 890 | + |
| 891 | + |
| 892 | +# CHECK-LABEL: func.func @reduce_op( |
| 893 | +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> f32 { |
| 894 | +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32 |
| 895 | +# CHECK: %[[SPLAT_0:.*]] = tensor.splat %[[CONSTANT_0]] : tensor<f32> |
| 896 | +# CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.mulf } ins(%[[ARG0]] : tensor<10xf32>) outs(%[[SPLAT_0]] : tensor<f32>) dimensions = [0] |
| 897 | +# CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor<f32> |
| 898 | +# CHECK: return %[[EXTRACT_0]] : f32 |
| 899 | +# CHECK: } |
| 900 | + |
| 901 | + |
| 902 | +@run |
| 903 | +def testMapOp(): |
| 904 | + with Context(), Location.unknown(): |
| 905 | + f32 = T.f32() |
| 906 | + tensor_type = T.tensor(10, f32) |
| 907 | + |
| 908 | + @builtin.module |
| 909 | + def module(): |
| 910 | + @func.func(tensor_type) |
| 911 | + def map_op(input): |
| 912 | + empty = tensor.empty(tensor_type.shape, f32) |
| 913 | + |
| 914 | + @linalg.map( |
| 915 | + result=[tensor_type], |
| 916 | + inputs=[input, input], |
| 917 | + init=empty, |
| 918 | + ) |
| 919 | + def add(element: f32, acc: f32, init: f32): |
| 920 | + return arith.addf(element, acc) |
| 921 | + |
| 922 | + return add |
| 923 | + |
| 924 | + module.verify() |
| 925 | + print(module) |
| 926 | + |
| 927 | + |
| 928 | +# CHECK-LABEL: func.func @map_op( |
| 929 | +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> tensor<10xf32> { |
| 930 | +# CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<10xf32> |
| 931 | +# CHECK: %[[MAP_0:.*]] = linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG0]] : tensor<10xf32>, tensor<10xf32>) outs(%[[EMPTY_0]] : tensor<10xf32>) |
| 932 | +# CHECK: return %[[MAP_0]] : tensor<10xf32> |
| 933 | +# CHECK: } |
0 commit comments