Skip to content

Commit 388d3e6

Browse files
authored
fix tensor pack -> linalg.pack (#123)
1 parent 15cb803 commit 388d3e6

File tree

3 files changed

+200
-41
lines changed

3 files changed

+200
-41
lines changed

mlir/extras/dialects/ext/linalg.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
from ...util import get_user_code_loc
33
from ....dialects import linalg
44

5+
from ....dialects._ods_common import (
6+
_dispatch_mixed_values,
7+
_cext,
8+
get_op_results_or_values,
9+
get_default_loc_context,
10+
get_op_result_or_op_results,
11+
get_default_loc_context,
12+
segmented_accessor,
13+
)
14+
515
# noinspection PyUnresolvedReferences
616
from ....dialects.linalg import *
717
from ....extras import types as T
@@ -524,3 +534,185 @@ def vecmat(y, A, x, *, loc=None, ip=None):
524534
if loc is None:
525535
loc = get_user_code_loc()
526536
return linalg.vecmat(y, A, loc=loc, ip=ip, outs=[x])
537+
538+
539+
@_cext.register_operation(linalg.Dialect)
540+
class PackOp(ir.OpView):
541+
OPERATION_NAME = "linalg.pack"
542+
543+
_ODS_OPERAND_SEGMENTS = [
544+
1,
545+
1,
546+
0,
547+
-1,
548+
]
549+
550+
_ODS_REGIONS = (0, True)
551+
552+
def __init__(
553+
self,
554+
source,
555+
dest,
556+
inner_dims_pos,
557+
inner_tiles,
558+
static_inner_tiles,
559+
*,
560+
padding_value=None,
561+
outer_dims_perm=None,
562+
loc=None,
563+
ip=None,
564+
):
565+
operands = []
566+
results = []
567+
attributes = {}
568+
regions = None
569+
operands.append(source)
570+
operands.append(dest)
571+
operands.append(padding_value)
572+
operands.append(get_op_results_or_values(inner_tiles))
573+
_ods_context = get_default_loc_context(loc)
574+
if outer_dims_perm is not None:
575+
attributes["outer_dims_perm"] = (
576+
outer_dims_perm
577+
if (
578+
isinstance(outer_dims_perm, ir.Attribute)
579+
or not ir.AttrBuilder.contains("DenseI64ArrayAttr")
580+
)
581+
else ir.AttrBuilder.get("DenseI64ArrayAttr")(
582+
outer_dims_perm, context=_ods_context
583+
)
584+
)
585+
attributes["inner_dims_pos"] = (
586+
inner_dims_pos
587+
if (
588+
isinstance(inner_dims_pos, ir.Attribute)
589+
or not ir.AttrBuilder.contains("DenseI64ArrayAttr")
590+
)
591+
else ir.AttrBuilder.get("DenseI64ArrayAttr")(
592+
inner_dims_pos, context=_ods_context
593+
)
594+
)
595+
attributes["static_inner_tiles"] = (
596+
static_inner_tiles
597+
if (
598+
isinstance(static_inner_tiles, ir.Attribute)
599+
or not ir.AttrBuilder.contains("DenseI64ArrayAttr")
600+
)
601+
else ir.AttrBuilder.get("DenseI64ArrayAttr")(
602+
static_inner_tiles, context=_ods_context
603+
)
604+
)
605+
_ods_successors = None
606+
super().__init__(
607+
self.OPERATION_NAME,
608+
self._ODS_REGIONS,
609+
self._ODS_OPERAND_SEGMENTS,
610+
self._ODS_RESULT_SEGMENTS,
611+
attributes=attributes,
612+
operands=operands,
613+
successors=_ods_successors,
614+
regions=regions,
615+
loc=loc,
616+
ip=ip,
617+
)
618+
619+
@property
620+
def source(self):
621+
operand_range = segmented_accessor(
622+
self.operation.operands, self.operation.attributes["operandSegmentSizes"], 0
623+
)
624+
return operand_range[0]
625+
626+
@property
627+
def dest(self):
628+
operand_range = segmented_accessor(
629+
self.operation.operands, self.operation.attributes["operandSegmentSizes"], 1
630+
)
631+
return operand_range[0]
632+
633+
@property
634+
def padding_value(self):
635+
operand_range = segmented_accessor(
636+
self.operation.operands, self.operation.attributes["operandSegmentSizes"], 2
637+
)
638+
return operand_range[0] if len(operand_range) > 0 else None
639+
640+
@property
641+
def inner_tiles(self):
642+
operand_range = segmented_accessor(
643+
self.operation.operands, self.operation.attributes["operandSegmentSizes"], 3
644+
)
645+
return operand_range
646+
647+
@property
648+
def outer_dims_perm(self):
649+
if "outer_dims_perm" not in self.operation.attributes:
650+
return None
651+
return self.operation.attributes["outer_dims_perm"]
652+
653+
@outer_dims_perm.setter
654+
def outer_dims_perm(self, value):
655+
if value is not None:
656+
self.operation.attributes["outer_dims_perm"] = value
657+
elif "outer_dims_perm" in self.operation.attributes:
658+
del self.operation.attributes["outer_dims_perm"]
659+
660+
@outer_dims_perm.deleter
661+
def outer_dims_perm(self):
662+
del self.operation.attributes["outer_dims_perm"]
663+
664+
@property
665+
def inner_dims_pos(self):
666+
return self.operation.attributes["inner_dims_pos"]
667+
668+
@inner_dims_pos.setter
669+
def inner_dims_pos(self, value):
670+
if value is None:
671+
raise ValueError("'None' not allowed as value for mandatory attributes")
672+
self.operation.attributes["inner_dims_pos"] = value
673+
674+
@property
675+
def static_inner_tiles(self):
676+
return self.operation.attributes["static_inner_tiles"]
677+
678+
@static_inner_tiles.setter
679+
def static_inner_tiles(self, value):
680+
if value is None:
681+
raise ValueError("'None' not allowed as value for mandatory attributes")
682+
self.operation.attributes["static_inner_tiles"] = value
683+
684+
@property
685+
def result(self):
686+
return self.operation.results[0]
687+
688+
689+
def pack(
690+
source,
691+
dest,
692+
inner_dims_pos,
693+
inner_tiles,
694+
*,
695+
padding_value=None,
696+
outer_dims_perm=None,
697+
loc=None,
698+
ip=None,
699+
) -> ir.Value:
700+
701+
(
702+
dynamic_inner_tiles,
703+
# packed here means %1:2 packing (results packing)
704+
_inner_tiles,
705+
static_inner_tiles,
706+
) = _dispatch_mixed_values(inner_tiles)
707+
708+
return PackOp(
709+
source=source,
710+
dest=dest,
711+
inner_dims_pos=inner_dims_pos,
712+
inner_tiles=dynamic_inner_tiles,
713+
static_inner_tiles=static_inner_tiles,
714+
padding_value=padding_value,
715+
outer_dims_perm=outer_dims_perm,
716+
loc=loc,
717+
ip=ip,
718+
).result

mlir/extras/dialects/ext/tensor.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -689,36 +689,3 @@ def pad_(
689689
generate = region_op(
690690
lambda result, dynamic_extents: tensor.GenerateOp(result, dynamic_extents)
691691
)
692-
693-
_pack = pack
694-
695-
696-
def pack(
697-
source,
698-
dest,
699-
inner_dims_pos,
700-
inner_tiles,
701-
*,
702-
padding_value=None,
703-
outer_dims_perm=None,
704-
loc=None,
705-
ip=None,
706-
):
707-
(
708-
dynamic_inner_tiles,
709-
# packed here means %1:2 packing (results packing)
710-
_inner_tiles,
711-
static_inner_tiles,
712-
) = _dispatch_mixed_values(inner_tiles)
713-
714-
return _pack(
715-
source=source,
716-
dest=dest,
717-
inner_dims_pos=inner_dims_pos,
718-
inner_tiles=dynamic_inner_tiles,
719-
static_inner_tiles=static_inner_tiles,
720-
padding_value=padding_value,
721-
outer_dims_perm=outer_dims_perm,
722-
loc=loc,
723-
ip=ip,
724-
)

tests/test_transform.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -990,22 +990,22 @@ def main(variant_op: any_op_t()):
990990
%c0_i32 = arith.constant 0 : i32
991991
%0 = tensor.empty() : tensor<16x256xi8>
992992
%1 = tensor.empty() : tensor<1x4x16x64xi8>
993-
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
993+
%pack = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
994994
%2 = tensor.empty() : tensor<4x1x64x64xi8>
995995
%3 = tensor.empty() : tensor<1x1x16x64xi8>
996996
%4 = linalg.fill ins(%c0_i32 : i32) outs(%3 : tensor<1x1x16x64xi8>) -> tensor<1x1x16x64xi8>
997997
%5 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %0) -> (tensor<16x256xi8>) {
998998
%6 = affine.apply #map(%arg3)
999999
%extracted_slice = tensor.extract_slice %arg1[0, %6] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
10001000
%extracted_slice_0 = tensor.extract_slice %arg4[0, %6] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
1001-
%pack_1 = tensor.pack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
1001+
%pack_1 = linalg.pack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
10021002
%7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_1 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%4 : tensor<1x1x16x64xi8>) {
10031003
^bb0(%in: i8, %in_2: i8, %out: i8):
10041004
%8 = arith.muli %in, %in_2 : i8
10051005
%9 = arith.addi %out, %8 : i8
10061006
linalg.yield %9 : i8
10071007
} -> tensor<1x1x16x64xi8>
1008-
%unpack = tensor.unpack %7 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_0 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
1008+
%unpack = linalg.unpack %7 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_0 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
10091009
scf.forall.in_parallel {
10101010
tensor.parallel_insert_slice %unpack into %arg4[0, %6] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
10111011
}
@@ -1050,7 +1050,7 @@ def tensor_pack(
10501050
src: T.tensor(129, 47, 16, 16, T.f32()),
10511051
dst: T.tensor(17, 2, 16, 16, 32, 8, T.f32()),
10521052
):
1053-
return tensor.pack(
1053+
return linalg.pack(
10541054
src,
10551055
dst,
10561056
inner_dims_pos=[1, 0],
@@ -1068,8 +1068,8 @@ def mod_transform():
10681068
def main(variant_op: any_op_t()):
10691069
packed = match(
10701070
variant_op,
1071-
ops=["tensor.pack"],
1072-
matched_op=transform_op_t("tensor.pack"),
1071+
ops=["linalg.pack"],
1072+
matched_op=transform_op_t("linalg.pack"),
10731073
)
10741074
lowered_pack = transform.structured.lower_pack(packed)
10751075

@@ -1097,8 +1097,8 @@ def main(variant_op: any_op_t()):
10971097
}
10981098
module attributes {transform.with_named_sequence} {
10991099
transform.named_sequence @main(%arg0: !transform.any_op) {
1100-
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.op<"tensor.pack">
1101-
%pad_op, %expand_shape_op, %transpose_op = transform.structured.lower_pack %0 : (!transform.op<"tensor.pack">) -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
1100+
%0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.op<"linalg.pack">
1101+
%pad_op, %expand_shape_op, %transpose_op = transform.structured.lower_pack %0 : (!transform.op<"linalg.pack">) -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
11021102
transform.yield
11031103
}
11041104
}

0 commit comments

Comments
 (0)