Skip to content

Commit 4cd5119

Browse files
authored
generics (#80)
1 parent aa08c7b commit 4cd5119

File tree

15 files changed

+814
-125
lines changed

15 files changed

+814
-125
lines changed

examples/cuda_matmul_opt.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from __future__ import annotations
2+
3+
import contextlib
4+
import math
5+
6+
import cupy as cp
7+
import mlir.extras.types as T
8+
import numpy as np
9+
from cupy.cuda import Module
10+
11+
from mlir.extras.ast.canonicalize import canonicalize
12+
from mlir.extras.context import (
13+
mlir_mod_ctx,
14+
MLIRContext,
15+
)
16+
from mlir.extras.dialects.ext import arith, memref, gpu, scf
17+
from mlir.extras.dialects.ext.gpu import (
18+
block_id,
19+
thread_id,
20+
block_dim,
21+
get_compile_object_bytes,
22+
)
23+
from mlir.extras.dialects.ext.scf import range_
24+
from mlir.extras.runtime.passes import Pipeline, run_pipeline
25+
26+
# noinspection PyUnresolvedReferences
27+
from mlir.extras.util import find_ops, enable_debug as enable_debug
28+
29+
# just so it doesn't get DCE'd by black/reformat
30+
_ = memref
31+
32+
33+
def build_cuda_func(compiled_module, kernel_name="mat_product_kernel"):
34+
ptx = get_compile_object_bytes(compiled_module)
35+
mod = Module()
36+
mod.load(ptx)
37+
return mod.get_function(kernel_name)
38+
39+
40+
@contextlib.contextmanager
41+
def time_cuda():
42+
start_gpu = cp.cuda.Event()
43+
end_gpu = cp.cuda.Event()
44+
45+
start_gpu.record()
46+
yield start_gpu, end_gpu
47+
end_gpu.record()
48+
end_gpu.synchronize()
49+
50+
51+
@gpu.func
52+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
53+
def mat_product_kernel[
54+
M, K, N, dtype
55+
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
56+
x = block_dim.x * block_id.x + thread_id.x
57+
y = block_dim.y * block_id.y + thread_id.y
58+
59+
one = arith.constant(1.0, type=dtype)
60+
tmp = arith.constant(0, type=dtype)
61+
for k, tmp in range_(K, iter_args=[tmp]):
62+
tmp += A[x, k] * B[k, y]
63+
tmp = yield tmp
64+
C[x, y] = tmp + one
65+
66+
67+
def main(ctx: MLIRContext, M, K, N, BLOCK_SIZE=32, repeat_times=50):
68+
dtype = T.f32()
69+
npy_dtype = np.float32
70+
71+
gpu.set_container_module(ctx.module)
72+
73+
@gpu.module("naive", ["#nvvm.target"])
74+
def _():
75+
mat_product_kernel[M, K, N, dtype].emit()
76+
77+
# print(ctx.module)
78+
ctx.module.operation.verify()
79+
80+
compiled_module = run_pipeline(
81+
ctx.module,
82+
Pipeline().add_pass(
83+
"gpu-lower-to-nvvm-pipeline",
84+
# https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
85+
**{
86+
"cubin-chip": "sm_80",
87+
"cubin-features": "+ptx83",
88+
"cubin-format": "isa",
89+
"kernel-bare-ptr-calling-convention": "1",
90+
# "cubin-format": "fatbin",
91+
# "cubin-format": "bin",
92+
},
93+
),
94+
)
95+
cuda_func = build_cuda_func(compiled_module)
96+
# print(compiled_module)
97+
# print_ptx(compiled_module)
98+
99+
A = np.random.randint(0, 10, (M, K)).astype(npy_dtype)
100+
B = np.random.randint(0, 10, (K, N)).astype(npy_dtype)
101+
C = np.zeros((M, N)).astype(npy_dtype)
102+
103+
dA = cp.asarray(A)
104+
dB = cp.asarray(B)
105+
dC = cp.asarray(C)
106+
107+
with time_cuda() as (start_gpu, end_gpu):
108+
for _ in range(repeat_times):
109+
cuda_func(
110+
(math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE), 1),
111+
(BLOCK_SIZE, BLOCK_SIZE, 1),
112+
(dA.data.ptr, dB.data.ptr, dC.data.ptr),
113+
)
114+
115+
t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu)
116+
117+
print(f"t_gpu={t_gpu / repeat_times:.6f} ms")
118+
119+
if not cp.array_equal(dC, dA @ dB + 1):
120+
print(dA @ dB + 1)
121+
print(dC)
122+
123+
124+
for s in [128, 256, 512, 1024]:
125+
with (
126+
mlir_mod_ctx() as ctx,
127+
# enable_debug()
128+
):
129+
main(ctx, s, s, s)

