Skip to content

Commit 74e2bab

Browse files
authored
cuda_opt_v4_tensor_cores (#87)
1 parent 8b813f2 commit 74e2bab

File tree

9 files changed

+876
-125
lines changed

9 files changed

+876
-125
lines changed

examples/cuda_matmul_opt.py

Lines changed: 333 additions & 74 deletions
Large diffs are not rendered by default.

mlir/extras/ast/canonicalize.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44
import inspect
55
import logging
66
import types
7-
import warnings
87
from abc import ABC, abstractmethod
98
from dis import findlinestarts
109
from opcode import opmap
11-
from types import CodeType
1210
from typing import List, Union, Sequence
1311

1412
import astunparse
1513
from bytecode import ConcreteBytecode
1614

17-
from ..ast.util import get_module_cst, set_lineno
15+
from ..ast.util import get_module_cst, set_lineno, find_func_in_code_object
1816

1917
logger = logging.getLogger(__name__)
2018

@@ -86,17 +84,6 @@ def insert_closed_vars(f, module):
8684
return module
8785

8886

89-
def find_func_in_code_object(co, func_name):
90-
for c in co.co_consts:
91-
if type(c) is CodeType:
92-
if c.co_name == func_name:
93-
return c
94-
else:
95-
f = find_func_in_code_object(c, func_name)
96-
if f is not None:
97-
return f
98-
99-
10087
def transform_ast(
10188
f, transformers: List[Union[type(Transformer), type(StrictTransformer)]] = None
10289
):

mlir/extras/ast/util.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import functools
33
import inspect
44
import types
5-
from itertools import dropwhile
65
from opcode import opmap
76
from textwrap import dedent
8-
from typing import Dict
7+
from typing import Dict, TypeVar
98

109
from bytecode import ConcreteBytecode
1110
from cloudpickle import cloudpickle
11+
from einspect import ptr
12+
from einspect.structs import PyVarObject, PyObject
1213

1314

1415
def set_lineno(node, n=1):
@@ -101,7 +102,7 @@ def replace_closure(code, new_closure: Dict):
101102
new_code = ConcreteBytecode.from_code(new_code)
102103
# update how many closure vars are loaded from frame
103104
# see https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Python/bytecodes.c#L1571
104-
assert new_code[0].opcode == COPY_FREE_VARS
105+
assert new_code[0].opcode == COPY_FREE_VARS, f"{new_code[0].opcode=}"
105106
new_code[0].arg = len(closure)
106107

107108
# map orig localsplus arg_i to new localplus position/arg_i
@@ -121,16 +122,17 @@ def copy_func(f, new_closure: Dict = None):
121122
if new_closure is not None:
122123
code, closure = replace_closure(f.__code__, new_closure)
123124
else:
124-
code, closure = f.__code__, f.__closure__
125+
# see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813
126+
# for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars)
127+
closure = cloudpickle.loads(cloudpickle.dumps(f.__closure__))
128+
code = f.__code__
125129

126130
g = types.FunctionType(
127131
code=code,
128132
globals=f.__globals__,
129133
name=f.__name__,
130134
argdefs=f.__defaults__,
131-
# see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813
132-
# for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars)
133-
closure=cloudpickle.loads(cloudpickle.dumps(closure)),
135+
closure=closure,
134136
)
135137
g.__kwdefaults__ = f.__kwdefaults__
136138
g.__dict__.update(f.__dict__)
@@ -151,3 +153,24 @@ def append_hidden_node(node_body, new_node):
151153
)
152154
node_body.append(new_node)
153155
return node_body
156+
157+
158+
def find_func_in_code_object(co, func_name):
159+
for c in co.co_consts:
160+
if type(c) is types.CodeType:
161+
if c.co_name == func_name:
162+
return c
163+
else:
164+
f = find_func_in_code_object(c, func_name)
165+
if f is not None:
166+
return f
167+
168+
169+
_T = TypeVar("_T")
170+
171+
172+
# https://github.com/python/cpython/blob/809aa9a682fc865f7502e7421da0a74d204aab6d/Objects/typevarobject.c#L29
173+
class PyTypeVarObject(PyVarObject[_T, None, None]):
174+
name: ptr[PyObject]
175+
# not sure why but this is the only thing that works but that's fine because it's the only thing we need
176+
bound: ptr[PyObject]

mlir/extras/dialects/ext/func.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import copy
12
import inspect
23
import sys
34
from dataclasses import dataclass
45
from functools import update_wrapper
56
from typing import Optional, List, Union, TypeVar
67

7-
from ...ast.util import copy_func
8+
from ...ast.util import copy_func, PyTypeVarObject
89
from ...meta import op_region_builder
10+
from ... import types as T
911
from ...util import get_user_code_loc, make_maybe_no_args_decorator
1012
from ....dialects._ods_common import get_op_result_or_op_results
1113
from ....dialects.func import *
@@ -105,17 +107,17 @@ def prep_func_types(sig, return_types):
105107
return_types = [return_types]
106108
return_types = list(return_types)
107109
assert all(
108-
isinstance(r, Type) for r in return_types
109-
), f"all return types must be mlir types {return_types=}"
110+
isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types
111+
), f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}"
110112

