Skip to content

Commit 8b813f2

Browse files
authored
finish cuda opt (#86)
1 parent 22dc84f commit 8b813f2

File tree

11 files changed

+826
-87
lines changed

11 files changed

+826
-87
lines changed

examples/cuda_matmul_opt.py

Lines changed: 650 additions & 35 deletions
Large diffs are not rendered by default.

examples/mlir_python_extras.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"from mlir.extras.dialects.ext.arith import constant\n",
5757
"from mlir.extras.dialects.ext.memref import S\n",
5858
"from mlir.extras.dialects.ext.func import func\n",
59-
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range\n",
59+
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_\n",
6060
"from mlir.extras.runtime.passes import Pipeline, run_pipeline\n",
6161
"from mlir.extras.runtime.refbackend import LLVMJITBackend\n",
6262
"from mlir.ir import StridedLayoutAttr\n",
@@ -102,8 +102,8 @@
102102
" if one > two:\n",
103103
" C[0, 0] = constant(3, T.i64())\n",
104104
" else:\n",
105-
" for i in range(0, K):\n",
106-
" for j in range(0, K):\n",
105+
" for i in range_(0, K):\n",
106+
" for j in range_(0, K):\n",
107107
" C[i, j] = A[i, j] * B[i, j]"
108108
]
109109
},
@@ -457,17 +457,17 @@
457457
"def tile(\n",
458458
" A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32\n",
459459
"):\n",
460-
" for i in range(0, D):\n",
461-
" for j in range(0, D):\n",
460+
" for i in range_(0, D):\n",
461+
" for j in range_(0, D):\n",
462462
" C[i, j] = A[i, j] + B[i, j]\n",
463463
"\n",
464464
"@func(emit=True)\n",
465465
"@canonicalize(using=scf)\n",
466466
"def tiled_memfoo(\n",
467467
" A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n",
468468
"):\n",
469-
" for i in range(0, F):\n",
470-
" for j in range(0, F):\n",
469+
" for i in range_(0, F):\n",
470+
" for j in range_(0, F):\n",
471471
" l = lambda l: l * D\n",
472472
" r = lambda r: (r + 1) * D\n",
473473
" a, b, c = (\n",
@@ -797,8 +797,8 @@
797797
"def linalg_memfoo(\n",
798798
" A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n",
799799
"):\n",
800-
" for i in range(0, F):\n",
801-
" for j in range(0, F):\n",
800+
" for i in range_(0, F):\n",
801+
" for j in range_(0, F):\n",
802802
" l = lambda l: l * D\n",
803803
" r = lambda r: (r + 1) * D\n",
804804
" a, b, c = (\n",