mlir/extras/ast/canonicalize.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
import inspect
55
import logging
66
import types
7+
import warnings
78
from abc import ABC, abstractmethod
89
from dis import findlinestarts
910
from opcode import opmap
1011
from types import CodeType
11-
from typing import List, Union
12+
from typing import List, Union, Sequence
1213

1314
import astunparse
1415
from bytecode import ConcreteBytecode
1516

16-
from ..ast.util import get_module_cst, copy_func
17+
from ..ast.util import get_module_cst
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -59,28 +60,63 @@ def transform_func(f, *transformer_ctors: type(Transformer)):
5960
return module
6061

6162

63+
# TODO(max): unify with `replace_closure` in ast/utils.py
64+
def insert_closed_vars(f, module):
65+
enclosing_mod = ast.FunctionDef(
66+
name="enclosing_mod",
67+
args=ast.arguments(
68+
posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]
69+
),
70+
body=[],
71+
decorator_list=[],
72+
)
73+
for var in f.__code__.co_freevars:
74+
enclosing_mod.body.append(
75+
ast.Assign(
76+
targets=[ast.Name(var, ctx=ast.Store())],
77+
value=ast.Constant(None, kind="None"),
78+
)
79+
)
80+
enclosing_mod.body.extend(module.body)
81+
module.body = [enclosing_mod]
82+
return module
83+
84+
85+
def find_func_in_code_object(co, func_name):
86+
for c in co.co_consts:
87+
if type(c) is CodeType:
88+
if c.co_name == func_name:
89+
return c
90+
else:
91+
f = find_func_in_code_object(c, func_name)
92+
if f is not None:
93+
return f
94+
95+
6296
def transform_ast(
6397
f, transformers: List[Union[type(Transformer), type(StrictTransformer)]] = None
6498
):
6599
if transformers is None:
66100
return f
67101

68102
module = transform_func(f, *transformers)
103+
if f.__closure__:
104+
module = insert_closed_vars(f, module)
69105
module = ast.fix_missing_locations(module)
70106
module = ast.increment_lineno(module, f.__code__.co_firstlineno - 1)
71107
module_code_o = compile(module, f.__code__.co_filename, "exec")
72-
new_f_code_o = next(
73-
c
74-
for c in module_code_o.co_consts
75-
if type(c) is CodeType and c.co_name == f.__name__
76-
)
108+
new_f_code_o = find_func_in_code_object(module_code_o, f.__name__)
77109
n_lines = len(inspect.getsource(f).splitlines())
78110
line_starts = list(findlinestarts(new_f_code_o))
79-
assert (
111+
if (
80112
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
81-
<= n_lines
82-
), f"something went wrong with the line numbers for the rewritten/canonicalized function"
83-
return copy_func(f, new_f_code_o)
113+
> n_lines
114+
) or (f.__code__.co_firstlineno != min([l for _, l in line_starts])):
115+
warnings.warn(
116+
"something went wrong with the line numbers for the rewritten/canonicalized function"
117+
)
118+
f.__code__ = new_f_code_o
119+
return f
84120

85121

86122
# this is like this because i couldn't figure out how to subclass
@@ -117,7 +153,8 @@ def patch_bytecode(f, patchers: List[type(BytecodePatcher)] = None):
117153
for patcher in patchers:
118154
code = patcher(context).patch_bytecode(code, f)
119155

120-
return copy_func(f, code.to_code())
156+
f.__code__ = code.to_code()
157+
return f
121158

122159

123160
class Canonicalizer(ABC):
@@ -132,10 +169,18 @@ def bytecode_patchers(self) -> List[BytecodePatcher]:
132169
pass
133170

134171

135-
def canonicalize(*, using: Canonicalizer):
172+
def canonicalize(*, using: Union[Canonicalizer, Sequence[Canonicalizer]]):
173+
if not isinstance(using, Sequence):
174+
using = [using]
175+
cst_transformers = []
176+
bytecode_patchers = []
177+
for u in using:
178+
cst_transformers.extend(u.cst_transformers)
179+
bytecode_patchers.extend(u.bytecode_patchers)
180+
136181
def wrapper(f):
137-
f = transform_ast(f, using.cst_transformers)
138-
f = patch_bytecode(f, using.bytecode_patchers)
182+
f = transform_ast(f, cst_transformers)
183+
f = patch_bytecode(f, bytecode_patchers)
139184
return f
140185

141186
return wrapper

mlir/extras/ast/util.py

Lines changed: 97 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
import functools
33
import inspect
44
import types
5+
from itertools import dropwhile
6+
from opcode import opmap
57
from textwrap import dedent
8+
from typing import Dict
9+
10+
from bytecode import ConcreteBytecode
11+
from cloudpickle import cloudpickle
612

713

814
def set_lineno(node, n=1):
@@ -26,8 +32,8 @@ def ast_call(name, args=None, keywords=None):
2632

2733

