Skip to content

Commit db1a8b1

Browse files
authored
Add a python API for reference interpreter (#2062)
```py ASM = """ func.func @test(%arg0: tensor<2x2xf16>) -> tensor<2x2xf16> { %0 = stablehlo.add %arg0, %arg0 : (tensor<2x2xf16>, tensor<2x2xf16>) -> tensor<2x2xf16> func.return %0 : tensor<2x2xf16> } """ m = ir.Module.parse(ASM) arg = np.asarray([1, 2, 3, 4], np.float16).reshape(2,2) args = [ir.DenseIntElementsAttr.get(arg)] res = stablehlo.eval_module(m, args) ``` - Added python API for reference interpreter, and test cases for int/float datatypes - Added reference API which takes DenseElementsAttrs to be used by python API - Added method to convert between `stablehlo::Tensor` and `mlir::DenseElementsAttr` more easily (only supports int/float types) - Added more error checking at interpreter entry point to improve UX - Changed error reporting at API boundary from `llvm::Error` to `mlir::LogicalResult` to be more consistent for external users, also improves python error experience
1 parent d4c8682 commit db1a8b1

File tree

10 files changed

+206
-45
lines changed

10 files changed

+206
-45
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ cc_library(
431431
":reference_ops",
432432
":reference_process",
433433
":reference_scope",
434+
":reference_tensor",
434435
":reference_value",
435436
":register",
436437
"@llvm-project//llvm:Support",

stablehlo/integrations/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ declare_mlir_python_extension(StablehloPythonExtensions.Main
8585
StablehloCAPI
8686
PRIVATE_LINK_LIBS
8787
StablehloPortableApi
88+
StablehloReferenceApi
8889
StablehloSerialization
8990
LLVMSupport
9091
)

stablehlo/integrations/python/StablehloModule.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@ See the License for the specific language governing permissions and
1111
limitations under the License.
1212
==============================================================================*/
1313

14+
#include <vector>
15+
1416
#include "mlir-c/IR.h"
1517
#include "mlir/Bindings/Python/PybindAdaptors.h"
1618
#include "mlir/CAPI/IR.h"
19+
#include "mlir/IR/BuiltinAttributes.h"
1720
#include "stablehlo/dialect/Serialization.h"
1821
#include "stablehlo/integrations/c/StablehloAttributes.h"
1922
#include "stablehlo/integrations/c/StablehloDialect.h"
2023
#include "stablehlo/integrations/c/StablehloTypes.h"
2124
#include "stablehlo/integrations/python/PortableApi.h"
25+
#include "stablehlo/reference/Api.h"
2226

2327
namespace py = pybind11;
2428

@@ -483,6 +487,38 @@ PYBIND11_MODULE(_stablehlo, m) {
483487
//
484488
mlir::stablehlo::AddPortableApi(m);
485489

490+
//
491+
// Reference APIs
492+
//
493+
m.def(
494+
"eval_module",
495+
[](MlirModule module,
496+
std::vector<MlirAttribute> &args) -> std::vector<MlirAttribute> {
497+
std::vector<mlir::DenseElementsAttr> inputs;
498+
for (auto arg : args) {
499+
auto attr = unwrap(arg).dyn_cast<mlir::DenseElementsAttr>();
500+
if (!attr) {
501+
PyErr_SetString(PyExc_ValueError,
502+
"input args must be DenseElementsAttr");
503+
return {};
504+
}
505+
inputs.push_back(attr);
506+
}
507+
508+
mlir::stablehlo::InterpreterConfiguration config;
509+
auto results =
510+
mlir::stablehlo::evalModule(unwrap(module), inputs, config);
511+
if (failed(results)) {
512+
PyErr_SetString(PyExc_ValueError, "interpreter failed");
513+
return {};
514+
}
515+
516+
std::vector<MlirAttribute> pyResults;
517+
for (auto res : *results) pyResults.push_back(wrap(res));
518+
return pyResults;
519+
},
520+
py::arg("module"), py::arg("args"));
521+
486522
//
487523
// Serialization APIs.
488524
//

stablehlo/integrations/python/tests/stablehlo.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io
2222
from mlir import ir
2323
from mlir.dialects import stablehlo
24+
import numpy as np
2425

2526

2627
def run(f):
@@ -227,22 +228,51 @@ def test_minimum_version():
227228
assert is_semver_format(curr_version)
228229

229230

230-
ASM = """
231-
func.func @test(%arg0: tensor<2xf32>) -> tensor<2xf32> {
232-
%0 = stablehlo.add %arg0, %arg0 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
233-
func.return %0 : tensor<2xf32>
234-
}
231+
ASM_FORMAT = """
232+
func.func @test(%arg0: tensor<{0}>) -> tensor<{0}> {{
233+
%0 = stablehlo.add %arg0, %arg0 : (tensor<{0}>, tensor<{0}>) -> tensor<{0}>
234+
func.return %0 : tensor<{0}>
235+
}}
235236
"""
236237

238+
239+
@run
240+
def test_reference_api():
241+
# Formatted as (tensor_type, np_value)
242+
# Program runs arg + arg, which is used for expected value
243+
tests = [
244+
# No numpy types for f8 - skipping fp8 tests
245+
("f16", np.asarray(1, np.float16)),
246+
("f32", np.asarray(2, np.float32)),
247+
("f64", np.asarray(3, np.double)),
248+
("1xi8", np.asarray([4], np.int8)),
249+
("1xi16", np.asarray([5], np.int16)),
250+
("1xi32", np.asarray([-6], np.int32)),
251+
# Numpy's uint treated as int by DenseElementsAttr, skipping np.uint tests
252+
("2x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,2)),
253+
("2x1x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,1,2)),
254+
]
255+
for test in tests:
256+
tensor_type, arg = test
257+
with ir.Context() as context:
258+
stablehlo.register_dialect(context)
259+
m = ir.Module.parse(ASM_FORMAT.format(tensor_type))
260+
args = [ir.DenseIntElementsAttr.get(arg)]
261+
262+
actual = np.array(stablehlo.eval_module(m, args)[0])
263+
expected = arg + arg
264+
assert (actual == expected).all()
265+
266+
237267
@run
238268
def test_serialization_apis():
239269
curr_version = stablehlo.get_current_version()
240270

241271
with ir.Context() as context:
242272
stablehlo.register_dialect(context)
243-
m = ir.Module.parse(ASM)
244-
module_str = str(m)
273+
m = ir.Module.parse(ASM_FORMAT.format("2xf32"))
245274
assert m is not None
275+
module_str = str(m)
246276
serialized = stablehlo.serialize_portable_artifact(m, curr_version)
247277
deserialized = stablehlo.deserialize_portable_artifact(context, serialized)
248278
assert module_str == str(deserialized)
@@ -258,9 +288,9 @@ def module_to_bytecode(module: ir.Module) -> bytes:
258288

259289
with ir.Context() as context:
260290
stablehlo.register_dialect(context)
261-
m = ir.Module.parse(ASM)
262-
module_str = str(m)
291+
m = ir.Module.parse(ASM_FORMAT.format("2xf32"))
263292
assert m is not None
293+
module_str = str(m)
264294
bytecode = module_to_bytecode(m)
265295
serialized = stablehlo.serialize_portable_artifact(bytecode, curr_version)
266296
deserialized = stablehlo.deserialize_portable_artifact(serialized)

stablehlo/reference/Api.cpp

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,23 @@ limitations under the License.
1414
==============================================================================*/
1515
#include "stablehlo/reference/Api.h"
1616

17+
#include <cstdint>
18+
19+
#include "llvm/ADT/STLExtras.h"
1720
#include "llvm/ADT/SmallVector.h"
1821
#include "llvm/ADT/StringRef.h"
1922
#include "llvm/Support/Error.h"
2023
#include "llvm/Support/FileSystem.h"
2124
#include "llvm/Support/Path.h"
2225
#include "llvm/Support/SourceMgr.h"
2326
#include "mlir/Dialect/Func/IR/FuncOps.h"
27+
#include "mlir/IR/BuiltinAttributes.h"
28+
#include "mlir/IR/Diagnostics.h"
2429
#include "mlir/IR/DialectRegistry.h"
30+
#include "mlir/IR/Location.h"
31+
#include "mlir/IR/TypeRange.h"
2532
#include "mlir/Parser/Parser.h"
33+
#include "mlir/Support/LogicalResult.h"
2634
#include "stablehlo/dialect/Register.h"
2735
#include "stablehlo/reference/Configuration.h"
2836
#include "stablehlo/reference/Errors.h"
@@ -31,11 +39,13 @@ limitations under the License.
3139
#include "stablehlo/reference/Ops.h"
3240
#include "stablehlo/reference/Process.h"
3341
#include "stablehlo/reference/Scope.h"
42+
#include "stablehlo/reference/Tensor.h"
43+
#include "stablehlo/reference/Value.h"
3444

3545
namespace mlir {
3646
namespace stablehlo {
3747
namespace {
38-
func::FuncOp getMainFunction(ModuleOp module, StringRef mainName) {
48+
FailureOr<func::FuncOp> getMainFunction(ModuleOp module, StringRef mainName) {
3949
auto functions = module.getOps<func::FuncOp>();
4050

4151
for (auto funcOp : functions)
@@ -46,7 +56,8 @@ func::FuncOp getMainFunction(ModuleOp module, StringRef mainName) {
4656
bool isDefaultLookup = mainName == "main";
4757
if (isSingleFunction && isDefaultLookup) return *functions.begin();
4858

49-
return {};
59+
return module.emitError()
60+
<< "module must have entry func with name " << mainName;
5061
}
5162

5263
// DefaultInterpreterFallback is an implementation detail of run module. It
@@ -106,33 +117,77 @@ class DefaultInterpreterFallback : public InterpreterFallback {
106117
int64_t serializedProbeFileId = 0;
107118
};
108119

120+
LogicalResult validateEntrySignature(func::FuncOp func,
121+
ArrayRef<InterpreterValue> inputs) {
122+
if (func.getNumArguments() != inputs.size())
123+
return func->emitError()
124+
<< "incorrect number of arguments specified, provided "
125+
<< inputs.size() << " inputs but function expected"
126+
<< func.getNumArguments();
127+
128+
TypeRange signature = func.getArgumentTypes();
129+
for (int64_t i = 0; i < func.getNumArguments(); ++i) {
130+
Type sigType = signature[i];
131+
Type argType = inputs[i].getType();
132+
if (sigType != argType)
133+
return func.emitError() << "invalid input argument type at index " << i
134+
<< ", input type was " << argType
135+
<< " but entry function expected " << sigType;
136+
}
137+
return success();
138+
}
139+
109140
} // namespace
110141

111-
llvm::ErrorOr<SmallVector<InterpreterValue>> evalModule(
142+
FailureOr<SmallVector<InterpreterValue>> evalModule(
112143
ModuleOp module, ArrayRef<InterpreterValue> inputs,
113144
const InterpreterConfiguration &config) {
145+
// Additional error checking at main function boundary.
146+
// This is most likely user error, where future errors during interpreting are
147+
// more likely invalid IR or interpreter bugs.
114148
if (module.getOps<func::FuncOp>().empty())
115149
return SmallVector<InterpreterValue>();
116150

117151
auto mainFunc = getMainFunction(module, config.mainFunction);
118-
if (!mainFunc) llvm::report_fatal_error("Requested main function not found.");
152+
if (failed(mainFunc) || failed(validateEntrySignature(*mainFunc, inputs)))
153+
return failure();
119154

120155
if (!config.probeInstrumentationDir.empty()) {
121156
llvm::SmallString<128> instrumentationMetadataFile(
122157
config.probeInstrumentationDir);
123158
llvm::sys::path::append(instrumentationMetadataFile,
124159
stablehlo::numpy::kInstrumentationMetadataFilename);
125160
if (llvm::sys::fs::remove(instrumentationMetadataFile))
126-
llvm::report_fatal_error(
161+
return emitError(
162+
UnknownLoc::get(module.getContext()),
127163
"Failed to remove existing instrumentation metadata file.");
128164
}
129165

130166
DefaultInterpreterFallback fallback(config);
131-
return stablehlo::eval(mainFunc.getBody(), inputs, &fallback);
167+
return stablehlo::eval(mainFunc->getBody(), inputs, &fallback);
168+
}
169+
170+
FailureOr<SmallVector<DenseElementsAttr>> evalModule(
171+
ModuleOp module, ArrayRef<DenseElementsAttr> inputs,
172+
const InterpreterConfiguration &config) {
173+
SmallVector<InterpreterValue> valueInputs = llvm::to_vector(
174+
llvm::map_range(inputs, [](DenseElementsAttr attr) -> InterpreterValue {
175+
return InterpreterValue(makeTensor(attr));
176+
}));
177+
178+
auto values = evalModule(module, valueInputs, config);
179+
if (failed(values)) return failure();
180+
181+
SmallVector<DenseElementsAttr> results = llvm::to_vector(llvm::map_range(
182+
values.value(), [](InterpreterValue val) -> DenseElementsAttr {
183+
return makeDenseElementsAttr(val.getTensor());
184+
}));
185+
186+
return results;
132187
}
133188

134-
llvm::ErrorOr<OwningOpRef<ModuleOp>> parseStablehloModule(
135-
const std::string &mlir, MLIRContext &context) {
189+
FailureOr<OwningOpRef<ModuleOp>> parseStablehloModule(const std::string &mlir,
190+
MLIRContext &context) {
136191
llvm::SourceMgr source_mgr;
137192
source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(mlir),
138193
llvm::SMLoc());
@@ -145,7 +200,8 @@ llvm::ErrorOr<OwningOpRef<ModuleOp>> parseStablehloModule(
145200
mlir::OwningOpRef<mlir::ModuleOp> module(
146201
mlir::parseSourceFile<mlir::ModuleOp>(source_mgr, &context));
147202

148-
if (!module) return llvm::errc::invalid_argument;
203+
if (!module)
204+
return emitError(UnknownLoc::get(&context), "unable to parse module");
149205

150206
return module;
151207
}

stablehlo/reference/Api.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <string>
2020

2121
#include "llvm/Support/ErrorOr.h"
22+
#include "mlir/IR/BuiltinAttributes.h"
2223
#include "mlir/IR/BuiltinOps.h"
2324
#include "mlir/IR/MLIRContext.h"
2425
#include "mlir/IR/OwningOpRef.h"
@@ -33,13 +34,19 @@ namespace stablehlo {
3334
/// module input and provided inputs. Returns a list of interpreter outputs.
3435
/// Can optionally pass a fallback interpreter callback which executes when no
3536
/// builtin kernels are matched.
36-
llvm::ErrorOr<SmallVector<InterpreterValue>> evalModule(
37+
FailureOr<SmallVector<InterpreterValue>> evalModule(
3738
ModuleOp module, ArrayRef<InterpreterValue> inputs,
3839
const InterpreterConfiguration &config);
3940

41+
/// This wrapper is intended to be easily used by the StableHLO Python bindings.
42+
// It wraps the InterpreterValue API.
43+
FailureOr<SmallVector<DenseElementsAttr>> evalModule(
44+
ModuleOp module, ArrayRef<DenseElementsAttr> inputs,
45+
const InterpreterConfiguration &config);
46+
4047
/// Parses a StableHLO MLIR text program into a ModuleOp.
41-
llvm::ErrorOr<OwningOpRef<ModuleOp>> parseStablehloModule(
42-
const std::string &mlir, MLIRContext &context);
48+
FailureOr<OwningOpRef<ModuleOp>> parseStablehloModule(const std::string &mlir,
49+
MLIRContext &context);
4350

4451
} // namespace stablehlo
4552
} // namespace mlir

stablehlo/reference/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_mlir_library(StablehloReferenceApi
2626
StablehloReferenceOps
2727
StablehloReferenceProcess
2828
StablehloReferenceScope
29+
StablehloReferenceTensor
2930
StablehloReferenceValue
3031
StablehloRegister
3132
)

0 commit comments

Comments
 (0)