111113
input_types = [
112114
p.annotation
113115
for p in sig.parameters.values()
114116
if not p.annotation is inspect.Signature.empty
115117
]
116118
assert all(
117-
isinstance(r, (str, Type)) or isalambda(r) for r in input_types
118-
), f"all input types must be mlir types {input_types=}"
119+
isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types
120+
), f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}"
119121
user_loc = get_user_code_loc()
120122
# If ir.Context is none (like for deferred func emit)
121123
if user_loc is None:
@@ -205,13 +207,15 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
205207
if self._func_op is None or decl or force:
206208
if len(call_args) == 0:
207209
input_types = self.input_types[:]
208-
locals = {}
210+
locals = {"T": T}
209211
if self.generics is not None:
210212
for t in self.generics:
211213
if not isinstance(t, ReifiedTypeParams):
212214
raise RuntimeError(f"{t=} must reified")
213215
locals[t.name] = t.val
214216
for i, v in enumerate(input_types):
217+
if isinstance(v, TypeVar):
218+
v = v.__name__
215219
if isinstance(v, str):
216220
input_types[i] = Type(
217221
eval(v, self.body_builder.__globals__, locals)
@@ -274,12 +278,38 @@ def __getitem__(self, item):
274278
# this also copies the function so that the original body_builder remains "generic" (via its closure)
275279
body_builder = copy_func(self.body_builder)
276280
reified_type_params = []
277-
for i, t in enumerate(self.generics):
278-
if t.__bound__ is not None:
279-
r = ReifiedTypeParams(t.__name__, t.__bound__)
281+
# dumb but whatever
282+
already_reified_type_params = {}
283+
generics = copy.deepcopy(self.generics)
284+
for i, t in enumerate(generics):
285+
if sys.version_info >= (3, 12):
286+
type_var_bound = PyTypeVarObject.try_from(t).bound
287+
else:
288+
type_var_bound = t.__bound__
289+
if type_var_bound:
290+
# before 3.12 typevar was just a python class
291+
# https://github.com/python/cpython/blob/3.11/Lib/typing.py#L966
292+
if sys.version_info < (3, 12):
293+
type_var_bound = lambda: type_var_bound
294+
else:
295+
type_var_bound = type_var_bound.contents.into_object()
296+
cvrs = inspect.getclosurevars(type_var_bound).nonlocals
297+
if len(cvrs):
298+
for k, v in cvrs.items():
299+
if not isinstance(v, TypeVar):
300+
continue
301+
if k not in already_reified_type_params:
302+
raise RuntimeError(
303+
f"typevar {k} not reified prior to evaluating dependent typevar {t}"
304+
)
305+
cvrs[k] = already_reified_type_params[k]
306+
type_var_bound = copy_func(type_var_bound, cvrs)
307+
r = ReifiedTypeParams(t.__name__, type_var_bound())
280308
else:
281309
r = ReifiedTypeParams(t.__name__, item[i])
310+
282311
reified_type_params.append(r)
312+
already_reified_type_params[r.name] = r.val
283313

284314
if t.__name__ in body_builder.__globals__:
285315
body_builder.__globals__[t.__name__] = r.val
@@ -290,8 +320,6 @@ def __getitem__(self, item):
290320
), "typevars don't match"
291321
body_builder.__closure__[free_i].cell_contents = r.val
292322

293-
generics = reified_type_params
294-
295323
return FuncBase(
296324
body_builder,
297325
self.func_op_ctor,
@@ -302,7 +330,7 @@ def __getitem__(self, item):
302330
arg_attrs=self.arg_attrs,
303331
res_attrs=self.res_attrs,
304332
func_attrs=self.func_attrs,
305-
generics=generics,
333+
generics=reified_type_params,
306334
qualname=self.qualname,
307335
loc=self.loc,
308336
ip=self.ip,

mlir/extras/dialects/ext/gpu.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functools import partial
33
from typing import Any, List, Optional, Tuple, Union
44

5-
from mlir.dialects._gpu_enum_gen import AddressSpace
65

76
from .arith import constant
87
from .func import FuncBase
@@ -119,8 +118,8 @@ def get_device_mapping_array_attr(
119118
return ArrayAttr.get(mapping, context=context)
120119

121120

122-
def gpu_attr(mnemonic, mapping_id_enum: MappingId):
123-
return Attribute.parse(f"#gpu.{mnemonic}<{mapping_id_enum}>")
121+
def gpu_attr(mnemonic, attr_value):
122+
return Attribute.parse(f"#gpu.{mnemonic}<{attr_value}>")
124123

125124

126125
def thread_attr(thread):

0 commit comments

Comments
 (0)