diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 24e0878..158cfdb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -75,6 +75,7 @@ jobs: python examples/mwe.py python examples/flash_attention.py + python examples/liveness_analysis.py test-other-host-bindings: diff --git a/examples/liveness_analysis.py b/examples/liveness_analysis.py new file mode 100644 index 0000000..96a6d36 --- /dev/null +++ b/examples/liveness_analysis.py @@ -0,0 +1,189 @@ +from mlir import ir +from pathlib import Path + +import mlir.extras.types as T +import numpy as np +from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr + +from mlir.extras.ast.canonicalize import canonicalize +from mlir.extras.context import RAIIMLIRContextModule +from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm +from mlir.dialects import math + +# noinspection PyUnresolvedReferences +from mlir.extras.dialects.ext.gpu import ( + block_idx, + thread_idx, + grid_dim, + func as gpu_func, + set_container_module, + module, + get_compile_object_bytes, +) +from mlir.extras.runtime.passes import run_pipeline, Pipeline +from mlir.extras.util import find_ops, walk_blocks_in_operation, walk_operations +from mlir.extras.util.liveness import ( + BlockInfoBuilder, + Liveness, + LiveInterval, + linear_scan_register_allocation, +) + +# just so it doesn't get DCE'd by black/reformat +# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable +_ = memref + +ctx = RAIIMLIRContextModule() +set_container_module(ctx.module) + + +# just a default attr - actual target is set blow +@module("kernels", [f'#rocdl.target']) +def gpu_module(): + pass + + +ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0]) +ip.__enter__() + +Bc = 32 +Br = 32 + +B = 16 +nh = 12 +N = 128 +d = 128 + +softmax_scale = 1.0 / float(np.sqrt(d)) + + +rank_reduce = memref.rank_reduce + + +# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu +@gpu_func(emit=True) +@canonicalize(using=[scf.canonicalizer, arith.canonicalizer]) +def flash_attention( + Q: T.memref(B, nh, N, d, T.f32()), + K: T.memref(B, nh, N, d, T.f32()), + V: T.memref(B, nh, N, d, T.f32()), + l: T.memref(B, nh, N, T.f32()), + m: T.memref(B, nh, N, T.f32()), + O: T.memref(B, nh, N, d, T.f32()), +): + tx = thread_idx.x + # batch idx, head_idx + bx, by = block_idx.x, block_idx.y + # gpu.printf("bx %ld, by %ld\n", bx, by) + + # Offset into Q,K,V,O,l,m - different for each batch and head + K = K[bx, by, :, :, rank_reduce] + V = V[bx, by, :, :, rank_reduce] + Q = Q[bx, by, :, :, rank_reduce] + O = O[bx, by, :, :, rank_reduce] + l = l[bx, by, :, rank_reduce] + m = m[bx, by, :, rank_reduce] + + # Define SRAM for Q,K,V,S + sram = gpu.dynamic_shared_memory() + Qi = memref.view(sram, (Br, d), dtype=T.f32()) + Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements) + Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements) + S = memref.view( + sram, + (Br, Bc), + dtype=T.f32(), + shift=Qi.n_elements + Kj.n_elements + Vj.n_elements, + ) + + for bc in scf.range_(0, N, Bc): + # Load Kj, Vj to SRAM + K_ = K[bc : bc + 1, :] + V_ = V[bc : bc + 1, :] + for x in scf.range_(0, d): + Kj[tx, x] = K_[tx, x] + Vj[tx, x] = V_[tx, x] + + for br in scf.range_(0, N, Br): + # Load Qi to SRAM, l and m to registers + Q_ = Q[br : br + 1, :] + for x in scf.range_(0, d): + Qi[tx, x] = Q_[tx, x] + + l_ = l[br : br + 1] + m_ = m[br : br + 1] + row_l_prev = l_[tx] + row_m_prev = m_[tx] + + # S = QK^T, row_m = rowmax(S) + row_m: T.f32() = float(np.finfo(np.float32).min) + for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]): + sum: T.f32() = 0.0 + for x, sum, _ in scf.range_(0, d, iter_args=[sum]): + sum += Qi[tx, x] * Kj[y, x] + sum = yield sum + + sum *= softmax_scale + S[tx, y] = sum + + if sum > row_m: + row_m_ = yield sum + else: + row_m_ = yield row_m + + row_m = yield row_m_ + + # P = exp(S - row_m), row_l = rowsum(P) + row_l: T.f32() = 0.0 + for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]): + S[tx, y] = math.exp(S[tx, y] - row_m) + row_l += S[tx, y] + row_l = yield row_l + + # Compute new m and l + row_m_new = arith.maximumf(row_m_prev, row_m) + row_l_new = ( + math.exp(row_m_prev - row_m_new) * row_l_prev + + math.exp(row_m - row_m_new) * row_l + ) + div = 1.0 / row_l_new + f1 = row_l_prev * math.exp(row_m_prev - row_m_new) + f2 = math.exp(row_m - row_m_new) + + # Write O, l, m to HBM + O_ = O[br : br + 1, :] + for x in scf.range_(0, d): + pv: T.f32() = 0.0 # Pij * Vj + for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]): + pv += S[tx, y] * Vj[y, x] + pv = yield pv + + O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv) + + l_[tx] = row_l_new + m_[tx] = row_m_new + + gpu.barrier() + + +ip.__exit__(None, None, None) + +assert gpu_module.operation.verify() +# l = Liveness(gpu_module) +# print(l) + + +# https://langdev.stackexchange.com/questions/4325/how-do-modern-compilers-choose-which-variables-to-put-in-registers +x = LiveInterval(1, 3, "x") +t1 = LiveInterval(1, 2, "t1") +y = LiveInterval(2, 5, "y") +z = LiveInterval(3, 4, "z") +t2 = LiveInterval(4, 5, "t2") +y2 = LiveInterval(5, 6, "y2") + +register, location = linear_scan_register_allocation([x, t1, y, z, t2, y2], 2) + +for v, r in register.items(): + print(v, r) +for v, l in location.items(): + print(v, l) diff --git a/mlir/extras/util/__init__.py b/mlir/extras/util/__init__.py new file mode 100644 index 0000000..027aa23 --- /dev/null +++ b/mlir/extras/util/__init__.py @@ -0,0 +1,7 @@ +from .util import * +from .util import ( + _get_previous_frame_idents, + _get_sym_name, + _update_caller_vars, + _unpack_sizes_element_type, +) diff --git a/mlir/extras/util/liveness.py b/mlir/extras/util/liveness.py new file mode 100644 index 0000000..00a6248 --- /dev/null +++ b/mlir/extras/util/liveness.py @@ -0,0 +1,423 @@ +from collections import OrderedDict +from dataclasses import dataclass + +from sortedcontainers import SortedList + +from ...ir import ( + Block, + Value, + Operation, + OperationList, + BlockArgument, + OperationIterator, +) +from .util import ( + walk_blocks_in_operation, + walk_operations, + find_ancestor_block_in_region, + find_ancestor_op_in_block, +) + +# based on https://github.com/llvm/llvm-project/blob/07ae19c132e1b0adbdb3cc036b9f50624e2ed1b7/mlir/lib/Analysis/Liveness.cpp + + +def escapes_block(v: Value, b: Block): + # Check if value escapes, i.e., if there's a use + # which is in a block that is not "our" block + for use in v.uses: + user = use.owner + use_block = user.operation.block + # Find an owner block in the current region. Note that a value does not + # escape this block if it is used in a nested region. + use_block = find_ancestor_block_in_region(use_block) + if use_block != b: + return True + return False + + +class BlockInfoBuilder: + block = None + in_values = None + out_values = None + def_values = None + use_values = None + + def __init__(self, block): + self.block = block + self.in_values = set() + self.out_values = set() + self.def_values = set() + self.use_values = set() + + # Mark all block arguments (phis) as defined. + for arg in block.arguments: + self.def_values.add(arg) + # how the fuck can block args escape a block? + # answer: + # func.func @test(%arg0: i32, %arg1: i16) -> i16 { + # cf.br ^bb1(%arg1 : i16) + # ^bb1(%0: i16): // pred: ^bb0 + # cf.br ^bb2(%arg0 : i32) + # ^bb2(%1: i32): // pred: ^bb1 + # return %0 : i16 + # } + if escapes_block(arg, self.block): + self.out_values.add(arg) + + # Gather out values of all operations in the current block. + for op in block.operations: + for r in op.results: + if escapes_block(r, self.block): + self.out_values.add(r) + + # Mark all nested operation results as defined, and nested operation + # operands as used. All defined value will be removed from the used set + # at the end. + for op in block.operations: + for nested_op in walk_operations(op): + self.def_values |= set(nested_op.results) + self.use_values |= set(nested_op.operands) + for b in walk_blocks_in_operation(nested_op): + self.def_values |= set(b.arguments) + + self.use_values -= self.def_values + + # newIn = use U out - def + def update_livein(self): + new_in = (self.use_values | self.out_values) - self.def_values + # It is sufficient to check the set sizes (instead of their contents) since + # the live-in set can only grow monotonically during all update operations. + if len(new_in) == len(self.in_values): + return set() + self.in_values = new_in + return new_in + + # Updates live-out information of the current block. It iterates over all + # successors and unifies their live-in values with the current live-out + # values. + def update_liveout(self, builders): + for succ in self.block.successors: + self.out_values |= builders[succ].in_values + + +def build_block_mapping(op) -> dict[Block, BlockInfoBuilder]: + to_process = OrderedDict() + builders = {} + for b in walk_blocks_in_operation(op): + builder = builders[b] = BlockInfoBuilder(b) + if not builder.update_livein(): + continue + for p in b.predecessors: + if p in to_process: + continue + to_process[p] = True + + # Propagate the in and out-value sets (fixpoint iteration). + while to_process: + # Pairs are returned in LIFO order if last is true or FIFO order if false. + current, _ = to_process.popitem(last=True) + builder = builders[current] + builder.update_liveout(builders) + if not builder.update_livein(): + continue + for p in current.predecessors: + if p not in to_process: + to_process[p] = True + + return builders + + +class LivenessBlockInfo: + block: Block = None + in_values: set[Value] = None + out_values: set[Value] = None + + def __init__(self, block, in_values, out_values): + self.block = block + self.in_values = in_values + self.out_values = out_values + + def is_livein(self, v: Value): + return v in self.in_values + + def is_liveout(self, v: Value): + return v in self.out_values + + def get_start_operation(self, v: Value): + # The given value is either live-in or is defined + # in the scope of this block. + if self.is_livein(v) or isinstance(v, BlockArgument): + return self.block.operations[0] + return v.owner + + def get_end_operation(self, v: Value, start_op: Operation): + # The given value is either dying in this block or live-out. + if self.is_liveout(v): + return self.block.operations[-1] + # Resolve the last operation (must exist by definition). + end_op = start_op + for use in v.uses: + # Find the associated operation in the current block (if any). + # Check whether the use is in our block and after the current end + # operation. + if ( + use_op := find_ancestor_op_in_block(self.block, use.owner) + ) and end_op.is_before_in_block(use_op): + end_op = use_op + return end_op + + def currently_live_values(self, op: Operation): + live_set = set() + + # Given a value, check which ops are within its live range. For each of + # those ops, add the value to the set of live values as-of that op. + def add_value_to_currently_live_sets(value): + # Determine the live range of this value inside this block. + end_of_live_range = None + # If it's a livein or a block argument, then the start is the beginning + # of the block. + if self.is_livein(value) or isinstance(value, BlockArgument): + start_of_live_range = self.block.operations[0] + else: + start_of_live_range = find_ancestor_op_in_block(self.block, value.owner) + + # If it's a liveout, then the end is the back of the block. + if self.is_liveout(value): + end_of_live_range = self.block.operations[-1] + + # We must have at least a startOfLiveRange at this point. Given this, we + # can use the existing getEndOperation to find the end of the live range. + if start_of_live_range is not None and end_of_live_range is None: + end_of_live_range = self.get_end_operation(value, start_of_live_range) + + assert end_of_live_range, "Must have end_of_live_range at this point!" + # If this op is within the live range, insert the value into the set. + if not ( + op.is_before_in_block(start_of_live_range) + or end_of_live_range.is_before_in_block(op) + ): + live_set.add(value) + + for arg in self.block.arguments: + add_value_to_currently_live_sets(arg) + + # Handle live-ins. Between the live ins and all the op results that gives us + # every value in the block. + for value in self.in_values: + add_value_to_currently_live_sets(value) + + # Now walk the block and handle all values used in the block and values + # defined by the block. + for bop in self.block.operations: + for r in bop.results: + add_value_to_currently_live_sets(r) + if bop == op: + break + + return live_set + + +class Liveness: + operation = None + block_mapping = None + + def __init__(self, op): + self.operation = op + self.block_mapping = {} + + builders = build_block_mapping(self.operation) + for block, builder in builders.items(): + assert block == builder.block + self.block_mapping[block] = LivenessBlockInfo( + builder.block, builder.in_values, builder.out_values + ) + + def resolve_liveness(self, value: Value) -> OperationList: + result = [] + to_process = OrderedDict() + + # Start with the defining block + if isinstance(value, BlockArgument): + current_block = value.owner + else: + current_block = value.owner.operation.block + to_process[current_block] = True + + # Start with all associated blocks + for use in value.uses: + user = use.owner + use_block = user.operation.block + if use_block not in to_process: + to_process[use_block] = True + + while to_process: + current_block, _ = to_process.popitem(last=True) + block_info = self.block_mapping[current_block] + # Note that start and end will be in the same block. + start = block_info.get_start_operation(value) + end = block_info.get_end_operation(value, start) + + for op in OperationIterator(start.parent, start): + if start == end: + break + result.append(op) + for succ in current_block.successors: + if self.get_liveness(succ).is_livein(value) and succ not in to_process: + to_process[succ] = True + + return result + + def get_liveness(self, block: Block) -> LivenessBlockInfo | None: + if liveness := self.block_mapping.get(block): + return liveness + + def get_livein(self, block: Block) -> set[Value] | None: + if liveness := self.get_liveness(block): + return liveness.in_values + + def get_liveout(self, block: Block) -> set[Value] | None: + if liveness := self.get_liveness(block): + return liveness.out_values + + def is_dead_after(self, value: Value, op: Operation) -> bool: + block_info = self.get_liveness(op.operation.block) + if block_info.is_liveout(value): + return False + end_op = block_info.get_end_operation(value, op) + # If the operation is a real user of `value` the first check is sufficient. + # If not, we will have to test whether the end operation is executed before + # the given operation in the block. + return end_op == op or end_op.is_before_in_block(op) + + def __str__(self): + print("// ---- Liveness -----") + + # Builds unique block/value mappings for testing purposes. + block_ids: dict[Block, int] = {} + operation_ids: dict[Operation, int] = {} + value_ids: dict[Value, int] = {} + for block in walk_blocks_in_operation(self.operation): + block_ids[block] = len(block_ids) + for argument in block.arguments: + value_ids[argument] = len(value_ids) + for operation in block.operations: + operation_ids[operation] = len(operation_ids) + for result in operation.results: + value_ids[result] = len(value_ids) + + # Local printing helpers + def print_value_ref(value): + if isinstance(value, BlockArgument): + print(f"arg{value.arg_number}@{block_ids[value.owner]}", end=" ") + else: + print(f"val_{value_ids[value]}", end=" ") + + def print_value_refs(values: set[Value]): + ordered_values = sorted(list(values), key=lambda v: value_ids[v]) + for value in ordered_values: + print_value_ref(value) + + # Dump information about in and out values. + for block in walk_blocks_in_operation(self.operation): + print(f"// - Block: {block_ids[block]}") + liveness = self.get_liveness(block) + print("// --- LiveIn: ", end="") + print_value_refs(liveness.in_values) + print("\n// --- LiveOut: ", end="") + print_value_refs(liveness.out_values) + print() + + # Print liveness intervals. + print("// --- BeginLivenessIntervals", end="") + for op in block.operations: + if not op.results: + continue + print() + for result in op.results: + print("// ", end="") + print_value_ref(result) + print(":", end="") + live_operations = sorted( + list(self.resolve_liveness(result)), + key=lambda v: operation_ids[v], + ) + for operation in live_operations: + print("\n// ", end="") + print(operation) + print("\n// --- EndLivenessIntervals") + + # Print currently live values. + print("// --- BeginCurrentlyLive") + for op in block.operations: + currently_live = liveness.currently_live_values(op) + if not currently_live: + continue + print("// ", end="") + print(op) + print(" [", end="") + print_value_refs(currently_live) + print("\b]") + print("// --- EndCurrentlyLive") + + print("// -------------------") + + +@dataclass(frozen=True) +class LiveInterval: + start: int + end: int + name: str + + def __str__(self): + return f"{self.name}@[{self.start},{self.end}]" + + def __repr__(self): + return f"{self.name}@[{self.start},{self.end}]" + + +def linear_scan_register_allocation(intervals: list[LiveInterval], R: int): + active: list[LiveInterval] = SortedList(key=lambda i: i.end) + free_registers = set(range(R)) + register = {} + location = {} + + # expire intervals whose lifetimes have ended + # (remove from active and return the register to the free list) + def expire_old_intervals(i: LiveInterval): + for j in active: + if j.end > i.start: + return + # j.end < i.start + active.remove(j) + free_registers.add(register[j]) + + # spill either last ending + # or this interval + def spill_at_interval(i: LiveInterval): + spill = active[-1] + # if last ending ends after this interval + if spill.end > i.end: + # give its register to this interval + register[i] = register[spill] + # stack slot + location[spill] = len(location) + assert spill in active, "expected spill in active" + active.remove(spill) + active.add(i) + else: + # else this interval ends later so + # spill it (give it a stack slot) + location[i] = len(location) + + # sorted by start + intervals = sorted(intervals, key=lambda i: i.start) + for i in intervals: + expire_old_intervals(i) + # if max registers reached, spill + if len(active) == R: + spill_at_interval(i) + else: + register[i] = free_registers.pop() + active.add(i) + + return register, location \ No newline at end of file diff --git a/mlir/extras/util.py b/mlir/extras/util/util.py similarity index 90% rename from mlir/extras/util.py rename to mlir/extras/util/util.py index ba8a72f..eafbd6a 100644 --- a/mlir/extras/util.py +++ b/mlir/extras/util/util.py @@ -12,9 +12,9 @@ import numpy as np -from .meta import op_region_builder -from ..extras import types as T -from ..ir import ( +from ..meta import op_region_builder +from ...extras import types as T +from ...ir import ( Block, Context, F32Type, @@ -35,7 +35,7 @@ ) try: - from ..ir import TypeID + from ...ir import TypeID except ImportError: warnings.warn( f"TypeID not supported by host bindings; value casting won't work correctly" @@ -48,7 +48,7 @@ def is_relative_to(self, other): def get_user_code_loc(user_base: Optional[Path] = None): - from .. import extras + from ... import extras if Context.current is None: return @@ -105,23 +105,61 @@ def shlib_prefix(): return shlib_pref +def walk_blocks(block, pred=None): + if pred is None: + pred = lambda b: True + for op in block.operations: + for r in op.regions: + for b in r.blocks: + if pred(b): + yield b + yield from walk_blocks(b, pred) + + +def walk_blocks_in_operation(op, pred=None): + if pred is None: + pred = lambda b: True + for r in op.regions: + for b in r.blocks: + yield from walk_blocks(b, pred) + + +def walk_operations(op, pred=None): + if pred is None: + pred = lambda o: True + for r in op.regions: + for b in r.blocks: + for o in b.operations: + if pred(o): + yield o + yield from walk_operations(o, pred) + + +def find_ancestor_block_in_region(block: Block): + curr_block = block + while curr_block and curr_block.region != block.region: + parent_op = curr_block.owner + if not parent_op or not parent_op.operation.block: + return None + curr_block = parent_op.operation.block + return curr_block + + +def find_ancestor_op_in_block(block: Block, op: Operation): + curr_op = op + while curr_op and curr_op.operation.block != block: + curr_op = curr_op.parent + if not curr_op.parent: + return None + return curr_op + + def find_ops(op, pred: Callable[[OpView, Operation, Module], bool], single=False): if isinstance(op, (OpView, Module)): op = op.operation - matching = [] - - def find(op: Operation): - if single and len(matching): - return - for r in op.regions: - for b in r.blocks: - for o in b.operations: - if pred(o): - matching.append(o) - find(o) + matching = list(walk_operations(op, pred)) - find(op) if single and matching: matching = matching[0] return matching diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ca87e7..f4f8d67 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,3 @@ black inflection +pytest diff --git a/requirements.txt b/requirements.txt index 4098766..76a31e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ bytecode cloudpickle>=3.0.0 einspect @ git+https://github.com/makslevental/einspect@makslevental/bump-py.3.13 numpy>=1.19.5, <=2.1.2 +sortedcontainers==2.4.0 \ No newline at end of file