Skip to content

Commit 0b5b4a2

Browse files
authored
Add StableHLO shape assertion check pass and related tests (#2869)
This PR adds the CheckShapeAssertion from XLA into StableHLO: https://github.com/openxla/xla/blob/949bb09eb7d0bb563f5b564fd7d0782e0687796e/xla/python/refine_polymorphic_shapes.cc#L72-L98 This idea was discussed in the discord channel: https://discord.com/channels/999073994483433573/999074539138990131/1409946189721374803 This allows taking programs exported from JAX and using just the satblehlo binary to produce a program that can be compile and executed with PJRT plugins. ---- A usage example is for instance, export a program such as: ``` import jax from jax import export import jax.numpy as jnp import numpy as np def f1(x, y): # x: f32[a, 1], y : f32[a, 4] return x + y # Assuming you have some actual args with concrete shapes x = np.ones((3, 1), dtype=np.int32) y = np.ones((3, 4), dtype=np.int32) args_specs = export.symbolic_args_specs((x, y), 'a, ...') exp = export.export(jax.jit(f1))(* args_specs) code = exp.mlir_module() print(code) ``` This gives us stableHLO like: ``` #loc = loc(unknown) #loc2 = loc("x") #loc3 = loc("y") module @jit_f1 attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<?x1xi32> loc(unknown), %arg1: tensor<?x4xi32> loc(unknown)) -> (tensor<?x4xi32> {jax.result_info = "result"}) { %c = stablehlo.constant dense<1> : tensor<i32> loc(#loc) %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x1xi32>) -> tensor<i32> loc(#loc7) %1 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?x4xi32>) -> tensor<i32> loc(#loc7) %2 = stablehlo.compare GE, %0, %c, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc8) stablehlo.custom_call @shape_assertion(%2, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 1),args[1].shape = (a, 4). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> () loc(#loc9) %3 = stablehlo.compare EQ, %1, %0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc10) stablehlo.custom_call @shape_assertion(%3, %1, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'a' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (a, 1),args[1].shape = (a, 4). Obtained dimension variables: 'a' = {1} from specification 'a' for dimension args[0].shape[0] (= {1}), . Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>, tensor<i32>) -> () loc(#loc9) %4 = call @_wrapped_jax_export_main(%0, %arg0, %arg1) : (tensor<i32>, tensor<?x1xi32>, tensor<?x4xi32>) -> tensor<?x4xi32> loc(#loc) return %4 : tensor<?x4xi32> loc(#loc) } loc(#loc) func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = "a"} loc(unknown), %arg1: tensor<?x1xi32> loc("x"), %arg2: tensor<?x4xi32> loc("y")) -> (tensor<?x4xi32> {jax.result_info = "result"}) { %c = stablehlo.constant dense<4> : tensor<1xi32> loc(#loc12) %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32> loc(#loc12) %1 = stablehlo.concatenate %0, %c, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc12) %2 = stablehlo.dynamic_broadcast_in_dim %arg1, %1, dims = [0, 1] : (tensor<?x1xi32>, tensor<2xi32>) -> tensor<?x4xi32> loc(#loc12) %3 = stablehlo.add %2, %arg2 : tensor<?x4xi32> loc(#loc12) return %3 : tensor<?x4xi32> loc(#loc) } loc(#loc) } loc(#loc) #loc1 = loc("<string>":13:6 to :46) #loc4 = loc("<string>":7:8 to :13) #loc5 = loc("<module>"(#loc1)) #loc6 = loc("f1"(#loc4)) #loc7 = loc("dimension_size"(#loc5)) #loc8 = loc("ge"(#loc5)) #loc9 = loc("shape_assertion"(#loc5)) #loc10 = loc("eq"(#loc5)) #loc11 = loc(callsite(#loc6 at #loc5)) #loc12 = loc("jit(f1)/add"(#loc11)) ``` Before executing with PJRT one needs to refine shapes and remove the shape assertions, with eg: ``` stablehlo-opt hello.mlir -stablehlo-refine-arguments="types='tensor<10x1xi32>, tensor<10x4xi32>'" -stablehlo-refine-shapes --stablehlo-check-shape-assertions ``` Which gives us: ``` module @jit_f1 attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<10x1xi32>, %arg1: tensor<10x4xi32>) -> (tensor<10x4xi32> {jax.result_info = "result"}) { %c = stablehlo.constant dense<10> : tensor<i32> %c_0 = stablehlo.constant dense<true> : tensor<i1> %0 = call @_wrapped_jax_export_main(%arg0, %arg1) : (tensor<10x1xi32>, tensor<10x4xi32>) -> tensor<10x4xi32> return %0 : tensor<10x4xi32> } func.func private @_wrapped_jax_export_main(%arg0: tensor<10x1xi32>, %arg1: tensor<10x4xi32>) -> (tensor<10x4xi32> {jax.result_info = "result"}) { %c = stablehlo.constant dense<[10, 4]> : tensor<2xi32> %0 = stablehlo.dynamic_broadcast_in_dim %arg0, %c, dims = [0, 1] : (tensor<10x1xi32>, tensor<2xi32>) -> tensor<10x4xi32> %1 = stablehlo.add %0, %arg1 : tensor<10x4xi32> return %1 : tensor<10x4xi32> } } ```
1 parent 4c0d484 commit 0b5b4a2

File tree

7 files changed

+308
-0
lines changed

7 files changed

+308
-0
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,7 @@ cc_library(
11351135
"stablehlo/transforms/PassPipelines.cpp",
11361136
"stablehlo/transforms/ShapeLegalizeToStablehlo.cpp",
11371137
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
1138+
"stablehlo/transforms/StablehloCheckShapeAssertions.cpp",
11381139
"stablehlo/transforms/StablehloCompatibilityExpander.cpp",
11391140
"stablehlo/transforms/StablehloComplexMathExpander.cpp",
11401141
"stablehlo/transforms/StablehloConvertToSignless.cpp",

docs/dynamism.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,16 @@ Individually, the passes that tend to be useful for shape refinement are:
105105
shape information throughout the entire program.
106106
- [`stablehlo-canonicalize-dynamism`][canonicalize-dynamism] to replace dynamic
107107
ops with their static variants.
108+
- [`stablehlo-check-shape-assertions`][check-shape-assertions] to check and
109+
remove shape assertions custom calls.
108110
109111
See linked documentation for up-to-date information and examples.
110112
111113
[remove-dynamism]:https://github.com/openxla/stablehlo/blob/ff13c96e56b73c62dcbb5b34b69f5ece9e71322f/stablehlo/transforms/Passes.h#L134
112114
[canonicalize-dynamism]:https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-canonicalize-dynamism
113115
[refine-arguments]:https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-refine-arguments
114116
[refine-shapes]:https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-refine-shapes
117+
[check-shape-assertions]:https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-check-shape-assertions
115118
116119
## Example: How is dynamism useful, and how can I use it?
117120

docs/generated/stablehlo_passes.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,32 @@ these ops are actually constants.
3232
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
3333
```
3434

35+
### `-stablehlo-check-shape-assertions`
36+
37+
_Check stablehlo.custom_call @shape_assertion ops._
38+
39+
Validate shape_assertion custom calls.
40+
41+
Shape assertions validate constraints on dynamic dimensions in StableHLO.
42+
For example if a framework needed to enforce a constraint of `DimA < 2`,
43+
the following IR could be emitted:
44+
45+
```mlir
46+
%dimA = <get_dimension_size or input arg> : tensor<i32>
47+
%c2 = stablehlo.constant dense<2> : tensor<i32>
48+
%is_lt = stablehlo.compare LT %dimA, %c2 : tensor<i1>
49+
stablehlo.custom_call @shape_assertion(%is_lt) { error_message = "DimA must be less than 2" }
50+
```
51+
52+
After the pass, if the shapes are correct, the `stablehlo.custom_call`
53+
will be removed.
54+
55+
#### Options
56+
57+
```
58+
-enable-shape-assertions : Whether shape assertions may generate errors.
59+
```
60+
3561
### `-stablehlo-compatibility-expander`
3662

3763
_Compatibility expander for StableHLO operations._
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: stablehlo-opt --stablehlo-check-shape-assertions --split-input-file --verify-diagnostics %s | FileCheck %s --check-prefixes=CHECK
2+
3+
// CHECK-LABEL: func.func @assertion_succeeds
4+
// CHECK-NOT: stablehlo.custom_call @shape_assertion
5+
// CHECK: return
6+
func.func @assertion_succeeds() {
7+
%c1 = stablehlo.constant dense<true> : tensor<i1>
8+
%c0 = stablehlo.constant dense<0> : tensor<i32>
9+
stablehlo.custom_call @shape_assertion(%c1, %c0) {
10+
api_version = 2 : i32,
11+
error_message = "should not fire",
12+
has_side_effect = true
13+
} : (tensor<i1>, tensor<i32>) -> ()
14+
return
15+
}
16+
17+
// -----
18+
19+
// ERR-LABEL: func.func @assertion_fails
20+
func.func @assertion_fails() {
21+
%c1 = stablehlo.constant dense<false> : tensor<i1>
22+
%c0 = stablehlo.constant dense<7> : tensor<i32>
23+
// expected-error@+1 {{should fire}}
24+
stablehlo.custom_call @shape_assertion(%c1, %c0) {
25+
api_version = 2 : i32,
26+
error_message = "should fire",
27+
has_side_effect = true
28+
} : (tensor<i1>, tensor<i32>) -> ()
29+
return
30+
}
31+
32+
// -----
33+
34+
// ERR-LABEL: func.func @assertion_fails_not_constant
35+
func.func @assertion_fails_not_constant(%arg0 : tensor<i1>) {
36+
%c0 = stablehlo.constant dense<7> : tensor<i32>
37+
// expected-error@+1 {{expects static assert_what (operand #0)}}
38+
stablehlo.custom_call @shape_assertion(%arg0, %c0) {
39+
api_version = 2 : i32,
40+
error_message = "not firing",
41+
has_side_effect = true
42+
} : (tensor<i1>, tensor<i32>) -> ()
43+
return
44+
}

stablehlo/transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ add_mlir_dialect_library(StablehloPasses
6161
PassPipelines.cpp
6262
ShapeLegalizeToStablehlo.cpp
6363
StablehloCanonicalizeDynamism.cpp
64+
StablehloCheckShapeAssertions.cpp
6465
StablehloConvertToSignless.cpp
6566
StablehloCompatibilityExpander.cpp
6667
StablehloComplexMathExpander.cpp

stablehlo/transforms/Passes.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,35 @@ def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism",
5555
}];
5656
}
5757

58+
def StablehloCheckShapeAssertionsPass
59+
: Pass<"stablehlo-check-shape-assertions", "func::FuncOp"> {
60+
let summary = "Check stablehlo.custom_call @shape_assertion ops.";
61+
62+
let description = [{
63+
Validate shape_assertion custom calls.
64+
65+
Shape assertions validate constraints on dynamic dimensions in StableHLO.
66+
For example if a framework needed to enforce a constraint of `DimA < 2`,
67+
the following IR could be emitted:
68+
69+
```mlir
70+
%dimA = <get_dimension_size or input arg> : tensor<i32>
71+
%c2 = stablehlo.constant dense<2> : tensor<i32>
72+
%is_lt = stablehlo.compare LT %dimA, %c2 : tensor<i1>
73+
stablehlo.custom_call @shape_assertion(%is_lt) { error_message = "DimA must be less than 2" }
74+
```
75+
76+
After the pass, if the shapes are correct, the `stablehlo.custom_call`
77+
will be removed.
78+
}];
79+
80+
let options = [
81+
Option<"enable_shape_assertions", "enable-shape-assertions", "bool",
82+
/*default=*/"true",
83+
"Whether shape assertions may generate errors.">
84+
];
85+
}
86+
5887
def StablehloCompatibilityExpanderPass : Pass<"stablehlo-compatibility-expander", "mlir::ModuleOp"> {
5988
let summary = "Compatibility expander for StableHLO operations.";
6089

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/* Copyright 2023 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "llvm/Support/CommandLine.h"
17+
#include "llvm/Support/FormatVariadic.h"
18+
#include "llvm/Support/LogicalResult.h"
19+
#include "mlir/Dialect/Func/IR/FuncOps.h"
20+
#include "mlir/IR/BuiltinAttributes.h"
21+
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/Diagnostics.h"
23+
#include "mlir/IR/Operation.h"
24+
#include "mlir/Support/LLVM.h"
25+
#include "stablehlo/dialect/StablehloOps.h"
26+
#include "stablehlo/dialect/TypeInference.h"
27+
#include "stablehlo/transforms/Passes.h"
28+
29+
namespace mlir {
30+
namespace stablehlo {
31+
32+
#define GEN_PASS_DEF_STABLEHLOCHECKSHAPEASSERTIONSPASS
33+
#include "stablehlo/transforms/Passes.h.inc"
34+
35+
namespace {
36+
37+
constexpr llvm::StringRef shapeAssertionName = "shape_assertion";
38+
constexpr llvm::StringRef errorMessageAttrName = "error_message";
39+
// We bound the number of error_message_inputs for using llvm::formatv
40+
constexpr int maxErrorMessageInputs = 32; // TODO(necula): Remove this bound
41+
42+
// This pass is needed when we have shape assertions. A shape assertion is
43+
// represented via the `stablehlo.custom_call @shape_assertion`
44+
// custom call, and represents an assertion that the first operand
45+
// (`assert_what`) evaluates to `true`. The custom call also has an
46+
// `error_message` string attribute, and a variadic number
47+
// of integer scalar operands that may be used to format the error message.
48+
// The `error_message` may contain format specifiers `{0}`, `{1}`, ..., that
49+
// are replaced with the values of the error message inputs. The formatting is
50+
// done with the `llvm::formatv` function
51+
// (https://llvm.org/docs/ProgrammersManual.html#formatting-strings-the-formatv-function).
52+
//
53+
struct CheckShapeAssertionsPass
54+
: public impl::StablehloCheckShapeAssertionsPassBase<
55+
CheckShapeAssertionsPass> {
56+
using StablehloCheckShapeAssertionsPassBase::
57+
StablehloCheckShapeAssertionsPassBase;
58+
59+
void runOnOperation() override {
60+
func::FuncOp funcOp = getOperation();
61+
funcOp.walk([this](CustomCallOp op) {
62+
if (op.getCallTargetName() != shapeAssertionName) return;
63+
if (!enable_shape_assertions) {
64+
op.erase();
65+
return;
66+
}
67+
// Check first for ill-formed assertions, rather than silently fail.
68+
if (failed(verifyShapeAssertion(op))) {
69+
signalPassFailure();
70+
return;
71+
}
72+
OperandRange inputs = op.getInputs();
73+
SmallVector<int64_t> assertWhat;
74+
if (failed(hlo::matchInts(inputs[0], assertWhat))) {
75+
op.emitError() << "expects static assert_what (operand #0)";
76+
signalPassFailure();
77+
return;
78+
}
79+
if (assertWhat[0] != 0) {
80+
op.erase();
81+
return;
82+
}
83+
StringRef errorMessage = getErrorMessage(op);
84+
SmallVector<int64_t> errorMessageInputs;
85+
for (size_t i = 1; i < inputs.size(); ++i) {
86+
SmallVector<int64_t> input;
87+
if (failed(hlo::matchInts(inputs[i], input))) {
88+
op.emitError() << "expects static error_message_input (operand #" << i
89+
<< ")";
90+
signalPassFailure();
91+
return;
92+
}
93+
errorMessageInputs.push_back(input[0]);
94+
}
95+
op.emitError(formatErrorMessage(errorMessage, errorMessageInputs));
96+
signalPassFailure();
97+
});
98+
}
99+
100+
private:
101+
LogicalResult verifyShapeAssertion(CustomCallOp op) {
102+
if (!(1 <= op->getNumOperands() &&
103+
op->getNumOperands() <= 1 + maxErrorMessageInputs))
104+
return op.emitError() << "expects 1 <= size(operands) <= "
105+
<< (1 + maxErrorMessageInputs);
106+
int nrErrorMessageInputs = op.getNumOperands() - 1;
107+
if (op->getNumResults() != 0)
108+
return op.emitError("expects size(results) = 0");
109+
for (const auto& attr : op->getAttrs()) {
110+
if (attr.getName() != "api_version" &&
111+
attr.getName() != "backend_config" &&
112+
attr.getName() != "call_target_name" &&
113+
attr.getName() != "error_message" &&
114+
attr.getName() != "has_side_effect")
115+
return op.emitError()
116+
<< attr.getName() << " is not a supported attribute";
117+
}
118+
if (!op.hasEmptyBackendConfig())
119+
return op.emitError() << "expects an empty backend_config";
120+
if (op.getCallTargetName() != shapeAssertionName)
121+
return op.emitError() << "expects @shape_assertion";
122+
123+
// input[0] (assert_what) : tensor<i1>
124+
auto assertWhatType = dyn_cast<ShapedType>(op.getInputs()[0].getType());
125+
if (!assertWhatType || !assertWhatType.hasRank() ||
126+
assertWhatType.getRank() != 0 ||
127+
!assertWhatType.getElementType().isSignlessInteger() ||
128+
assertWhatType.getElementTypeBitWidth() != 1)
129+
return op.emitError() << "expects assert_what (operand #0) "
130+
<< "to be a constant of type tensor<i1>";
131+
132+
// input[1:] (error_message_inputs) : tensor<i32> or tensor<i64>
133+
for (int i = 0; i < nrErrorMessageInputs; ++i) {
134+
auto errorMessageInputType =
135+
dyn_cast<ShapedType>(op.getInputs()[i + 1].getType());
136+
if (!errorMessageInputType || !errorMessageInputType.hasRank() ||
137+
errorMessageInputType.getRank() != 0 ||
138+
!errorMessageInputType.getElementType().isSignlessInteger() ||
139+
(errorMessageInputType.getElementTypeBitWidth() != 32 &&
140+
errorMessageInputType.getElementTypeBitWidth() != 64))
141+
return op.emitError()
142+
<< "expects error_message_input (operand #" << (i + 1) << ") "
143+
<< "to be a constant of type tensor<i32> or tensor<i64>";
144+
}
145+
146+
if (!op->hasAttr(errorMessageAttrName))
147+
return op.emitError() << "expects an error_message attribute";
148+
149+
// error_message contains valid format specifiers.
150+
StringRef errorMessage = getErrorMessage(op);
151+
152+
// format specs: "{" index ["," layout] [":" format] "}"
153+
size_t spec_begin = errorMessage.find_first_of('{');
154+
size_t spec_end = errorMessage.find_first_of(",:}", spec_begin);
155+
156+
// Check that all specs reference valid input indices.
157+
while (spec_begin != StringRef::npos && spec_end != StringRef::npos) {
158+
StringRef index_str =
159+
errorMessage.substr(spec_begin + 1, spec_end - spec_begin - 1);
160+
161+
int32_t index;
162+
if (!index_str.getAsInteger(10, index) &&
163+
!(0 <= index && index < nrErrorMessageInputs)) {
164+
return op.emitError()
165+
<< "expects error_message to contain format specifiers with "
166+
<< "error_message_input index less than " << nrErrorMessageInputs
167+
<< ". Found specifier "
168+
<< errorMessage.substr(spec_begin, spec_end - spec_begin + 1);
169+
}
170+
171+
spec_begin = errorMessage.find_first_of('{', spec_begin + 1);
172+
spec_end = errorMessage.find_first_of(",:}", spec_begin);
173+
}
174+
return success();
175+
}
176+
177+
StringRef getErrorMessage(CustomCallOp op) const {
178+
return cast<StringAttr>(op->getAttr(errorMessageAttrName)).getValue();
179+
}
180+
181+
std::string formatErrorMessage(
182+
StringRef errorMessage,
183+
const SmallVector<int64_t>& errorMessageInputs) const {
184+
int nrErrorMessageInputs = errorMessageInputs.size();
185+
auto errorMessageFormat = errorMessage.data();
186+
if (nrErrorMessageInputs == 0) return errorMessageFormat;
187+
auto errInput = [nrErrorMessageInputs, &errorMessageInputs](int idx) {
188+
return (idx < nrErrorMessageInputs ? errorMessageInputs[idx] : -1);
189+
};
190+
return llvm::formatv(
191+
false, errorMessageFormat, errInput(0), errInput(1), errInput(2),
192+
errInput(3), errInput(4), errInput(5), errInput(6), errInput(7),
193+
errInput(8), errInput(9), errInput(10), errInput(11), errInput(12),
194+
errInput(13), errInput(14), errInput(15), errInput(16), errInput(17),
195+
errInput(18), errInput(19), errInput(20), errInput(21), errInput(22),
196+
errInput(23), errInput(24), errInput(25), errInput(26), errInput(27),
197+
errInput(28), errInput(29), errInput(30), errInput(31));
198+
}
199+
};
200+
201+
} // namespace
202+
203+
} // namespace stablehlo
204+
} // namespace mlir

0 commit comments

Comments
 (0)