|
2 | 2 | from ...util import get_user_code_loc
|
3 | 3 | from ....dialects import linalg
|
4 | 4 |
|
| 5 | +from ....dialects._ods_common import ( |
| 6 | + _dispatch_mixed_values, |
| 7 | + _cext, |
| 8 | + get_op_results_or_values, |
| 9 | + get_default_loc_context, |
| 10 | + get_op_result_or_op_results, |
| 11 | + get_default_loc_context, |
| 12 | + segmented_accessor, |
| 13 | +) |
| 14 | + |
5 | 15 | # noinspection PyUnresolvedReferences
|
6 | 16 | from ....dialects.linalg import *
|
7 | 17 | from ....extras import types as T
|
@@ -524,3 +534,185 @@ def vecmat(y, A, x, *, loc=None, ip=None):
|
524 | 534 | if loc is None:
|
525 | 535 | loc = get_user_code_loc()
|
526 | 536 | return linalg.vecmat(y, A, loc=loc, ip=ip, outs=[x])
|
| 537 | + |
| 538 | + |
| 539 | +@_cext.register_operation(linalg.Dialect) |
| 540 | +class PackOp(ir.OpView): |
| 541 | + OPERATION_NAME = "linalg.pack" |
| 542 | + |
| 543 | + _ODS_OPERAND_SEGMENTS = [ |
| 544 | + 1, |
| 545 | + 1, |
| 546 | + 0, |
| 547 | + -1, |
| 548 | + ] |
| 549 | + |
| 550 | + _ODS_REGIONS = (0, True) |
| 551 | + |
| 552 | + def __init__( |
| 553 | + self, |
| 554 | + source, |
| 555 | + dest, |
| 556 | + inner_dims_pos, |
| 557 | + inner_tiles, |
| 558 | + static_inner_tiles, |
| 559 | + *, |
| 560 | + padding_value=None, |
| 561 | + outer_dims_perm=None, |
| 562 | + loc=None, |
| 563 | + ip=None, |
| 564 | + ): |
| 565 | + operands = [] |
| 566 | + results = [] |
| 567 | + attributes = {} |
| 568 | + regions = None |
| 569 | + operands.append(source) |
| 570 | + operands.append(dest) |
| 571 | + operands.append(padding_value) |
| 572 | + operands.append(get_op_results_or_values(inner_tiles)) |
| 573 | + _ods_context = get_default_loc_context(loc) |
| 574 | + if outer_dims_perm is not None: |
| 575 | + attributes["outer_dims_perm"] = ( |
| 576 | + outer_dims_perm |
| 577 | + if ( |
| 578 | + isinstance(outer_dims_perm, ir.Attribute) |
| 579 | + or not ir.AttrBuilder.contains("DenseI64ArrayAttr") |
| 580 | + ) |
| 581 | + else ir.AttrBuilder.get("DenseI64ArrayAttr")( |
| 582 | + outer_dims_perm, context=_ods_context |
| 583 | + ) |
| 584 | + ) |
| 585 | + attributes["inner_dims_pos"] = ( |
| 586 | + inner_dims_pos |
| 587 | + if ( |
| 588 | + isinstance(inner_dims_pos, ir.Attribute) |
| 589 | + or not ir.AttrBuilder.contains("DenseI64ArrayAttr") |
| 590 | + ) |
| 591 | + else ir.AttrBuilder.get("DenseI64ArrayAttr")( |
| 592 | + inner_dims_pos, context=_ods_context |
| 593 | + ) |
| 594 | + ) |
| 595 | + attributes["static_inner_tiles"] = ( |
| 596 | + static_inner_tiles |
| 597 | + if ( |
| 598 | + isinstance(static_inner_tiles, ir.Attribute) |
| 599 | + or not ir.AttrBuilder.contains("DenseI64ArrayAttr") |
| 600 | + ) |
| 601 | + else ir.AttrBuilder.get("DenseI64ArrayAttr")( |
| 602 | + static_inner_tiles, context=_ods_context |
| 603 | + ) |
| 604 | + ) |
| 605 | + _ods_successors = None |
| 606 | + super().__init__( |
| 607 | + self.OPERATION_NAME, |
| 608 | + self._ODS_REGIONS, |
| 609 | + self._ODS_OPERAND_SEGMENTS, |
| 610 | + self._ODS_RESULT_SEGMENTS, |
| 611 | + attributes=attributes, |
| 612 | + operands=operands, |
| 613 | + successors=_ods_successors, |
| 614 | + regions=regions, |
| 615 | + loc=loc, |
| 616 | + ip=ip, |
| 617 | + ) |
| 618 | + |
| 619 | + @property |
| 620 | + def source(self): |
| 621 | + operand_range = segmented_accessor( |
| 622 | + self.operation.operands, self.operation.attributes["operandSegmentSizes"], 0 |
| 623 | + ) |
| 624 | + return operand_range[0] |
| 625 | + |
| 626 | + @property |
| 627 | + def dest(self): |
| 628 | + operand_range = segmented_accessor( |
| 629 | + self.operation.operands, self.operation.attributes["operandSegmentSizes"], 1 |
| 630 | + ) |
| 631 | + return operand_range[0] |
| 632 | + |
| 633 | + @property |
| 634 | + def padding_value(self): |
| 635 | + operand_range = segmented_accessor( |
| 636 | + self.operation.operands, self.operation.attributes["operandSegmentSizes"], 2 |
| 637 | + ) |
| 638 | + return operand_range[0] if len(operand_range) > 0 else None |
| 639 | + |
| 640 | + @property |
| 641 | + def inner_tiles(self): |
| 642 | + operand_range = segmented_accessor( |
| 643 | + self.operation.operands, self.operation.attributes["operandSegmentSizes"], 3 |
| 644 | + ) |
| 645 | + return operand_range |
| 646 | + |
| 647 | + @property |
| 648 | + def outer_dims_perm(self): |
| 649 | + if "outer_dims_perm" not in self.operation.attributes: |
| 650 | + return None |
| 651 | + return self.operation.attributes["outer_dims_perm"] |
| 652 | + |
| 653 | + @outer_dims_perm.setter |
| 654 | + def outer_dims_perm(self, value): |
| 655 | + if value is not None: |
| 656 | + self.operation.attributes["outer_dims_perm"] = value |
| 657 | + elif "outer_dims_perm" in self.operation.attributes: |
| 658 | + del self.operation.attributes["outer_dims_perm"] |
| 659 | + |
| 660 | + @outer_dims_perm.deleter |
| 661 | + def outer_dims_perm(self): |
| 662 | + del self.operation.attributes["outer_dims_perm"] |
| 663 | + |
| 664 | + @property |
| 665 | + def inner_dims_pos(self): |
| 666 | + return self.operation.attributes["inner_dims_pos"] |
| 667 | + |
| 668 | + @inner_dims_pos.setter |
| 669 | + def inner_dims_pos(self, value): |
| 670 | + if value is None: |
| 671 | + raise ValueError("'None' not allowed as value for mandatory attributes") |
| 672 | + self.operation.attributes["inner_dims_pos"] = value |
| 673 | + |
| 674 | + @property |
| 675 | + def static_inner_tiles(self): |
| 676 | + return self.operation.attributes["static_inner_tiles"] |
| 677 | + |
| 678 | + @static_inner_tiles.setter |
| 679 | + def static_inner_tiles(self, value): |
| 680 | + if value is None: |
| 681 | + raise ValueError("'None' not allowed as value for mandatory attributes") |
| 682 | + self.operation.attributes["static_inner_tiles"] = value |
| 683 | + |
| 684 | + @property |
| 685 | + def result(self): |
| 686 | + return self.operation.results[0] |
| 687 | + |
| 688 | + |
| 689 | +def pack( |
| 690 | + source, |
| 691 | + dest, |
| 692 | + inner_dims_pos, |
| 693 | + inner_tiles, |
| 694 | + *, |
| 695 | + padding_value=None, |
| 696 | + outer_dims_perm=None, |
| 697 | + loc=None, |
| 698 | + ip=None, |
| 699 | +) -> ir.Value: |
| 700 | + |
| 701 | + ( |
| 702 | + dynamic_inner_tiles, |
| 703 | + # packed here means %1:2 packing (results packing) |
| 704 | + _inner_tiles, |
| 705 | + static_inner_tiles, |
| 706 | + ) = _dispatch_mixed_values(inner_tiles) |
| 707 | + |
| 708 | + return PackOp( |
| 709 | + source=source, |
| 710 | + dest=dest, |
| 711 | + inner_dims_pos=inner_dims_pos, |
| 712 | + inner_tiles=dynamic_inner_tiles, |
| 713 | + static_inner_tiles=static_inner_tiles, |
| 714 | + padding_value=padding_value, |
| 715 | + outer_dims_perm=outer_dims_perm, |
| 716 | + loc=loc, |
| 717 | + ip=ip, |
| 718 | + ).result |
0 commit comments