|
| 1 | +# Simply demonstrates applying a schedule to a payload. |
| 2 | +# To do so generates a basic payload and a basic schedule, purely as an example. |
| 3 | + |
| 4 | +from mlir.ir import Context, Location, InsertionPoint, Operation, Module |
| 5 | +from mlir.ir import RankedTensorType, F32Type, UnitAttr |
| 6 | +from mlir.dialects import arith, func, linalg, tensor, transform |
| 7 | +from mlir.dialects.transform import structured |
| 8 | + |
| 9 | + |
| 10 | +def example_payload() -> Module: |
| 11 | + """IR for: |
| 12 | + Zero = ... |
| 13 | + X = matmul(..., C=Zero) |
| 14 | + Y = matmul(..., C=Zero) |
| 15 | + Res = add(X, Y) |
| 16 | +
|
| 17 | + Can be re-written to: |
| 18 | + X = matmul(..., C=Zero) |
| 19 | + Res = matmul(..., C=X) |
| 20 | + """ |
| 21 | + |
| 22 | + print("NOTE: example payload module:") |
| 23 | + payload = Module.create() |
| 24 | + with InsertionPoint(payload.body): |
| 25 | + matrixType = RankedTensorType.get([16, 16], F32Type.get()) |
| 26 | + |
| 27 | + @func.func(matrixType, matrixType, matrixType) |
| 28 | + def fold_add_on_two_matmuls(matrixA, matrixB, weights): |
| 29 | + empty = tensor.empty(matrixType.shape, matrixType.element_type) |
| 30 | + c0 = arith.constant(F32Type.get(), 0.0) |
| 31 | + zero_init = linalg.fill(c0, outs=[empty]) |
| 32 | + A_x_weights = linalg.matmul(matrixA, weights, outs=[zero_init]) |
| 33 | + empty2 = tensor.empty(matrixType.shape, matrixType.element_type) |
| 34 | + zero_init2 = linalg.fill(c0, outs=[empty2]) |
| 35 | + B_x_weights = linalg.matmul(matrixB, weights, outs=[zero_init2]) |
| 36 | + added = linalg.add(A_x_weights, B_x_weights, outs=[empty]) |
| 37 | + return added |
| 38 | + |
| 39 | + print(payload) |
| 40 | + return payload |
| 41 | + |
| 42 | + |
| 43 | +def example_schedule() -> Module: |
| 44 | + """Basic schedule wrapping a single rewrite pattern.""" |
| 45 | + |
| 46 | + print("NOTE: example schedule module:") |
| 47 | + schedule_module = Module.create() |
| 48 | + schedule_module.operation.attributes["transform.with_named_sequence"] = ( |
| 49 | + UnitAttr.get() |
| 50 | + ) |
| 51 | + with InsertionPoint(schedule_module.body): |
| 52 | + named_seq = transform.named_sequence( |
| 53 | + "__transform_main", |
| 54 | + input_types=[transform.any_op_t()], |
| 55 | + result_types=[], |
| 56 | + arg_attrs=[{"transform.readonly": UnitAttr.get()}], |
| 57 | + ) |
| 58 | + |
| 59 | + with InsertionPoint(named_seq.body): |
| 60 | + func = structured.MatchOp.match_op_names( |
| 61 | + named_seq.bodyTarget, ["func.func"] |
| 62 | + ) # TODO: fix syntax upstream |
| 63 | + with InsertionPoint(transform.apply_patterns(func).patterns): |
| 64 | + Operation.create( |
| 65 | + "transform.apply_patterns.linalg.fold_add_into_dest" |
| 66 | + ) # TODO: expose dedicated builder upstream |
| 67 | + transform.yield_([]) |
| 68 | + |
| 69 | + print(schedule_module) |
| 70 | + return schedule_module |
| 71 | + |
| 72 | + |
| 73 | +with Context(), Location.unknown(): |
| 74 | + payload = example_payload() |
| 75 | + schedule_module = example_schedule() |
| 76 | + # Actual schedule is defined by the contained transform.named_sequence: |
| 77 | + schedule: transform.NamedSequenceOp = schedule_module.body.operations[0] |
| 78 | + |
| 79 | + schedule.apply(payload) # The actual transformation happens here. |
| 80 | + |
| 81 | + print("NOTE: result of applying schedule to payload:") |
| 82 | + print(payload) |
0 commit comments