2834
def get_module_cst(f):
29-
f_src = dedent(inspect.getsource(f))
30-
# tree = cst.parse_module(f_src)
35+
lines, _lnum = inspect.getsourcelines(f)
36+
f_src = dedent("".join(list(dropwhile(lambda l: l.startswith("@"), lines))))
3137
tree = ast.parse(f_src)
3238
assert isinstance(
3339
tree.body[0], ast.FunctionDef
@@ -43,21 +49,89 @@ def bind(func, instance, as_name=None):
4349
return bound_method
4450

4551

46-
def copy_func(f, new_code):
47-
"""Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
52+
# based on https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Tools/build/deepfreeze.py#L48
53+
def get_localsplus_name_to_idx(code: types.CodeType):
54+
localsplus = code.co_varnames + code.co_cellvars + code.co_freevars
55+
return localsplus, {v: i for i, v in enumerate(localsplus)}
56+
57+
58+
class _empty_cell_value:
59+
"""Sentinel for empty closures."""
60+
61+
@classmethod
62+
def __reduce__(cls):
63+
return cls.__name__
64+
65+
66+
_empty_cell_value = _empty_cell_value()
67+
68+
69+
# based on https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L513
70+
def make_empty_cell():
71+
if False:
72+
# trick the compiler into creating an empty cell in our lambda
73+
cell = None
74+
raise AssertionError("this route should not be executed")
75+
76+
return (lambda: cell).__closure__[0]
77+
78+
79+
def make_cell(value=_empty_cell_value):
80+
cell = make_empty_cell()
81+
if value is not _empty_cell_value:
82+
cell.cell_contents = value
83+
return cell
84+
85+
86+
# based on https://github.com/python/cpython/blob/a4b44d39cd6941cc03590fee7538776728bdfd0a/Lib/test/test_code.py#L197
87+
def replace_closure(code, new_closure: Dict):
88+
COPY_FREE_VARS = opmap["COPY_FREE_VARS"]
89+
LOAD_DEREF = opmap["LOAD_DEREF"]
90+
91+
# get the orig localplus that will be loaded from by the orig bytecode LOAD_DEREF arg_i
92+
localsplus, localsplus_name_to_idx = get_localsplus_name_to_idx(code)
93+
94+
# closure vars go into co_freevars
95+
new_code = code.replace(co_freevars=tuple(new_closure.keys()))
96+
# closure is a tuple of cells
97+
closure = tuple(
98+
make_cell(v) if not isinstance(v, types.CellType) else v
99+
for v in new_closure.values()
100+
)
101+
102+
new_code = ConcreteBytecode.from_code(new_code)
103+
# update how many closure vars are loaded from frame
104+
# see https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Python/bytecodes.c#L1571
105+
assert new_code[0].opcode == COPY_FREE_VARS
106+
new_code[0].arg = len(closure)
107+
108+
# map orig localsplus arg_i to new localplus position/arg_i
109+
new_localsplus = new_code.varnames + new_code.cellvars + new_code.freevars
110+
new_localsplus_name_to_idx = {v: i for i, v in enumerate(new_localsplus)}
111+
for c in new_code:
112+
if c.opcode == LOAD_DEREF and c.arg < len(localsplus):
113+
c.arg = new_localsplus_name_to_idx[localsplus[c.arg]]
114+
new_code = new_code.to_code()
115+
116+
return new_code, closure
117+
118+
119+
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard);
120+
# potentially more complete approach https://stackoverflow.com/a/56901529/9045206
121+
def copy_func(f, new_closure: Dict = None):
122+
if new_closure is not None:
123+
code, closure = replace_closure(f.__code__, new_closure)
124+
else:
125+
code, closure = f.__code__, f.__closure__
126+
48127
g = types.FunctionType(
49-
code=new_code,
50-
globals={
51-
**f.__globals__,
52-
**{
53-
fr: f.__closure__[i].cell_contents
54-
for i, fr in enumerate(f.__code__.co_freevars)
55-
},
56-
},
128+
code=code,
129+
globals=f.__globals__,
57130
name=f.__name__,
58131
argdefs=f.__defaults__,
59-
# TODO(max): ValueError: foo requires closure of length 0, not 1
60-
# closure=f.__closure__,
132+
# see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813
133+
# for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars)
134+
closure=cloudpickle.loads(cloudpickle.dumps(closure)),
61135
)
62136
g.__kwdefaults__ = f.__kwdefaults__
63137
g.__dict__.update(f.__dict__)
@@ -66,3 +140,12 @@ def copy_func(f, new_code):
66140
if inspect.ismethod(f):
67141
g = bind(g, f.__self__)
68142
return g
143+
144+
145+
def append_hidden_node(node_body, new_node):
146+
last_statement = node_body[-1]
147+
new_node = ast.fix_missing_locations(
148+
set_lineno(new_node, last_statement.end_lineno)
149+
)
150+
node_body.append(new_node)
151+
return node_body

0 commit comments

Comments
 (0)