Skip to content

Commit 5024195

Browse files
ulysseBcopybara-github
authored andcommitted
Fix FreeOp shape.
FreeOp type parsing and printing was inconsistent. FreeOp shape was invalid in cases where the operand mapping was not the identity. Change FreeOp so that: * shape() returns `mapping(operand_shape)`, previously was `inverse_mapping(operand_shape)` * the printed type is `mapping(operand_type)`, previously was `operand_type` * parsed operand type remains `inverse_mapping(printed_type)`. PiperOrigin-RevId: 361498431
1 parent f91ac0e commit 5024195

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

sair_ops.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,8 @@ static void Print(SairFreeOp op, mlir::OpAsmPrinter &printer) {
781781
PrintValueAccess(op.Value(), printer);
782782
printer.printOptionalAttrDict(op->getAttrs(),
783783
{SairDialect::kMappingAttrName});
784-
printer << " : " << op.value().getType();
784+
mlir::Type element_type = op.Value().GetType().ElementType();
785+
printer << " : " << ValueType::get(op.shape(), element_type);
785786
}
786787

787788
mlir::LogicalResult Verify(SairFromScalarOp op) {
@@ -1648,8 +1649,8 @@ llvm::SmallVector<int, 2> SairFreeOp::SubDomains() {
16481649
}
16491650

16501651
DomainShapeAttr SairFreeOp::shape() {
1651-
return value().getType().cast<ValueType>().Shape().AccessedShape(
1652-
mapping_array()[0].cast<MappingAttr>());
1652+
ValueOperand value = Value();
1653+
return value.GetType().Shape().AccessedShape(value.Mapping().Inverse());
16531654
}
16541655

16551656
// Takes a mapping `lhs` and an array of mappings `rhs_array`. Returns a new

test/roundtrip.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,3 +656,17 @@ func @placeholder_with_loop_nest(%arg0: f32) {
656656
}
657657
return
658658
}
659+
660+
// CHECK-LABEL: @free_with_mapping
661+
func @free_with_mapping() {
662+
sair.program {
663+
%0 = sair.static_range 8 : !sair.range
664+
%1 = sair.alloc[d0:%0] : !sair.value<d0:range, memref<f32>>
665+
%2 = sair.placeholder : !sair.range
666+
%3 = sair.placeholder[d0:%2] : !sair.range<d0:range>
667+
sair.free[d0:%2, d1:%3] %1(unstripe(d0, d1, [4]))
668+
: !sair.value<d0:range x d1:range(d0), memref<f32>>
669+
sair.exit
670+
}
671+
return
672+
}

0 commit comments

Comments
 (0)