Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions mlir/docs/Dialects/Linalg/OpDSL.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`:

```python
@linalg_structured_op
def fill(value=ScalarDef(T1),
O=TensorDef(U, output=True)):
O[None] = TypeFn.cast_signed(U, value)
def fill(value=ScalarDef(T),
O=TensorDef(T, output=True)):
O[None] = value
```

The operation sets the elements of the output tensor `O` to `value`. All
operands are either scalars or rank zero tensors that are accessed using the
index `None`. The operation thus performs a scalar computation that trivially
extends to a multi-dimensional pointwise computation. As a result, we may use
`fill` with arbitrary ranked output tensors:
The operation sets the elements of the output tensor `O` to `value`. The value
type must match the element type of the output tensor. All operands are either
scalars or rank zero tensors that are accessed using the index `None`. The
operation thus performs a scalar computation that trivially extends to a
multi-dimensional pointwise computation. As a result, we may use `fill` with
arbitrary ranked output tensors:

```python
tensor_2d = tensor.EmptyOp([4, 8], f32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata
doc: |-
Fills the output tensor with the given value.

Works for arbitrary ranked output tensors since the operation performs scalar
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
Works for arbitrary ranked output tensors since the operation performs
scalar accesses only and is thus rank polymorphic. The value operand
type must match the element type of the output.
implements:
- LinalgFillOpInterface
defines:
Expand All @@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: value
kind: scalar
type_var: T1
type_var: T
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
type_var: T
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
Expand All @@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: value
scalar_arg: value
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
Expand Down
57 changes: 44 additions & 13 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,35 +1057,66 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
// FillOpInterface implementation
//===----------------------------------------------------------------------===//

namespace {
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
NotScalarInput
NotScalarInput,
TypeMismatch
};

static MatchFillResult isFillInterfaceImpl(Operation *op) {
struct FillInterfaceResult {
MatchFillResult result = MatchFillResult::Success;
Type scalarType;
Type outputElementType;
};
} // namespace

static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
FillInterfaceResult fillResult = {};
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchFillResult::NotLinalgOp;
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return MatchFillResult::WrongNumOperands;
if (!linalgOp) {
fillResult.result = MatchFillResult::NotLinalgOp;
return fillResult;
}
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) {
fillResult.result = MatchFillResult::WrongNumOperands;
return fillResult;
}

OpOperand *value = linalgOp.getDpsInputOperand(0);
if (!linalgOp.isScalar(value))
return MatchFillResult::NotScalarInput;
if (!linalgOp.isScalar(value)) {
fillResult.result = MatchFillResult::NotScalarInput;
return fillResult;
}

// Check that the scalar input type matches the output element type.
OpOperand *output = linalgOp.getDpsInitOperand(0);
Type scalarType = value->get().getType();
Type outputElementType = getElementTypeOrSelf(output->get().getType());
if (scalarType != outputElementType) {
fillResult.result = MatchFillResult::TypeMismatch;
fillResult.scalarType = scalarType;
fillResult.outputElementType = outputElementType;
return fillResult;
}

return MatchFillResult::Success;
return fillResult;
}

LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
auto res = isFillInterfaceImpl(op);
if (res == MatchFillResult::NotLinalgOp)
auto [result, scalarType, outputElementType] = isFillInterfaceImpl(op);
if (result == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
if (res == MatchFillResult::WrongNumOperands)
if (result == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
if (res == MatchFillResult::NotScalarInput)
if (result == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
if (result == MatchFillResult::TypeMismatch)
return op->emitOpError("expected fill value type (")
<< scalarType << ") to match output element type ("
<< outputElementType << ")";

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1729,16 +1729,16 @@ def pooling_ndhwc_min(


@linalg_structured_op
def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
def fill(value=ScalarDef(T), O=TensorDef(T, output=True)):
"""Fills the output tensor with the given value.

Works for arbitrary ranked output tensors since the operation performs scalar
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
accesses only and is thus rank polymorphic. The value type must match the
element type of the output tensor or memref.
"""
implements(FillOpInterface)
defines(Canonicalizer)
O[None] = TypeFn.cast_signed(U, value)
O[None] = value


@linalg_structured_op
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Affine/value-bounds-reification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
// CHECK: "test.some_use"(%[[c5]])
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: "test.some_use"(%[[c5]])
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
scf.for %iv = %c0 to %ub step %c4 {
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
%filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
%filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>

%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound) : (index) -> ()
Expand Down
26 changes: 1 addition & 25 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {

// -----

// CHECK-LABEL: func @fold_fill_generic_different_dtype
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
#map0 = affine_map<(d0) -> (d0)>
func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 7.0 : f32
%0 = tensor.dim %arg0, %c0 : tensor<?xf16>
%1 = tensor.empty(%0) : tensor<?xf16>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
%3 = tensor.empty(%0) : tensor<?xf16>
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
%5 = arith.addf %arg1, %arg2 : f16
linalg.yield %5 : f16
} -> tensor<?xf16>
return %4 : tensor<?xf16>
}

// -----

// CHECK-LABEL: func @fold_fill_generic_mixedaccess
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
Expand Down Expand Up @@ -1079,4 +1055,4 @@ module {
// CHECK-NOT: linalg.generic
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t

// -----

func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
%0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
%0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
}

Expand All @@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {

// -----

func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
return
}

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Linalg/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return

// -----

func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
{
// expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
%0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

// -----

func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
{
// expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
%0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}

// -----

func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
{
// expected-error @+1 {{expected op with scalar input}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func.func @main() {
%A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
%B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>

%c0_i32 = arith.constant 0 : i32
%C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
%c0_f32 = arith.constant 0.0 : f32
%C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>

%res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Integration/Dialect/Transform/match_matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
}

func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
%cst = arith.constant 0.0 : f64
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<10x15xf32>

// expected-remark @below {{fill}}
%fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>

%real_lhs = linalg.mul
ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>
Expand Down
26 changes: 12 additions & 14 deletions mlir/test/python/integration/dialects/linalg/opsrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def log(*args):
%O1 = memref.alloc() : memref<16xi32>
%O2 = memref.alloc() : memref<4x16xi32>

%val0 = arith.constant 1.0 : f32
%val1 = arith.constant 2.0 : f32
%val2 = arith.constant 3.0 : f32
%val0 = arith.constant 1 : i32
%val1 = arith.constant 2 : i32
%val2 = arith.constant 3 : i32

call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> ()
call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> ()
call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> ()

%c0 = arith.constant 0 : index
%res0 = memref.load %O0[] : memref<i32>
Expand Down Expand Up @@ -149,19 +149,18 @@ def transform(module, boilerplate):
def test_fill_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):

@func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out])

@func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out])

@func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out])

Expand All @@ -184,19 +183,18 @@ def fill_2d_on_buffers(value, out):
def test_fill_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):

@func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)

@func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)

@func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)

Expand Down