|  | 
|  | 1 | +# RUN: python %s | FileCheck %s | 
|  | 2 | +from unittest import result | 
|  | 3 | +from mlir.ir import ( | 
|  | 4 | +    Context, | 
|  | 5 | +    FunctionType, | 
|  | 6 | +    Location, | 
|  | 7 | +    Module, | 
|  | 8 | +    InsertionPoint, | 
|  | 9 | +    IntegerType, | 
|  | 10 | +    IndexType, | 
|  | 11 | +    MemRefType, | 
|  | 12 | +    F32Type, | 
|  | 13 | +    Block, | 
|  | 14 | +    ArrayAttr, | 
|  | 15 | +    Attribute, | 
|  | 16 | +    UnitAttr, | 
|  | 17 | +    StringAttr, | 
|  | 18 | +    DenseI32ArrayAttr, | 
|  | 19 | +    ShapedType, | 
|  | 20 | +) | 
|  | 21 | +from mlir.dialects import openacc, func, arith, memref | 
|  | 22 | +from mlir.extras import types | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +def run(f): | 
|  | 26 | +    print("\n// TEST:", f.__name__) | 
|  | 27 | +    with Context(), Location.unknown(): | 
|  | 28 | +        f() | 
|  | 29 | +    return f | 
|  | 30 | + | 
|  | 31 | + | 
|  | 32 | +@run | 
|  | 33 | +def testParallelMemcpy(): | 
|  | 34 | +    module = Module.create() | 
|  | 35 | + | 
|  | 36 | +    dynamic = ShapedType.get_dynamic_size() | 
|  | 37 | +    memref_f32_1d_any = MemRefType.get([dynamic], types.f32()) | 
|  | 38 | + | 
|  | 39 | +    with InsertionPoint(module.body): | 
|  | 40 | +        function_type = FunctionType.get( | 
|  | 41 | +            [memref_f32_1d_any, memref_f32_1d_any, types.i64()], [] | 
|  | 42 | +        ) | 
|  | 43 | +        f = func.FuncOp( | 
|  | 44 | +            type=function_type, | 
|  | 45 | +            name="memcpy_idiom", | 
|  | 46 | +        ) | 
|  | 47 | +        f.attributes["sym_visibility"] = StringAttr.get("public") | 
|  | 48 | + | 
|  | 49 | +    with InsertionPoint(f.add_entry_block()): | 
|  | 50 | +        c1024 = arith.ConstantOp(types.i32(), 1024) | 
|  | 51 | +        c128 = arith.ConstantOp(types.i32(), 128) | 
|  | 52 | + | 
|  | 53 | +        arg0, arg1, arg2 = f.arguments | 
|  | 54 | + | 
|  | 55 | +        copied = openacc.copyin( | 
|  | 56 | +            acc_var=arg0.type, | 
|  | 57 | +            var=arg0, | 
|  | 58 | +            var_type=types.f32(), | 
|  | 59 | +            bounds=[], | 
|  | 60 | +            async_operands=[], | 
|  | 61 | +            implicit=False, | 
|  | 62 | +            structured=True, | 
|  | 63 | +        ) | 
|  | 64 | +        created = openacc.create_( | 
|  | 65 | +            acc_var=arg1.type, | 
|  | 66 | +            var=arg1, | 
|  | 67 | +            var_type=types.f32(), | 
|  | 68 | +            bounds=[], | 
|  | 69 | +            async_operands=[], | 
|  | 70 | +            implicit=False, | 
|  | 71 | +            structured=True, | 
|  | 72 | +        ) | 
|  | 73 | + | 
|  | 74 | +        parallel_op = openacc.ParallelOp( | 
|  | 75 | +            asyncOperands=[], | 
|  | 76 | +            waitOperands=[], | 
|  | 77 | +            numGangs=[c1024], | 
|  | 78 | +            numWorkers=[], | 
|  | 79 | +            vectorLength=[c128], | 
|  | 80 | +            reductionOperands=[], | 
|  | 81 | +            privateOperands=[], | 
|  | 82 | +            firstprivateOperands=[], | 
|  | 83 | +            dataClauseOperands=[], | 
|  | 84 | +        ) | 
|  | 85 | + | 
|  | 86 | +        # Set required device_type and segment attributes to satisfy verifier | 
|  | 87 | +        acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")]) | 
|  | 88 | +        parallel_op.numGangsDeviceType = acc_device_none | 
|  | 89 | +        parallel_op.numGangsSegments = DenseI32ArrayAttr.get([1]) | 
|  | 90 | +        parallel_op.vectorLengthDeviceType = acc_device_none | 
|  | 91 | + | 
|  | 92 | +        parallel_block = Block.create_at_start(parent=parallel_op.region, arg_types=[]) | 
|  | 93 | + | 
|  | 94 | +        with InsertionPoint(parallel_block): | 
|  | 95 | +            c0 = arith.ConstantOp(types.i64(), 0) | 
|  | 96 | +            c1 = arith.ConstantOp(types.i64(), 1) | 
|  | 97 | + | 
|  | 98 | +            loop_op = openacc.LoopOp( | 
|  | 99 | +                results_=[], | 
|  | 100 | +                lowerbound=[c0], | 
|  | 101 | +                upperbound=[f.arguments[2]], | 
|  | 102 | +                step=[c1], | 
|  | 103 | +                gangOperands=[], | 
|  | 104 | +                workerNumOperands=[], | 
|  | 105 | +                vectorOperands=[], | 
|  | 106 | +                tileOperands=[], | 
|  | 107 | +                cacheOperands=[], | 
|  | 108 | +                privateOperands=[], | 
|  | 109 | +                reductionOperands=[], | 
|  | 110 | +                firstprivateOperands=[], | 
|  | 111 | +            ) | 
|  | 112 | + | 
|  | 113 | +            # Set loop attributes: gang and independent on device_type<none> | 
|  | 114 | +            acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")]) | 
|  | 115 | +            loop_op.gang = acc_device_none | 
|  | 116 | +            loop_op.independent = acc_device_none | 
|  | 117 | + | 
|  | 118 | +            loop_block = Block.create_at_start( | 
|  | 119 | +                parent=loop_op.region, arg_types=[types.i64()] | 
|  | 120 | +            ) | 
|  | 121 | + | 
|  | 122 | +            with InsertionPoint(loop_block): | 
|  | 123 | +                idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0]) | 
|  | 124 | +                val = memref.load(memref=copied, indices=[idx]) | 
|  | 125 | +                memref.store(value=val, memref=created, indices=[idx]) | 
|  | 126 | +                openacc.YieldOp([]) | 
|  | 127 | + | 
|  | 128 | +            openacc.YieldOp([]) | 
|  | 129 | + | 
|  | 130 | +        deleted = openacc.delete( | 
|  | 131 | +            acc_var=copied, | 
|  | 132 | +            bounds=[], | 
|  | 133 | +            async_operands=[], | 
|  | 134 | +            implicit=False, | 
|  | 135 | +            structured=True, | 
|  | 136 | +        ) | 
|  | 137 | +        copied = openacc.copyout( | 
|  | 138 | +            acc_var=created, | 
|  | 139 | +            var=arg1, | 
|  | 140 | +            var_type=types.f32(), | 
|  | 141 | +            bounds=[], | 
|  | 142 | +            async_operands=[], | 
|  | 143 | +            implicit=False, | 
|  | 144 | +            structured=True, | 
|  | 145 | +        ) | 
|  | 146 | +        func.ReturnOp([]) | 
|  | 147 | + | 
|  | 148 | +    print(module) | 
|  | 149 | + | 
|  | 150 | +    # CHECK: TEST: testParallelMemcpy | 
|  | 151 | +    # CHECK-LABEL:   func.func public @memcpy_idiom( | 
|  | 152 | +    # CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: i64) { | 
|  | 153 | +    # CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1024 : i32 | 
|  | 154 | +    # CHECK:           %[[CONSTANT_1:.*]] = arith.constant 128 : i32 | 
|  | 155 | +    # CHECK:           %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ARG0]] : memref<?xf32>) -> memref<?xf32> | 
|  | 156 | +    # CHECK:           %[[CREATE_0:.*]] = acc.create varPtr(%[[ARG1]] : memref<?xf32>) -> memref<?xf32> | 
|  | 157 | +    # CHECK:           acc.parallel num_gangs({%[[CONSTANT_0]] : i32}) vector_length(%[[CONSTANT_1]] : i32) { | 
|  | 158 | +    # CHECK:             %[[CONSTANT_2:.*]] = arith.constant 0 : i64 | 
|  | 159 | +    # CHECK:             %[[CONSTANT_3:.*]] = arith.constant 1 : i64 | 
|  | 160 | +    # CHECK:             acc.loop gang control(%[[VAL_0:.*]] : i64) = (%[[CONSTANT_2]] : i64) to (%[[ARG2]] : i64)  step (%[[CONSTANT_3]] : i64) { | 
|  | 161 | +    # CHECK:               %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i64 to index | 
|  | 162 | +    # CHECK:               %[[LOAD_0:.*]] = memref.load %[[COPYIN_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32> | 
|  | 163 | +    # CHECK:               memref.store %[[LOAD_0]], %[[CREATE_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32> | 
|  | 164 | +    # CHECK:               acc.yield | 
|  | 165 | +    # CHECK:             } attributes {independent = [#acc.device_type<none>]} | 
|  | 166 | +    # CHECK:             acc.yield | 
|  | 167 | +    # CHECK:           } | 
|  | 168 | +    # CHECK:           acc.delete accPtr(%[[COPYIN_0]] : memref<?xf32>) | 
|  | 169 | +    # CHECK:           acc.copyout accPtr(%[[CREATE_0]] : memref<?xf32>) to varPtr(%[[ARG1]] : memref<?xf32>) | 
|  | 170 | +    # CHECK:           return | 
|  | 171 | +    # CHECK:         } | 
0 commit comments