Skip to content

Commit a8e142b

Browse files
tpoppmemfrob
authored andcommitted
[mlir] Add shape.is_broadcastable.
This op returns a boolean value indicating whether 2 ops are broadcastable or not. This follows the same logic as the other ops with broadcast in their names in the shape dialect. Concretely, shape.is_broadcastable returning true implies that shape.broadcast will not give an error, and shape.cstr_broadcastable will not result in an assertion failure. Similarly, false implies an error or assertion failure.
1 parent 39d6529 commit a8e142b

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,27 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
181181
let assemblyFormat = "attr-dict $input `:` type($input)";
182182
}
183183

184+
def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
185+
let summary = "Determines if 2 shapes can be successfully broadcasted";
186+
let description = [{
187+
Given two input shapes or extent tensors, return a predicate specifying if
188+
they are broadcastable. This broadcastable follows the same logic as what
189+
shape.broadcast documents.
190+
191+
Example:
192+
```mlir
193+
%true = shape.is_broadcastable [2,2], [3,1,2]
194+
%false = shape.is_broadcastable [2,2], [3,2]
195+
```
196+
}];
197+
198+
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
199+
Shape_ShapeOrExtentTensorType:$rhs);
200+
let results = (outs I1:$result);
201+
202+
let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict";
203+
}
204+
184205
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
185206
let summary = "Gets the rank of a shape";
186207
let description = [{

mlir/test/Dialect/Shape/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,17 @@ func @any_on_extent_tensors(%a : tensor<?xindex>,
260260
: tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
261261
return %result : tensor<?xindex>
262262
}
263+
264+
func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>,
265+
%b : tensor<?xindex>) -> i1 {
266+
%result = shape.is_broadcastable %a, %b
267+
: tensor<?xindex>, tensor<?xindex>
268+
return %result : i1
269+
}
270+
271+
func @is_broadcastable_on_shapes(%a : !shape.shape,
272+
%b : !shape.shape) -> i1 {
273+
%result = shape.is_broadcastable %a, %b
274+
: !shape.shape, !shape.shape
275+
return %result : i1
276+
}

0 commit comments

Comments
 (0)