mlir/extras/ast/canonicalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def transform_ast(
116116
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
117117
> n_lines
118118
) or (f.__code__.co_firstlineno != min([l for _, l in line_starts])):
119-
warnings.warn(
119+
logger.debug(
120120
"something went wrong with the line numbers for the rewritten/canonicalized function"
121121
)
122122
f.__code__ = new_f_code_o

mlir/extras/ast/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def copy_func(f, new_closure: Dict = None):
143143

144144
def append_hidden_node(node_body, new_node):
145145
last_statement = node_body[-1]
146+
assert (
147+
last_statement.end_lineno is not None
148+
), f"last_statement {ast.unparse(last_statement)} must have end_lineno"
146149
new_node = ast.fix_missing_locations(
147150
set_lineno(new_node, last_statement.end_lineno)
148151
)

mlir/extras/dialects/ext/arith.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import copy
23
import operator
34
from abc import abstractmethod
45
from copy import deepcopy
@@ -513,17 +514,20 @@ def visit_AugAssign(
513514
and isinstance(updated_node.value, ast.BinOp)
514515
and isinstance(updated_node.value.op, ast.Mult)
515516
):
517+
target = copy.deepcopy(updated_node.target)
518+
target.ctx = ast.Load()
516519
updated_node = ast.Assign(
517520
targets=[updated_node.target],
518521
value=ast_call(
519522
_FMA_BUILDER_NAME,
520523
[
521524
updated_node.value.left,
522525
updated_node.value.right,
523-
ast.Name(updated_node.target.id, ast.Load()),
526+
target,
524527
],
525528
),
526529
)
530+
updated_node = ast.fix_missing_locations(updated_node)
527531

528532
return updated_node
529533

mlir/extras/dialects/ext/gpu.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from typing import Any, List, Optional, Tuple, Union
44

5+
from mlir.dialects._gpu_enum_gen import AddressSpace
6+
57
from .arith import constant
68
from .func import FuncBase
79
from ... import types as T
@@ -117,32 +119,39 @@ def get_device_mapping_array_attr(
117119
return ArrayAttr.get(mapping, context=context)
118120

119121

120-
def device_mapping_attr(mnemonic, mapping_id_enum: MappingId):
122+
def gpu_attr(mnemonic, mapping_id_enum: MappingId):
121123
return Attribute.parse(f"#gpu.{mnemonic}<{mapping_id_enum}>")
122124

123125

124126
def thread_attr(thread):
125-
return device_mapping_attr("thread", thread)
127+
return gpu_attr("thread", thread)
126128

127129

128130
def block_attr(block):
129-
return device_mapping_attr("block", block)
131+
return gpu_attr("block", block)
130132

131133

132134
def warp_attr(warp):
133-
return device_mapping_attr("warp", warp)
135+
return gpu_attr("warp", warp)
134136

135137

136138
def warpgroup_attr(warpgroup):
137-
return device_mapping_attr("warpgroup", warpgroup)
139+
return gpu_attr("warpgroup", warpgroup)
138140

139141

140142
def address_space_attr(address_space: AddressSpace):
141-
return device_mapping_attr("address_space", address_space)
143+
return gpu_attr("address_space", address_space)
144+
145+
146+
_int = int
147+
142148

149+
def smem_space(int=False):
150+
a = AddressSpace.Workgroup
151+
if int:
152+
return _int(a)
143153

144-
def smem_space():
145-
return address_space_attr(AddressSpace.Workgroup)
154+
return address_space_attr(a)
146155

147156

148157
@_cext.register_operation(_Dialect, replace=True)
@@ -577,13 +586,29 @@ def printf(format, *args):
577586
_dynamic_shared_memory = dynamic_shared_memory
578587

579588

580-
def dynamic_shared_memory(*, loc=None, ip=None):
589+
def dynamic_shared_memory(*, int=False, loc=None, ip=None):
581590
return _dynamic_shared_memory(
582591
T.memref(
583592
ShapedType.get_dynamic_size(),
584593
element_type=T.i8(),
585-
memory_space=smem_space(),
594+
memory_space=smem_space(int),
586595
),
587596
loc=loc,
588597
ip=ip,
589598
)
599+
600+
601+
_memset = memset
602+
603+
604+
def memset(dst, value, async_dependencies=None, *, loc=None, ip=None):
605+
if loc is None:
606+
loc = get_user_code_loc()
607+
if async_dependencies is None:
608+
async_dependencies = []
609+
async_token = None
610+
if len(async_dependencies):
611+
async_token = gpu_async_token()
612+
if isinstance(value, (int, float, bool)):
613+
value = constant(value, type=dst.type.element_type)
614+
return _memset(async_token, async_dependencies, dst, value, loc=loc, ip=ip)

mlir/extras/dialects/ext/llvm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55

66
def llvm_ptr_t():
77
return Type.parse("!llvm.ptr")
8+

mlir/extras/dialects/ext/memref.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _alloc(
3333
sizes: Sequence[Union[int, Value]],
3434
element_type: Type,
3535
memory_space=None,
36+
alignment=None,
3637
loc=None,
3738
ip=None,
3839
):
@@ -52,21 +53,56 @@ def _alloc(
5253

5354
symbol_operands = []
5455
return get_op_result_or_op_results(
55-
op_ctor(result_type, dynamic_sizes, symbol_operands, loc=loc, ip=ip)
56+
op_ctor(
57+
result_type,
58+
dynamic_sizes,
59+
symbol_operands,
60+
alignment=alignment,
61+
loc=loc,
62+
ip=ip,
63+
)
5664
)
5765

5866

59-
def alloc(sizes: Union[int, Value], element_type: Type = None, memory_space=None):
60-
loc = get_user_code_loc()
67+
def alloc(
68+
sizes: Union[int, Value],
69+
element_type: Type = None,
70+
memory_space=None,
71+
alignment=None,
72+
loc=None,
73+
ip=None,
74+
):
75+
if loc is None:
76+
loc = get_user_code_loc()
6177
return _alloc(
62-
AllocOp, sizes, element_type, memory_space=memory_space, loc=loc, ip=None
78+
AllocOp,
79+
sizes,
80+
element_type,
81+
memory_space=memory_space,
82+
alignment=alignment,
83+
loc=loc,
84+
ip=ip,
6385
)
6486

6587

66-
def alloca(sizes: Union[int, Value], element_type: Type = None, memory_space=None):
67-
loc = get_user_code_loc()
88+
def alloca(
89+
sizes: Union[int, Value],
90+
element_type: Type = None,
91+
memory_space=None,
92+
alignment=None,
93+
loc=None,
94+
ip=None,
95+
):
96+
if loc is None:
97+
loc = get_user_code_loc()
6898
return _alloc(
69-
AllocaOp, sizes, element_type, memory_space=memory_space, loc=loc, ip=None
99+
AllocaOp,
100+
sizes,
101+
element_type,
102+
memory_space=memory_space,
103+
alignment=alignment,
104+
loc=loc,
105+
ip=ip,
70106
)
71107

72108

@@ -113,8 +149,9 @@ def __getitem__(self, idx: tuple) -> "MemRef":
113149
if idx is None:
114150
return expand_shape(self, (0,), loc=loc)
115151

116-
idx = list((idx,) if isinstance(idx, (int, slice)) else idx)
152+
idx = list((idx,) if isinstance(idx, (int, Scalar, slice)) else idx)
117153
for i, d in enumerate(idx):
154+
# TODO(max): rethink this since subview and etc probably take constant attributes?
118155
if isinstance(d, int):
119156
idx[i] = constant(d, index=True, loc=loc)
120157

@@ -123,7 +160,7 @@ def __getitem__(self, idx: tuple) -> "MemRef":
123160
else:
124161
return _subview(self, tuple(idx), loc=loc)
125162

126-
def __setitem__(self, idx, source):
163+
def __setitem__(self, idx, val):
127164
loc = get_user_code_loc()
128165

129166
if not self.has_rank():
@@ -135,12 +172,10 @@ def __setitem__(self, idx, source):
135172
idx[i] = constant(d, index=True, loc=loc)
136173

137174
if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape):
138-
assert isinstance(
139-
source, Scalar
140-
), "coordinate insert requires scalar element"
141-
store(source, self, idx, loc=loc)
175+
assert isinstance(val, Scalar), "coordinate insert requires scalar element"
176+
store(val, self, idx, loc=loc)
142177
else:
143-
_copy_to_subview(self, source, tuple(idx), loc=loc)
178+
_copy_to_subview(self, val, tuple(idx), loc=loc)
144179

145180

146181
def expand_shape(

mlir/extras/dialects/ext/scf.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from contextlib import contextmanager
44
from copy import deepcopy
5-
from typing import List
5+
from typing import List, Union, Optional, Sequence
66

77
from bytecode import ConcreteBytecode
88

@@ -18,6 +18,7 @@
1818
get_op_result_or_op_results,
1919
)
2020
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
21+
2122
# gotta come first
2223
from ....dialects.scf import *
2324
from ....dialects.scf import _Dialect, yield_ as yield__
@@ -432,13 +433,18 @@ def visit_If(self, updated_node: ast.If) -> ast.If:
432433
updated_node.orelse, deepcopy(new_yield)
433434
)
434435

436+
updated_node = ast.fix_missing_locations(updated_node)
435437
return updated_node
436438

437439
def visit_For(self, updated_node: ast.For) -> ast.For:
438-
updated_node = self.generic_visit(updated_node)
439-
new_yield = ast.Expr(ast.Yield(value=None))
440-
if not is_yield(updated_node.body[-1]):
441-
updated_node.body = append_hidden_node(updated_node.body, new_yield)
440+
# TODO(max): this isn't robust at all...
441+
line = ast.dump(updated_node.iter.func)
442+
if "range_" in line or "for_" in line:
443+
updated_node = self.generic_visit(updated_node)
444+
new_yield = ast.Expr(ast.Yield(value=None))
445+
if not is_yield(updated_node.body[-1]):
446+
updated_node.body = append_hidden_node(updated_node.body, new_yield)
447+
updated_node = ast.fix_missing_locations(updated_node)
442448
return updated_node
443449

444450

@@ -480,6 +486,7 @@ def visit_If(self, updated_node: ast.If) -> ast.If:
480486

481487
if needs_forward(updated_node.orelse):
482488
updated_node.orelse = forward_yield_from_nested_if(updated_node.orelse)
489+
updated_node = ast.fix_missing_locations(updated_node)
483490
return updated_node
484491

485492

@@ -515,6 +522,10 @@ def visit_While(self, updated_node: ast.While) -> List[ast.AST]:
515522
)
516523
new_test = ast.copy_location(new_test, updated_node)
517524
updated_node.test = new_test
525+
526+
updated_node = ast.fix_missing_locations(updated_node)
527+
assign = ast.fix_missing_locations(assign)
528+
518529
return [assign, updated_node]
519530

520531

mlir/extras/dialects/ext/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def insert_slice(
102102
)
103103

104104

105+
# TODO(max): unify vector/memref/tensor
105106
@register_value_caster(RankedTensorType.static_typeid)
106107
class Tensor(ShapedValue, ArithValue):
107108
def __getitem__(self, idx: tuple) -> "Tensor":

0 commit comments

Comments
 (0)