Commit 3a0cd9d
authored
Pass to wrap StableHLO ops in composite (#2722)
Wraps StableHLO operations in `stablehlo.composite` operations.
For instance, consider a simple StableHLO program:
```mlir
func.func @main(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) ->
tensor<2xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<2xf32>
return %0 : tensor<2xf32>
}
```
Applying this pass to wrap `stablehlo.add` operations will result in the
following program:
```mlir
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) ->
tensor<2xf32> {
%0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 {decomposition =
@stablehlo.add.impl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1:
tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<2xf32>
return %0 : tensor<2xf32>
}
```
Notes:
- The `name` attribute of the generated `stablehlo.composite` operation
will always be the same as the name of the original operation that was
wrapped (e.g., if you wrap a `stablehlo.add` operation, the composite
will also be named `"stablehlo.add"`).
- The private function that encapsulates the original operation
(referenced by the `decomposition` attribute of the
`stablehlo.composite` operation) will be named using the pattern
`<op_name>.impl[.N]`, where `<op_name>` is the name of the original
operation, and `N` is a unique integer identifier generated to prevent
naming conflicts within the module.
This pass can be used in three distinct ways:
**Mode 1: Command-line Usage**
This mode is the simplest, using the `stablehlo-opt` utility with the
`op-names` (a comma-separated list of operation names) and `version` (an
integer version number) options. It wraps **all instances** of specified
operations. The attributes of the newly created `stablehlo.composite`
operation will be the same as the attributes of the original operation.
**Usage Example:**
```bash
stablehlo-opt input.mlir
--stablehlo-wrap-in-composite=op-names='stablehlo.add,stablehlo.mul' -o
output.mlir
```
**Mode 2: Programmatic Single-Op Wrapping**
This mode provides programmatic control to wrap
**a specific operation instance** and returns a pointer to the newly
created `stablehlo.composite` operation.
**Example (C++):**
```cpp
// To wrap a specific stablehlo.add instance
mlir::stablehlo::AddOp addOp = ...; // The op instanced to be wrapped.
mlir::ModuleOp module = addOp->getParentOfType<mlir::ModuleOp>();
mlir::OpBuilder builder(addOp);
mlir::NamedAttrList attrs = ...; // Attributes to be set on the
composite op.
int32_t version = 0; // Composite version.
mlir::stablehlo::CompositeOp compositeOp =
mlir::stablehlo::wrapOperationInComposite(builder, addOp, attrs,
version, module);
addOp.replaceAllUsesWith(compositeOp);
```
**Mode 3: Programmatic Module-Wide Wrapping with Attribute Predicates**
This mode extends programmatic wrapping to the entire module, offering
fine-grained control over which operations are wrapped and their
attributes.
This is achieved by using the `createStablehloWrapInCompositePass` API,
which takes an `AttributePredicateMap` as an argument.
The `AttributePredicateMap` is a map that dictates which operations
should
be considered for wrapping and how their attributes should be handled.
Its
semantics are as follows:
- **Keys (mlir::TypeID):** `TypeID` of an MLIR operation. If an
operation's
`TypeID` matches a key in the map, it becomes a candidate for wrapping.
- **Values (Lambda Functions):** Lambda function of type
`std::function<std::optional<NamedAttrList>(Operation*)>`. This function
is applied to each candidate operation.
- **Input:** An `mlir::Operation*`, which is an instance of the
operation type corresponding to the `TypeID` key.
- **Return Value:** An `std::optional<NamedAttrList>`.
- If the lambda returns a `NamedAttrList` (wrapped in
`std::optional`), the operation is wrapped in a
`stablehlo::composite` operation, and the returned attributes are
used to set the composite's attributes.
- If the lambda returns `std::nullopt`, the operation is **not**
wrapped. This allows for selective wrapping based on custom
criteria.
**Example (C++):**
```cpp
// ... inside a pass or function ...
stablehlo::AttributePredicateMap attributePredicateMap;
attributePredicateMap[mlir::TypeID::get<mlir::stablehlo::AddOp>()] =
[](mlir::Operation* op) -> std::optional<mlir::NamedAttrList> {
// Custom logic to determine if and how to wrap the operation.
// Example: Only wrap if it's on a specific type.
if (op->getOperand(0).getType().isa<mlir::Float32Type>()) {
return mlir::NamedAttrList(op->getAttrs());
}
return std::nullopt; // Do not wrap.
};
pm.addPass(createStablehloWrapInCompositePass(attributePredicateMap,
compositeVersion));
if (mlir::failed(pm.run(module))) {
return;
}
```1 parent 350021b commit 3a0cd9d
File tree
7 files changed
+603
-8
lines changed- docs/generated
- stablehlo
- tests/transforms
- transforms
7 files changed
+603
-8
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1246 | 1246 | | |
1247 | 1247 | | |
1248 | 1248 | | |
| 1249 | + | |
1249 | 1250 | | |
1250 | 1251 | | |
1251 | 1252 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
301 | 301 | | |
302 | 302 | | |
303 | 303 | | |
304 | | - | |
305 | | - | |
| 304 | + | |
| 305 | + | |
306 | 306 | | |
307 | 307 | | |
308 | 308 | | |
| |||
338 | 338 | | |
339 | 339 | | |
340 | 340 | | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
341 | 456 | | |
342 | 457 | | |
343 | 458 | | |
| |||
Lines changed: 88 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| 60 | + | |
60 | 61 | | |
61 | 62 | | |
62 | 63 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
| 20 | + | |
19 | 21 | | |
| 22 | + | |
20 | 23 | | |
21 | | - | |
22 | | - | |
23 | | - | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
24 | 27 | | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
25 | 31 | | |
26 | | - | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
27 | 35 | | |
28 | 36 | | |
| 37 | + | |
29 | 38 | | |
30 | 39 | | |
31 | 40 | | |
| |||
102 | 111 | | |
103 | 112 | | |
104 | 113 | | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
105 | 151 | | |
106 | 152 | | |
107 | 153 | | |
| |||
0 commit comments