@@ -14,15 +14,23 @@ limitations under the License.
1414==============================================================================*/
1515#include " stablehlo/reference/Api.h"
1616
17+ #include < cstdint>
18+
19+ #include " llvm/ADT/STLExtras.h"
1720#include " llvm/ADT/SmallVector.h"
1821#include " llvm/ADT/StringRef.h"
1922#include " llvm/Support/Error.h"
2023#include " llvm/Support/FileSystem.h"
2124#include " llvm/Support/Path.h"
2225#include " llvm/Support/SourceMgr.h"
2326#include " mlir/Dialect/Func/IR/FuncOps.h"
27+ #include " mlir/IR/BuiltinAttributes.h"
28+ #include " mlir/IR/Diagnostics.h"
2429#include " mlir/IR/DialectRegistry.h"
30+ #include " mlir/IR/Location.h"
31+ #include " mlir/IR/TypeRange.h"
2532#include " mlir/Parser/Parser.h"
33+ #include " mlir/Support/LogicalResult.h"
2634#include " stablehlo/dialect/Register.h"
2735#include " stablehlo/reference/Configuration.h"
2836#include " stablehlo/reference/Errors.h"
@@ -31,11 +39,13 @@ limitations under the License.
3139#include " stablehlo/reference/Ops.h"
3240#include " stablehlo/reference/Process.h"
3341#include " stablehlo/reference/Scope.h"
42+ #include " stablehlo/reference/Tensor.h"
43+ #include " stablehlo/reference/Value.h"
3444
3545namespace mlir {
3646namespace stablehlo {
3747namespace {
38- func::FuncOp getMainFunction (ModuleOp module , StringRef mainName) {
48+ FailureOr< func::FuncOp> getMainFunction (ModuleOp module , StringRef mainName) {
3949 auto functions = module .getOps <func::FuncOp>();
4050
4151 for (auto funcOp : functions)
@@ -46,7 +56,8 @@ func::FuncOp getMainFunction(ModuleOp module, StringRef mainName) {
4656 bool isDefaultLookup = mainName == " main" ;
4757 if (isSingleFunction && isDefaultLookup) return *functions.begin ();
4858
49- return {};
59+ return module .emitError ()
60+ << " module must have entry func with name " << mainName;
5061}
5162
5263// DefaultInterpreterFallback is an implementation detail of run module. It
@@ -106,33 +117,77 @@ class DefaultInterpreterFallback : public InterpreterFallback {
106117 int64_t serializedProbeFileId = 0 ;
107118};
108119
120+ LogicalResult validateEntrySignature (func::FuncOp func,
121+ ArrayRef<InterpreterValue> inputs) {
122+ if (func.getNumArguments () != inputs.size ())
123+ return func->emitError ()
124+ << " incorrect number of arguments specified, provided "
125+ << inputs.size () << " inputs but function expected"
126+ << func.getNumArguments ();
127+
128+ TypeRange signature = func.getArgumentTypes ();
129+ for (int64_t i = 0 ; i < func.getNumArguments (); ++i) {
130+ Type sigType = signature[i];
131+ Type argType = inputs[i].getType ();
132+ if (sigType != argType)
133+ return func.emitError () << " invalid input argument type at index " << i
134+ << " , input type was " << argType
135+ << " but entry function expected " << sigType;
136+ }
137+ return success ();
138+ }
139+
109140} // namespace
110141
111- llvm::ErrorOr <SmallVector<InterpreterValue>> evalModule (
142+ FailureOr <SmallVector<InterpreterValue>> evalModule (
112143 ModuleOp module , ArrayRef<InterpreterValue> inputs,
113144 const InterpreterConfiguration &config) {
145+ // Additional error checking at main function boundary.
146+ // This is most likely user error, where future errors during interpreting are
147+ // more likely invalid IR or interpreter bugs.
114148 if (module .getOps <func::FuncOp>().empty ())
115149 return SmallVector<InterpreterValue>();
116150
117151 auto mainFunc = getMainFunction (module , config.mainFunction );
118- if (!mainFunc) llvm::report_fatal_error (" Requested main function not found." );
152+ if (failed (mainFunc) || failed (validateEntrySignature (*mainFunc, inputs)))
153+ return failure ();
119154
120155 if (!config.probeInstrumentationDir .empty ()) {
121156 llvm::SmallString<128 > instrumentationMetadataFile (
122157 config.probeInstrumentationDir );
123158 llvm::sys::path::append (instrumentationMetadataFile,
124159 stablehlo::numpy::kInstrumentationMetadataFilename );
125160 if (llvm::sys::fs::remove (instrumentationMetadataFile))
126- llvm::report_fatal_error (
161+ return emitError (
162+ UnknownLoc::get (module .getContext ()),
127163 " Failed to remove existing instrumentation metadata file." );
128164 }
129165
130166 DefaultInterpreterFallback fallback (config);
131- return stablehlo::eval (mainFunc.getBody (), inputs, &fallback);
167+ return stablehlo::eval (mainFunc->getBody (), inputs, &fallback);
168+ }
169+
170+ FailureOr<SmallVector<DenseElementsAttr>> evalModule (
171+ ModuleOp module , ArrayRef<DenseElementsAttr> inputs,
172+ const InterpreterConfiguration &config) {
173+ SmallVector<InterpreterValue> valueInputs = llvm::to_vector (
174+ llvm::map_range (inputs, [](DenseElementsAttr attr) -> InterpreterValue {
175+ return InterpreterValue (makeTensor (attr));
176+ }));
177+
178+ auto values = evalModule (module , valueInputs, config);
179+ if (failed (values)) return failure ();
180+
181+ SmallVector<DenseElementsAttr> results = llvm::to_vector (llvm::map_range (
182+ values.value (), [](InterpreterValue val) -> DenseElementsAttr {
183+ return makeDenseElementsAttr (val.getTensor ());
184+ }));
185+
186+ return results;
132187}
133188
134- llvm::ErrorOr <OwningOpRef<ModuleOp>> parseStablehloModule (
135- const std::string &mlir, MLIRContext &context) {
189+ FailureOr <OwningOpRef<ModuleOp>> parseStablehloModule (const std::string &mlir,
190+ MLIRContext &context) {
136191 llvm::SourceMgr source_mgr;
137192 source_mgr.AddNewSourceBuffer (llvm::MemoryBuffer::getMemBuffer (mlir),
138193 llvm::SMLoc ());
@@ -145,7 +200,8 @@ llvm::ErrorOr<OwningOpRef<ModuleOp>> parseStablehloModule(
145200 mlir::OwningOpRef<mlir::ModuleOp> module (
146201 mlir::parseSourceFile<mlir::ModuleOp>(source_mgr, &context));
147202
148- if (!module ) return llvm::errc::invalid_argument;
203+ if (!module )
204+ return emitError (UnknownLoc::get (&context), " unable to parse module" );
149205
150206 return module ;
151207}
0 commit comments