Skip to content

Commit efc51fa

Browse files
authored
minor fix (#62)
1 parent 45229fd commit efc51fa

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

mlir/extras/dialects/ext/tensor.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
S = ShapedType.get_dynamic_size()
2727

2828

29-
def empty(
30-
*sizes: Sequence[Union[int, Value]], element_type: Type = None, loc=None, ip=None
31-
):
29+
def empty(*sizes: Union[int, Value], element_type: Type = None, loc=None, ip=None):
3230
if loc is None:
3331
loc = get_user_code_loc()
3432
if element_type is None:
@@ -608,12 +606,7 @@ def _insert_slice(
608606
):
609607
if loc is None:
610608
loc = get_user_code_loc()
611-
612-
if isinstance(source, Scalar):
613-
source = expand_dims(source, (0,), loc=loc, ip=ip)
614-
615609
indexer = _indices_to_indexer(idx, dest.shape)
616-
617610
if indexer.is_constant():
618611
assert (
619612
indexer.static_sizes() == source.shape

mlir/extras/dialects/ext/transform.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,20 @@
88
from ...util import get_user_code_loc
99
from ....dialects import pdl
1010
from ....dialects import transform
11-
from ....dialects._ods_common import (
12-
_dispatch_mixed_values,
13-
)
14-
from ....dialects._ods_common import get_op_result_or_op_results
11+
from ....dialects._ods_common import _dispatch_mixed_values, get_op_result_or_op_results
1512
from ....dialects._structured_transform_ops_gen import (
1613
TileUsingForallOp,
1714
MatchOp,
1815
)
1916
from ....dialects.transform import *
17+
from ....dialects.transform import AnyOpType, AnyValueType, OperationType
2018
from ....dialects.transform.structured import TileUsingForOp
19+
from ....dialects.transform.loop import LoopUnrollOp
2120
from ....ir import Type, Operation, StringAttr, Attribute, Value
2221

2322
transform_fully_qualified_name = transform.__spec__.name
2423

2524

26-
def create_simple_namespace(name):
27-
return SimpleNamespace(__name__=name)
28-
29-
3025
# transform.apply_patterns is both a namespace and an op...
3126
delattr(transform, "apply_patterns")
3227

@@ -57,7 +52,8 @@ def create_simple_namespace(name):
5752

5853
for i, n in enumerate(namespaces[1:-1]):
5954
if not hasattr(simple_namespace, n):
60-
# dumb: without the prefix, this somehow always names the modules "mlir.dialect.module.transform.<n>" instead of suffixing
55+
# dumb: without the prefix, this somehow always names the modules
56+
# "mlir.dialect.module.transform.<n>" instead of suffixing
6157
setattr(
6258
simple_namespace,
6359
n,

0 commit comments

Comments
 (0)