Skip to content

Commit 3a0cd9d

Browse files
authored
Pass to wrap StableHLO ops in composite (#2722)
Wraps StableHLO operations in `stablehlo.composite` operations. For instance, consider a simple StableHLO program: ```mlir func.func @main(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> { %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> return %0 : tensor<2xf32> } ``` Applying this pass to wrap `stablehlo.add` operations will result in the following program: ```mlir func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { %0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 {decomposition = @stablehlo.add.impl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> return %0 : tensor<2xf32> } ``` Notes: - The `name` attribute of the generated `stablehlo.composite` operation will always be the same as the name of the original operation that was wrapped (e.g., if you wrap a `stablehlo.add` operation, the composite will also be named `"stablehlo.add"`). - The private function that encapsulates the original operation (referenced by the `decomposition` attribute of the `stablehlo.composite` operation) will be named using the pattern `<op_name>.impl[.N]`, where `<op_name>` is the name of the original operation, and `N` is a unique integer identifier generated to prevent naming conflicts within the module. This pass can be used in three distinct ways: **Mode 1: Command-line Usage** This mode is the simplest, using the `stablehlo-opt` utility with the `op-names` (a comma-separated list of operation names) and `version` (an integer version number) options. It wraps **all instances** of specified operations. The attributes of the newly created `stablehlo.composite` operation will be the same as the attributes of the original operation. **Usage Example:** ```bash stablehlo-opt input.mlir --stablehlo-wrap-in-composite=op-names='stablehlo.add,stablehlo.mul' -o output.mlir ``` **Mode 2: Programmatic Single-Op Wrapping** This mode provides programmatic control to wrap **a specific operation instance** and returns a pointer to the newly created `stablehlo.composite` operation. **Example (C++):** ```cpp // To wrap a specific stablehlo.add instance mlir::stablehlo::AddOp addOp = ...; // The op instanced to be wrapped. mlir::ModuleOp module = addOp->getParentOfType<mlir::ModuleOp>(); mlir::OpBuilder builder(addOp); mlir::NamedAttrList attrs = ...; // Attributes to be set on the composite op. int32_t version = 0; // Composite version. mlir::stablehlo::CompositeOp compositeOp = mlir::stablehlo::wrapOperationInComposite(builder, addOp, attrs, version, module); addOp.replaceAllUsesWith(compositeOp); ``` **Mode 3: Programmatic Module-Wide Wrapping with Attribute Predicates** This mode extends programmatic wrapping to the entire module, offering fine-grained control over which operations are wrapped and their attributes. This is achieved by using the `createStablehloWrapInCompositePass` API, which takes an `AttributePredicateMap` as an argument. The `AttributePredicateMap` is a map that dictates which operations should be considered for wrapping and how their attributes should be handled. Its semantics are as follows: - **Keys (mlir::TypeID):** `TypeID` of an MLIR operation. If an operation's `TypeID` matches a key in the map, it becomes a candidate for wrapping. - **Values (Lambda Functions):** Lambda function of type `std::function<std::optional<NamedAttrList>(Operation*)>`. This function is applied to each candidate operation. - **Input:** An `mlir::Operation*`, which is an instance of the operation type corresponding to the `TypeID` key. - **Return Value:** An `std::optional<NamedAttrList>`. - If the lambda returns a `NamedAttrList` (wrapped in `std::optional`), the operation is wrapped in a `stablehlo::composite` operation, and the returned attributes are used to set the composite's attributes. - If the lambda returns `std::nullopt`, the operation is **not** wrapped. This allows for selective wrapping based on custom criteria. **Example (C++):** ```cpp // ... inside a pass or function ... stablehlo::AttributePredicateMap attributePredicateMap; attributePredicateMap[mlir::TypeID::get<mlir::stablehlo::AddOp>()] = [](mlir::Operation* op) -> std::optional<mlir::NamedAttrList> { // Custom logic to determine if and how to wrap the operation. // Example: Only wrap if it's on a specific type. if (op->getOperand(0).getType().isa<mlir::Float32Type>()) { return mlir::NamedAttrList(op->getAttrs()); } return std::nullopt; // Do not wrap. }; pm.addPass(createStablehloWrapInCompositePass(attributePredicateMap, compositeVersion)); if (mlir::failed(pm.run(module))) { return; } ```
1 parent 350021b commit 3a0cd9d

File tree

7 files changed

+603
-8
lines changed

7 files changed

+603
-8
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,7 @@ cc_library(
12461246
"stablehlo/transforms/StablehloLegalizeToVhlo.cpp",
12471247
"stablehlo/transforms/StablehloRefineArguments.cpp",
12481248
"stablehlo/transforms/StablehloRefineShapes.cpp",
1249+
"stablehlo/transforms/StablehloWrapInComposite.cpp",
12491250
"stablehlo/transforms/VhloLegalizeToStablehlo.cpp",
12501251
"stablehlo/transforms/VhloToVersion.cpp",
12511252
],

docs/generated/stablehlo_passes.md

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,8 @@ _Refines shapes across a StableHLO program._
301301
The flagship use case for this pass is specializing dynamically-shaped
302302
programs to static shapes. If a dynamically-shaped StableHLO program has the
303303
right structure, then updating its argument types from dynamic shapes to
304-
static shapes and running this pass will propagate static shapes across
305-
the program.
304+
static shapes and running this pass will propagate static shapes across the
305+
program.
306306

307307
This pass removes `custom_call @shape_refinement_operand_wrapper` by
308308
replacing uses of the result with the operand directly, and propagates
@@ -338,6 +338,121 @@ Modules valid for shape refinement must have the following properties:
338338
* All calls to a single function resolve to the same argument shapes, and no
339339
recursive / co-recursive function calls are made.
340340

341+
### `-stablehlo-wrap-in-composite`
342+
343+
_Wraps a non-composite StableHLO op in a composite op._
344+
345+
Wraps StableHLO operations in `stablehlo.composite` operations.
346+
347+
For instance, consider a simple StableHLO program:
348+
349+
```mlir
350+
func.func @main(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> {
351+
%0 = stablehlo.add %arg0, %arg1 : tensor<2xf32>
352+
return %0 : tensor<2xf32>
353+
}
354+
```
355+
356+
Applying this pass to wrap `stablehlo.add` operations will result in the
357+
following program:
358+
359+
```mlir
360+
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
361+
%0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 {decomposition = @stablehlo.add.impl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
362+
return %0 : tensor<2xf32>
363+
}
364+
func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
365+
%0 = stablehlo.add %arg0, %arg1 : tensor<2xf32>
366+
return %0 : tensor<2xf32>
367+
}
368+
```
369+
370+
Notes:
371+
372+
- The `name` attribute of the generated `stablehlo.composite` operation
373+
will always be the same as the name of the original operation that was
374+
wrapped (e.g., if you wrap a `stablehlo.add` operation, the composite
375+
will also be named `"stablehlo.add"`).
376+
- The private function that encapsulates the original operation
377+
(referenced by the `decomposition` attribute of the
378+
`stablehlo.composite` operation) will be named using the pattern
379+
`<op_name>.impl[.N]`, where `<op_name>` is the name of the original
380+
operation, and `N` is a unique integer identifier generated to prevent
381+
naming conflicts within the module.
382+
383+
This pass can be used in two distinct ways:
384+
385+
**Mode 1: Command-line Usage**
386+
387+
This mode is intended for debugging or testing, as it offers minimal control
388+
over the attributes of the generated `stablehlo.composite` operations.
389+
It wraps **all instances** of operations specified using the `op-names`
390+
(a comma-separated list of operation names) options. The attributes of the
391+
newly created `stablehlo.composite` operation will be the same as the
392+
attributes of the original operation.
393+
394+
**Usage Example:**
395+
396+
```bash
397+
stablehlo-opt input.mlir --stablehlo-wrap-in-composite=op-names='stablehlo.add,stablehlo.mul' -o output.mlir
398+
```
399+
400+
**Mode 2: Programmatic Module-Wide Wrapping with customized Attribute Handling**
401+
402+
This mode extends programmatic wrapping to the entire module, offering
403+
fine-grained control over which operations are wrapped and their attributes.
404+
This is achieved by using the `createStablehloWrapInCompositePass` API,
405+
which takes an `CompositeAttributeProviderMap` as an argument.
406+
407+
The `CompositeAttributeProviderMap` is a map that dictates which operations
408+
should be considered for wrapping and how their attributes should be
409+
handled. Its semantics are as follows:
410+
411+
- **Keys (mlir::TypeID):** `TypeID` of an MLIR operation. If an operation's
412+
`TypeID` matches a key in the map, it becomes a candidate for wrapping.
413+
- **Values (Lambda Functions):** Lambda function of type
414+
`std::function<std::optional<NamedAttrList>(Operation*)>`. This function
415+
is applied to each candidate operation.
416+
- **Input:** An `mlir::Operation*`, which is an instance of the
417+
operation type corresponding to the `TypeID` key.
418+
- **Return Value:** An `std::optional<NamedAttrList>`.
419+
- If the lambda returns a `NamedAttrList` (wrapped in
420+
`std::optional`), the operation is wrapped in a
421+
`stablehlo::composite` operation, and the returned attributes are
422+
used to set the composite's attributes.
423+
- If the lambda returns `std::nullopt`, the operation is **not**
424+
wrapped. This allows for selective wrapping based on custom
425+
criteria.
426+
427+
**Example (C++):**
428+
429+
```cpp
430+
431+
stablehlo::CompositeAttributeProviderMap compositeAttributeProviderMap;
432+
433+
compositeAttributeProviderMap[mlir::TypeID::get<mlir::stablehlo::AddOp>()] =
434+
[](mlir::Operation* op) -> std::optional<mlir::NamedAttrList> {
435+
// Custom logic to determine if and how to wrap the operation.
436+
// Example: Only wrap if it's on a specific type.
437+
if (op->getOperand(0).getType().isa<mlir::Float32Type>()) {
438+
return mlir::NamedAttrList(op->getAttrs());
439+
}
440+
return std::nullopt; // Do not wrap.
441+
};
442+
443+
pm.addPass(createStablehloWrapInCompositePass(compositeAttributeProviderMap, compositeVersion));
444+
if (mlir::failed(pm.run(module))) {
445+
return;
446+
}
447+
```
448+
449+
#### Options
450+
451+
```
452+
-op-names : The names of the ops to wrap.
453+
-version : The version number of the composite op.
454+
```
455+
341456
### `-vhlo-legalize-to-stablehlo`
342457

343458
_Legalize VHLO to StableHLO._
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: stablehlo-opt --stablehlo-wrap-in-composite='op-names=stablehlo.add,stablehlo.convolution,stablehlo.reduce version=1' --split-input-file --verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @wrap_in_composite
4+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>,
5+
// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>,
6+
// CHECK-SAME: %[[ARG_2:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
7+
// CHECK: %[[CONV:.*]] = stablehlo.composite "stablehlo.convolution" %[[ARG_0]], %[[ARG_1]] {
8+
// CHECK-SAME: composite_attributes = {batch_group_count = 1 : i64,
9+
// CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
10+
// CHECK-SAME: feature_group_count = 1 : i64,
11+
// CHECK-SAME{LITERAL}: padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
12+
// CHECK-SAME{LITERAL}: rhs_dilation = array<i64: 2, 2>,
13+
// CHECK-SAME{LITERAL}: window_strides = array<i64: 1, 1>},
14+
// CHECK-SAME: decomposition = @stablehlo.convolution.impl,
15+
// CHECK-SAME: version = 1 : i32} : (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
16+
// CHECK: %[[ADD:.*]] = stablehlo.composite "stablehlo.add" %[[CONV]], %[[ARG_2]] {
17+
// CHECK-SAME: decomposition = @stablehlo.add.impl,
18+
// CHECK-SAME: version = 1 : i32} : (tensor<64x3x3x32xi32>, tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32>
19+
// CHECK-NEXT: return %[[ADD]]
20+
21+
// CHECK-LABEL: func.func private @stablehlo.add.impl
22+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x3x3x32xi32>,
23+
// CHECK-SAME: %[[ARG_1:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
24+
// CHECK: %[[VAL:.*]] = stablehlo.add %[[ARG_0]], %[[ARG_1]] : tensor<64x3x3x32xi32>
25+
// CHECK-NEXT: return %[[VAL]]
26+
27+
// CHECK-LABEL: func.func private @stablehlo.convolution.impl
28+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>,
29+
// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32> {
30+
// CHECK: %[[VAL:.*]] = stablehlo.convolution(%[[ARG_0]], %[[ARG_1]])
31+
// CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
32+
// CHECK-SAME{LITERAL}: stride = [1, 1],
33+
// CHECK-SAME{LITERAL}: pad = [[0, 1], [0, 1]],
34+
// CHECK-SAME{LITERAL}: rhs_dilate = [2, 2]}
35+
// CHECK-SAME: batch_group_count = 1 : i64
36+
// CHECK-SAME: feature_group_count = 1 : i64
37+
// CHECK-SAME: : (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
38+
// CHECK-NEXT: return %[[VAL]]
39+
40+
func.func @wrap_in_composite(
41+
%arg0: tensor<64x8x8x8xi8>,
42+
%arg1: tensor<4x4x8x32xi8>,
43+
%arg2: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
44+
%0 = stablehlo.convolution(%arg0, %arg1)
45+
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
46+
window = {stride = [1, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [2, 2]}
47+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64} :
48+
(tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
49+
%1 = stablehlo.add %0, %arg2 : tensor<64x3x3x32xi32>
50+
func.return %1 : tensor<64x3x3x32xi32>
51+
}
52+
53+
// -----
54+
55+
// CHECK-LABEL: func.func @wrap_in_composite_op_with_region
56+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x3xf32>) -> tensor<4xf32>
57+
// CHECK: %[[CONST:.*]] = stablehlo.constant
58+
// CHECK-NEXT: %[[COMPOSITE_REDUCE:.*]] = stablehlo.composite "stablehlo.reduce" %[[ARG_0]], %[[CONST]] {
59+
// CHECK-SAME: composite_attributes = {
60+
// CHECK-SAME: dimensions = array<i64: 1>},
61+
// CHECK-SAME: decomposition = @stablehlo.reduce.impl,
62+
// CHECK-SAME: version = 1 : i32} : (tensor<4x3xf32>, tensor<f32>) -> tensor<4xf32>
63+
// CHECK-NEXT: return %[[COMPOSITE_REDUCE]]
64+
65+
// CHECK-LABEL: func.func private @stablehlo.reduce.impl
66+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x3xf32>,
67+
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<4xf32> {
68+
// CHECK: %[[REDUCE:.*]] = stablehlo.reduce(%[[ARG_0]] init: %[[ARG_1]])
69+
// CHECK-SAME{LITERAL}: applies stablehlo.add across dimensions = [1]
70+
// CHECK-SAME: (tensor<4x3xf32>, tensor<f32>) -> tensor<4xf32>
71+
// CHECK-NEXT: return %[[REDUCE]]
72+
func.func @wrap_in_composite_op_with_region(%x : tensor<4x3xf32>) -> tensor<4xf32> {
73+
%cst = stablehlo.constant dense<2.7> : tensor<f32>
74+
%res = stablehlo.reduce(%x init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<4x3xf32>, tensor<f32>) -> tensor<4xf32>
75+
func.return %res : tensor<4xf32>
76+
}
77+
78+
// -----
79+
80+
// CHECK-LABEL: func.func @cannot_be_wrapped_ops_does_not_match
81+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<2xf32>,
82+
// CHECK-SAME: %[[ARG_1:.*]]: tensor<2xf32>) -> tensor<2xf32> {
83+
// CHECK: %[[VAL:.*]] = stablehlo.multiply %[[ARG_0]], %[[ARG_1]] : tensor<2xf32>
84+
// CHECK-NEXT: return %[[VAL]] : tensor<2xf32>
85+
func.func @cannot_be_wrapped_ops_does_not_match(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
86+
%0 = stablehlo.multiply %arg0, %arg1 : tensor<2xf32>
87+
func.return %0 : tensor<2xf32>
88+
}

stablehlo/transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ add_mlir_dialect_library(StablehloPasses
5757
StablehloLegalizeToVhlo.cpp
5858
StablehloRefineArguments.cpp
5959
StablehloRefineShapes.cpp
60+
StablehloWrapInComposite.cpp
6061
VhloLegalizeToStablehlo.cpp
6162
VhloToVersion.cpp
6263
PassUtils.cpp

stablehlo/transforms/Passes.h

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,25 @@ limitations under the License.
1616
#ifndef STABLEHLO_TRANSFORMS_PASSES_H
1717
#define STABLEHLO_TRANSFORMS_PASSES_H
1818

19+
#include <cstdint>
20+
#include <functional>
1921
#include <memory>
22+
#include <optional>
2023

21-
#include "mlir/Dialect/Func/IR/FuncOps.h"
22-
#include "mlir/Dialect/Quant/IR/Quant.h"
23-
#include "mlir/Dialect/Shape/IR/Shape.h"
24+
#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep
25+
#include "mlir/Dialect/Quant/IR/Quant.h" // IWYU pragma: keep
26+
#include "mlir/IR/Builders.h"
2427
#include "mlir/IR/BuiltinOps.h"
28+
#include "mlir/IR/OperationSupport.h"
29+
#include "mlir/IR/PatternMatch.h"
30+
#include "mlir/IR/TypeRange.h"
2531
#include "mlir/Pass/Pass.h"
26-
#include "mlir/Support/LogicalResult.h"
32+
#include "mlir/Pass/PassOptions.h"
33+
#include "mlir/Support/LLVM.h"
34+
#include "mlir/Support/TypeID.h"
2735
#include "mlir/Transforms/DialectConversion.h"
2836
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37+
#include "stablehlo/dialect/StablehloOps.h"
2938
#include "stablehlo/dialect/Version.h"
3039

3140
namespace mlir {
@@ -102,6 +111,43 @@ void populateStablehloCompatibilityExpanderPatterns(
102111
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
103112
TypeRange refinedTypes);
104113

114+
/// Creates a pass that wraps StableHLO ops in CompositeOp.
115+
/// The pass takes in a map from op's type id to a function that returns the
116+
/// attributes to be added to the CompositeOp. The pass also takes in a
117+
/// version number for the CompositeOp.
118+
using CompositeAttributeProvider =
119+
std::function<std::optional<NamedAttrList>(Operation *)>;
120+
using CompositeAttributeProviderMap =
121+
llvm::DenseMap<mlir::TypeID, CompositeAttributeProvider>;
122+
std::unique_ptr<OperationPass<ModuleOp>> createStablehloWrapInCompositePass(
123+
const CompositeAttributeProviderMap &compositeAttributeProviderMap,
124+
int32_t compositeVersion);
125+
126+
/// Wraps the given operation in a CompositeOp with the specified NamedAttrs and
127+
/// version and returns the CompositeOp.
128+
///
129+
/// **A typical usage **
130+
///
131+
/// ```cpp
132+
/// // To wrap a specific stablehlo.add instance
133+
///
134+
/// mlir::stablehlo::AddOp addOp = ...; // The op instanced to be wrapped.
135+
/// mlir::ModuleOp module = addOp->getParentOfType<mlir::ModuleOp>();
136+
/// mlir::OpBuilder builder(addOp);
137+
/// mlir::NamedAttrList attrs = ...; // Attributes to be set on the
138+
/// // composite op.
139+
/// int32_t version = 0; // Composite version.
140+
///
141+
/// mlir::stablehlo::CompositeOp compositeOp =
142+
/// mlir::stablehlo::wrapOperationInComposite(builder, addOp, attrs,
143+
/// version, module);
144+
/// addOp.replaceAllUsesWith(compositeOp);
145+
/// ```
146+
stablehlo::CompositeOp wrapOperationInComposite(OpBuilder &builder,
147+
Operation *op,
148+
const NamedAttrList &attrs,
149+
int32_t compositeVersion,
150+
ModuleOp module);
105151
//// Pass pipelines ////
106152

107153
// StableHLO consumers can add this pipeline to convert portable artifacts to

0 commit comments

Comments
 (0)