-
Notifications
You must be signed in to change notification settings - Fork 7
[transform] Basic example of applying a schedule to a payload #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
2de74a3
3a51e04
19920ba
971e7ef
08ee4a1
e94f333
d037047
6f29ea9
40f97cc
abc7403
fe03285
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| import tempfile | ||
rengolin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import subprocess | ||
|
|
||
| from mlir import ir | ||
| from mlir.dialects import arith, func, linalg, tensor, transform | ||
| from mlir.dialects.transform import structured | ||
|
|
||
| from lighthouse import transform as lh_transform | ||
rolfmorel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def example_payload() -> ir.Module: | ||
| payload = ir.Module.create() | ||
| with ir.InsertionPoint(payload.body): | ||
rengolin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| matrixType = ir.RankedTensorType.get([16, 16], ir.F32Type.get()) | ||
|
|
||
| @func.func(matrixType, matrixType) | ||
| def fold_add_on_two_matmuls(matrixA, matrixB): | ||
rolfmorel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| splat_float = ir.FloatAttr.get(ir.F32Type.get(), 1.111111) | ||
| splat_attr = ir.DenseElementsAttr.get_splat(matrixType, splat_float) | ||
| weights = arith.constant(matrixType, splat_attr) | ||
| c0 = arith.constant(ir.F32Type.get(), 0.0) | ||
| empty = tensor.empty(matrixType.shape, matrixType.element_type) | ||
| zero_init = linalg.fill(c0, outs=[empty]) | ||
| A_x_weights = linalg.matmul(matrixA, weights, outs=[zero_init]) | ||
| empty2 = tensor.empty(matrixType.shape, matrixType.element_type) | ||
| zero_init2 = linalg.fill(c0, outs=[empty2]) | ||
| B_x_weights = linalg.matmul(matrixB, weights, outs=[zero_init2]) | ||
| added = linalg.add(A_x_weights, B_x_weights, outs=[empty]) | ||
| return added | ||
|
|
||
| return payload | ||
|
|
||
|
|
||
| def example_schedule() -> ir.Module: | ||
| schedule = ir.Module.create() | ||
| schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() | ||
|
||
| with ir.InsertionPoint(schedule.body): | ||
| named_seq = transform.NamedSequenceOp( # TODO: fix snake_case wrapper upstream | ||
| sym_name="__transform_main", | ||
| input_types=[transform.any_op_t()], | ||
| result_types=[], | ||
| arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], | ||
| ) | ||
|
|
||
| with ir.InsertionPoint(named_seq.body): | ||
| func = structured.MatchOp.match_op_names( | ||
| named_seq.bodyTarget, ["func.func"] | ||
| ) # TODO: fix syntax upstream | ||
| with ir.InsertionPoint( | ||
| transform.apply_patterns(func).patterns.blocks.append() | ||
| ): # TODO: fix snake_case wrapper upstream | ||
| ir.Operation.create( | ||
| "transform.apply_patterns.linalg.fold_add_into_dest" | ||
| ) # TODO: expose dedicated builder upstream | ||
| transform.yield_([]) | ||
| return schedule | ||
|
|
||
|
|
||
| with ir.Context(), ir.Location.unknown(): | ||
| payload_module = example_payload() | ||
| print("NOTE: example payload module:") | ||
| print(payload_module) | ||
| schedule_module = example_schedule() | ||
| print("NOTE: example schedule module:") | ||
| print(schedule_module) | ||
|
|
||
| print("NOTE: output of applying schedule to payload directly within Python process:") | ||
| schedule = schedule_module.body.operations[0] | ||
| lh_transform.apply(schedule, payload_module) | ||
| print(payload_module) | ||
|
|
||
| with tempfile.NamedTemporaryFile( | ||
| "w", prefix="payload_" | ||
| ) as payload_file, tempfile.NamedTemporaryFile( | ||
| "w", prefix="schedule_" | ||
| ) as schedule_file: | ||
adam-smnk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print(payload_module, file=payload_file, flush=True) | ||
| print("NOTE: Have dumped payload to temp file:", payload_file.name) | ||
| print(schedule_module, file=schedule_file, flush=True) | ||
| print("NOTE: Have dumped schedule to temp file:", schedule_file.name) | ||
|
|
||
| cmdline = [ | ||
| "python", | ||
| "-m", | ||
| "lighthouse.transform", | ||
| schedule_file.name, | ||
| payload_file.name, | ||
| ] | ||
| print("NOTE: output of applying schedule to payload from commandline:", *cmdline) | ||
| subprocess.run(cmdline) | ||
adam-smnk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print( | ||
| f"NOTE: cleaning-up temp files: {payload_file.name}, {schedule_file.name}" | ||
| ) | ||
rengolin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .main import apply |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| import argparse | ||
| import sys | ||
|
|
||
| from mlir import ir | ||
|
|
||
| from .. import transform as lh_transform | ||
rolfmorel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| ArgParser = argparse.ArgumentParser() | ||
| ArgParser.add_argument("schedule") | ||
rolfmorel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ArgParser.add_argument("payload") | ||
| args = ArgParser.parse_args(sys.argv[1:]) | ||
|
|
||
| with ir.Context(), ir.Location.unknown(): | ||
| with open(args.schedule) as f: | ||
| schedule_module = ir.Module.parse(f.read()) | ||
| with open(args.payload) as f: | ||
| payload_module = ir.Module.parse(f.read()) | ||
|
|
||
| schedule = schedule_module.body.operations[0] | ||
| lh_transform.apply(schedule, payload_module) | ||
|
|
||
| print(payload_module) | ||
adam-smnk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| from mlir import ir | ||
| from mlir.dialects.transform import interpreter as transform_interpreter | ||
|
|
||
|
|
||
| def apply(schedule: ir.Operation | ir.OpView, payload: ir.Module) -> None: | ||
| assert schedule.parent and "transform.with_named_sequence" in schedule.parent.attributes | ||
| assert "transform.with_named_sequence" in schedule.parent.attributes | ||
adam-smnk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert isinstance(schedule.parent.attributes["transform.with_named_sequence"], ir.UnitAttr) | ||
|
|
||
| transform_interpreter.apply_named_sequence( | ||
| payload_root=payload, | ||
| transform_root=schedule, | ||
| transform_module=schedule.parent, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.