@@ -142,6 +142,8 @@ struct FunctionSymbolRefContainer : SymbolRefContainer {
142142
143143using ImportDesc = std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
144144
145+ using parsed_inst_t = llvm::FailureOr<llvm::SmallVector<Value>>;
146+
145147struct WasmModuleSymbolTables {
146148 llvm::SmallVector<FunctionSymbolRefContainer> funcSymbols;
147149 llvm::SmallVector<GlobalSymbolRefContainer> globalSymbols;
@@ -173,6 +175,134 @@ struct WasmModuleSymbolTables {
173175 return getNewSymbolName (" table_" , id);
174176 }
175177};
178+
179+ class ParserHead ;
180+
181+ // / Wrapper around SmallVector to only allow access as push and pop on the
182+ // / stack. Makes sure that there are no "free accesses" on the stack to preserve
183+ // / its state.
184+ class ValueStack {
185+ private:
186+ struct LabelLevel {
187+ size_t stackIdx;
188+ LabelLevelOpInterface levelOp;
189+ };
190+ public:
191+ bool empty () const { return values.empty (); }
192+
193+ size_t size () const { return values.size (); }
194+
195+ // / Pops values from the stack because they are being used in an operation.
196+ // / @param operandTypes The list of expected types of the operation, used
197+ // / to know how many values to pop and check if the types match the
198+ // / expectation.
199+ // / @param opLoc Location of the caller, used to report accurately the
200+ // / location
201+ // / if an error occurs.
202+ // / @return Failure or the vector of popped values.
203+ llvm::FailureOr<llvm::SmallVector<Value>> popOperands (TypeRange operandTypes,
204+ Location *opLoc);
205+
206+ // / Push the results of an operation to the stack so they can be used in a
207+ // / following operation.
208+ // / @param results The list of results of the operation
209+ // / @param opLoc Location of the caller, used to report accurately the
210+ // / location
211+ // / if an error occurs.
212+ LogicalResult pushResults (ValueRange results, Location *opLoc);
213+
214+
215+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
216+ // / A simple dump function for debugging.
217+ // / Writes output to llvm::dbgs().
218+ LLVM_DUMP_METHOD void dump () const ;
219+ #endif
220+
221+ private:
222+ llvm::SmallVector<Value> values;
223+ };
224+
225+ using local_val_t = TypedValue<wasmssa::LocalRefType>;
226+
227+ class ExpressionParser {
228+ public:
229+ using locals_t = llvm::SmallVector<local_val_t >;
230+ ExpressionParser (ParserHead &parser, WasmModuleSymbolTables const &symbols,
231+ llvm::ArrayRef<local_val_t > initLocal)
232+ : parser{parser}, symbols{symbols}, locals{initLocal} {}
233+
234+ private:
235+ template <std::byte opCode>
236+ inline parsed_inst_t parseSpecificInstruction (OpBuilder &builder);
237+
238+ template <typename valueT>
239+ parsed_inst_t
240+ parseConstInst (OpBuilder &builder,
241+ std::enable_if_t <std::is_arithmetic_v<valueT>> * = nullptr );
242+
243+
244+ // / This function generates a dispatch tree to associate an opcode with a
245+ // / parser. Parsers are registered by specialising the
246+ // / `parseSpecificInstruction` function for the op code to handle.
247+ // /
248+ // / The dispatcher is generated by recursively creating all possible patterns
249+ // / for an opcode and calling the relevant parser on the leaf.
250+ // /
251+ // / @tparam patternBitSize is the first bit for which the pattern is not fixed
252+ // /
253+ // / @tparam highBitPattern is the fixed pattern that this instance handles for
254+ // / the 8-patternBitSize bits
255+ template <size_t patternBitSize = 0 , std::byte highBitPattern = std::byte{0 }>
256+ inline parsed_inst_t dispatchToInstParser (std::byte opCode,
257+ OpBuilder &builder) {
258+ static_assert (patternBitSize <= 8 ,
259+ " PatternBitSize is outside of range of opcode space! "
260+ " (expected at most 8 bits)" );
261+ if constexpr (patternBitSize < 8 ) {
262+ constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
263+ constexpr std::byte nextHighBitPatternStem = highBitPattern << 1 ;
264+ constexpr size_t nextPatternBitSize = patternBitSize + 1 ;
265+ if ((opCode & bitSelect) != std::byte{0 })
266+ return dispatchToInstParser < nextPatternBitSize,
267+ nextHighBitPatternStem | std::byte{1 } > (opCode, builder);
268+ return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
269+ opCode, builder);
270+ } else {
271+ return parseSpecificInstruction<highBitPattern>(builder);
272+ }
273+ }
274+
275+ struct ParseResultWithInfo {
276+ llvm::SmallVector<Value> opResults;
277+ std::byte endingByte;
278+ };
279+
280+ public:
281+ template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
282+ parsed_inst_t parse (OpBuilder &builder,
283+ UniqueByte<ParseEndByte> = {});
284+
285+ template <std::byte... ExpressionParseEnd>
286+ llvm::FailureOr<ParseResultWithInfo>
287+ parse (OpBuilder &builder,
288+ ByteSequence<ExpressionParseEnd...> parsingEndFilters);
289+
290+ llvm::FailureOr<llvm::SmallVector<Value>>
291+ popOperands (TypeRange operandTypes) {
292+ return valueStack.popOperands (operandTypes, ¤tOpLoc.value ());
293+ }
294+
295+ LogicalResult pushResults (ValueRange results) {
296+ return valueStack.pushResults (results, ¤tOpLoc.value ());
297+ }
298+ private:
299+ std::optional<Location> currentOpLoc;
300+ ParserHead &parser;
301+ WasmModuleSymbolTables const &symbols;
302+ locals_t locals;
303+ ValueStack valueStack;
304+ };
305+
176306class ParserHead {
177307public:
178308 ParserHead (llvm::StringRef src, StringAttr name) : head{src}, locName{name} {}
@@ -382,6 +512,14 @@ class ParserHead {
382512 << static_cast <int >(*importType);
383513 }
384514 }
515+
516+ parsed_inst_t parseExpression (OpBuilder &builder,
517+ WasmModuleSymbolTables const &symbols,
518+ llvm::ArrayRef<local_val_t > locals = {}) {
519+ auto eParser = ExpressionParser{*this , symbols, locals};
520+ return eParser.parse (builder);
521+ }
522+
385523 bool end () const { return curHead ().empty (); }
386524
387525 ParserHead copy () const {
@@ -491,6 +629,201 @@ inline llvm::FailureOr<int64_t> ParserHead::parseI64() {
491629 return parseLiteral<int64_t >();
492630}
493631
632+ template <std::byte opCode>
633+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction (OpBuilder &) {
634+ return emitError (*currentOpLoc, " Unknown instruction opcode: " )
635+ << static_cast <int >(opCode);
636+ }
637+
638+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
639+ void ValueStack::dump () const {
640+ llvm::dbgs () << " ================= Wasm ValueStack =======================\n " ;
641+ llvm::dbgs () << " size: " << size () << " \n " ;
642+ llvm::dbgs () << " <Top>"
643+ << " \n " ;
644+ // Stack is pushed to via push_back. Therefore the top of the stack is the
645+ // end of the vector. Iterate in reverse so that the first thing we print
646+ // is the top of the stack.
647+ size_t stackSize = size ();
648+ for (size_t idx = 0 ; idx < stackSize ;) {
649+ size_t actualIdx = stackSize - 1 - idx;
650+ llvm::dbgs () << " " ;
651+ values[actualIdx].dump ();
652+ }
653+ llvm::dbgs () << " <Bottom>"
654+ << " \n " ;
655+ llvm::dbgs () << " =========================================================\n " ;
656+ }
657+ #endif
658+
659+ parsed_inst_t ValueStack::popOperands (TypeRange operandTypes, Location *opLoc) {
660+ LLVM_DEBUG (llvm::dbgs () << " Popping from ValueStack\n " );
661+ LLVM_DEBUG (llvm::dbgs () << " Elements(s) to pop: " << operandTypes.size ()
662+ << " \n " );
663+ LLVM_DEBUG (llvm::dbgs () << " Current stack size: " << values.size () << " \n " );
664+ if (operandTypes.size () > values.size ())
665+ return emitError (*opLoc,
666+ " Stack doesn't contain enough values. Trying to get " )
667+ << operandTypes.size () << " operands on a stack containing only "
668+ << values.size () << " values." ;
669+ size_t stackIdxOffset = values.size () - operandTypes.size ();
670+ llvm::SmallVector<Value> res{};
671+ res.reserve (operandTypes.size ());
672+ for (size_t i{0 }; i < operandTypes.size (); ++i) {
673+ Value operand = values[i + stackIdxOffset];
674+ Type stackType = operand.getType ();
675+ if (stackType != operandTypes[i])
676+ return emitError (*opLoc,
677+ " Invalid operand type on stack. Expecting " )
678+ << operandTypes[i] << " , value on stack is of type " << stackType
679+ << " ." ;
680+ LLVM_DEBUG (llvm::dbgs () << " POP: " << operand << " \n " );
681+ res.push_back (operand);
682+ }
683+ values.resize (values.size () - operandTypes.size ());
684+ LLVM_DEBUG (llvm::dbgs () << " Updated stack size: " << values.size () << " \n " );
685+ return res;
686+ }
687+
688+ LogicalResult ValueStack::pushResults (ValueRange results, Location *opLoc) {
689+ LLVM_DEBUG (llvm::dbgs () << " Pushing to ValueStack\n " );
690+ LLVM_DEBUG (llvm::dbgs () << " Elements(s) to push: " << results.size ()
691+ << " \n " );
692+ LLVM_DEBUG (llvm::dbgs () << " Current stack size: " << values.size () << " \n " );
693+ for (auto val : results) {
694+ if (!isWasmValueType (val.getType ()))
695+ return emitError (*opLoc, " Invalid value type on stack: " )
696+ << val.getType ();
697+ LLVM_DEBUG (llvm::dbgs () << " PUSH: " << val << " \n " );
698+ values.push_back (val);
699+ }
700+
701+ LLVM_DEBUG (llvm::dbgs () << " Updated stack size: " << values.size () << " \n " );
702+ return success ();
703+ }
704+
705+ template <std::byte EndParseByte>
706+ parsed_inst_t ExpressionParser::parse (OpBuilder &builder, UniqueByte<EndParseByte> endByte) {
707+ auto res = parse (builder, ByteSequence<EndParseByte>{});
708+ if (failed (res))
709+ return failure ();
710+ return res->opResults ;
711+ }
712+
713+ template <std::byte... ExpressionParseEnd>
714+ llvm::FailureOr<ExpressionParser::ParseResultWithInfo>
715+ ExpressionParser::parse (OpBuilder &builder,
716+ ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
717+ llvm::SmallVector<Value> res;
718+ for (;;) {
719+ currentOpLoc = parser.getLocation ();
720+ auto opCode = parser.consumeByte ();
721+ if (failed (opCode))
722+ return failure ();
723+ if (isValueOneOf (*opCode, parsingEndFilters))
724+ return {{res, *opCode}};
725+ parsed_inst_t resParsed;
726+ resParsed = dispatchToInstParser (*opCode, builder);
727+ if (failed (resParsed))
728+ return failure ();
729+ std::swap (res, *resParsed);
730+ if (failed (pushResults (res)))
731+ return failure ();
732+ }
733+ }
734+
735+
736+ template <typename T>
737+ inline Type buildLiteralType (OpBuilder &);
738+
739+ template <>
740+ inline Type buildLiteralType<int32_t >(OpBuilder &builder) {
741+ return builder.getI32Type ();
742+ }
743+
744+ template <>
745+ inline Type buildLiteralType<int64_t >(OpBuilder &builder) {
746+ return builder.getI64Type ();
747+ }
748+
749+ template <>
750+ inline Type buildLiteralType<uint32_t >(OpBuilder &builder) {
751+ return builder.getI32Type ();
752+ }
753+
754+ template <>
755+ inline Type buildLiteralType<uint64_t >(OpBuilder &builder) {
756+ return builder.getI64Type ();
757+ }
758+
759+ template <>
760+ inline Type buildLiteralType<float >(OpBuilder &builder) {
761+ return builder.getF32Type ();
762+ }
763+
764+ template <>
765+ inline Type buildLiteralType<double >(OpBuilder &builder) {
766+ return builder.getF64Type ();
767+ }
768+
769+ template <typename ValT, typename E = std::enable_if_t <std::is_arithmetic_v<ValT>>>
770+ struct AttrHolder ;
771+
772+ template <typename ValT>
773+ struct AttrHolder <ValT, std::enable_if_t <std::is_integral_v<ValT>>> {
774+ using type = IntegerAttr;
775+ };
776+
777+ template <typename ValT>
778+ struct AttrHolder <ValT, std::enable_if_t <std::is_floating_point_v<ValT>>> {
779+ using type = FloatAttr;
780+ };
781+
782+ template <typename ValT>
783+ using attr_holder_t = typename AttrHolder<ValT>::type;
784+
785+ template <typename ValT,
786+ typename EnableT = std::enable_if_t <std::is_arithmetic_v<ValT>>>
787+ attr_holder_t <ValT> buildLiteralAttr (OpBuilder &builder, ValT val) {
788+ return attr_holder_t <ValT>::get (buildLiteralType<ValT>(builder), val);
789+ }
790+
791+ template <typename valueT>
792+ parsed_inst_t ExpressionParser::parseConstInst (
793+ OpBuilder &builder, std::enable_if_t <std::is_arithmetic_v<valueT>> *) {
794+ auto parsedConstant = parser.parseLiteral <valueT>();
795+ if (failed (parsedConstant))
796+ return failure ();
797+ auto constOp = builder.create <ConstOp>(
798+ *currentOpLoc, buildLiteralAttr<valueT>(builder, *parsedConstant));
799+ return {{constOp.getResult ()}};
800+ }
801+
802+ template <>
803+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
804+ WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
805+ return parseConstInst<int32_t >(builder);
806+ }
807+
808+ template <>
809+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
810+ WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
811+ return parseConstInst<int64_t >(builder);
812+ }
813+
814+ template <>
815+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
816+ WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
817+ return parseConstInst<float >(builder);
818+ }
819+
820+ template <>
821+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
822+ WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
823+ return parseConstInst<double >(builder);
824+ }
825+
826+
494827class WasmBinaryParser {
495828private:
496829 struct SectionRegistry {
0 commit comments