@@ -50,8 +50,7 @@ WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
5050
5151void WebAssemblyAsmTypeCheck::funcDecl (const wasm::WasmSignature &Sig) {
5252 LocalTypes.assign (Sig.Params .begin (), Sig.Params .end ());
53- ReturnTypes.assign (Sig.Returns .begin (), Sig.Returns .end ());
54- BrStack.emplace_back (Sig.Returns .begin (), Sig.Returns .end ());
53+ BlockInfoStack.push_back ({Sig, 0 , false });
5554}
5655
5756void WebAssemblyAsmTypeCheck::localDecl (
@@ -64,14 +63,15 @@ void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
6463}
6564
6665bool WebAssemblyAsmTypeCheck::typeError (SMLoc ErrorLoc, const Twine &Msg) {
67- // If we're currently in unreachable code, we suppress errors completely.
68- if (Unreachable)
69- return false ;
7066 dumpTypeStack (" current stack: " );
7167 return Parser.Error (ErrorLoc, Msg);
7268}
7369
7470bool WebAssemblyAsmTypeCheck::match (StackType TypeA, StackType TypeB) {
71+ // These should have been filtered out in checkTypes()
72+ assert (!std::get_if<Polymorphic>(&TypeA) &&
73+ !std::get_if<Polymorphic>(&TypeB));
74+
7575 if (TypeA == TypeB)
7676 return false ;
7777 if (std::get_if<Any>(&TypeA) || std::get_if<Any>(&TypeB))
@@ -90,6 +90,10 @@ std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
9090 size_t StartPos) {
9191 SmallVector<std::string, 4 > TypeStrs;
9292 for (auto I = Types.size (); I > StartPos; I--) {
93+ if (std::get_if<Polymorphic>(&Types[I - 1 ])) {
94+ TypeStrs.push_back (" ..." );
95+ break ;
96+ }
9397 if (std::get_if<Any>(&Types[I - 1 ]))
9498 TypeStrs.push_back (" any" );
9599 else if (std::get_if<Ref>(&Types[I - 1 ]))
@@ -131,29 +135,48 @@ bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
131135 bool ExactMatch) {
132136 auto StackI = Stack.size ();
133137 auto TypeI = Types.size ();
138+ assert (!BlockInfoStack.empty ());
139+ auto BlockStackStartPos = BlockInfoStack.back ().StackStartPos ;
134140 bool Error = false ;
141+ bool PolymorphicStack = false ;
135142 // Compare elements one by one from the stack top
136- for (; StackI > 0 && TypeI > 0 ; StackI--, TypeI--) {
143+ for (; StackI > BlockStackStartPos && TypeI > 0 ; StackI--, TypeI--) {
144+ // If the stack is polymorphic, we assume all types in 'Types' have been
145+ // compared and matched
146+ if (std::get_if<Polymorphic>(&Stack[StackI - 1 ])) {
147+ TypeI = 0 ;
148+ break ;
149+ }
137150 if (match (Stack[StackI - 1 ], Types[TypeI - 1 ])) {
138151 Error = true ;
139152 break ;
140153 }
141154 }
155+
156+ // If the stack top is polymorphic, the stack is in the polymorphic state.
157+ if (StackI > BlockStackStartPos &&
158+ std::get_if<Polymorphic>(&Stack[StackI - 1 ]))
159+ PolymorphicStack = true ;
160+
142161 // Even if no match failure has happened in the loop above, if not all
143162 // elements of Types has been matched, that means we don't have enough
144163 // elements on the stack.
145164 //
146165 // Also, if not all elements of the Stack has been matched and when
147- // 'ExactMatch' is true, that means we have superfluous elements remaining on
148- // the stack (e.g. at the end of a function).
149- if (TypeI > 0 || (ExactMatch && StackI > 0 ))
166+ // 'ExactMatch' is true and the current stack is not polymorphic, that means
167+ // we have superfluous elements remaining on the stack (e.g. at the end of a
168+ // function).
169+ if (TypeI > 0 ||
170+ (ExactMatch && !PolymorphicStack && StackI > BlockStackStartPos))
150171 Error = true ;
151172
152173 if (!Error)
153174 return false ;
154175
155- auto StackStartPos =
156- ExactMatch ? 0 : std::max (0 , (int )Stack.size () - (int )Types.size ());
176+ auto StackStartPos = ExactMatch
177+ ? BlockStackStartPos
178+ : std::max ((int )BlockStackStartPos,
179+ (int )Stack.size () - (int )Types.size ());
157180 return typeError (ErrorLoc, " type mismatch, expected " +
158181 getTypesString (Types, 0 ) + " but got " +
159182 getTypesString (Stack, StackStartPos));
@@ -169,9 +192,13 @@ bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
169192 ArrayRef<StackType> Types,
170193 bool ExactMatch) {
171194 bool Error = checkTypes (ErrorLoc, Types, ExactMatch);
172- auto NumPops = std::min (Stack.size (), Types.size ());
173- for (size_t I = 0 , E = NumPops; I != E; I++)
195+ auto NumPops = std::min (Stack.size () - BlockInfoStack.back ().StackStartPos ,
196+ Types.size ());
197+ for (size_t I = 0 , E = NumPops; I != E; I++) {
198+ if (std::get_if<Polymorphic>(&Stack.back ()))
199+ break ;
174200 Stack.pop_back ();
201+ }
175202 return Error;
176203}
177204
@@ -201,25 +228,6 @@ bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
201228 return false ;
202229}
203230
204- bool WebAssemblyAsmTypeCheck::checkBr (SMLoc ErrorLoc, size_t Level) {
205- if (Level >= BrStack.size ())
206- return typeError (ErrorLoc,
207- StringRef (" br: invalid depth " ) + std::to_string (Level));
208- const SmallVector<wasm::ValType, 4 > &Expected =
209- BrStack[BrStack.size () - Level - 1 ];
210- return checkTypes (ErrorLoc, Expected);
211- return false ;
212- }
213-
214- bool WebAssemblyAsmTypeCheck::checkEnd (SMLoc ErrorLoc, bool PopVals) {
215- if (!PopVals)
216- BrStack.pop_back ();
217-
218- if (PopVals)
219- return popTypes (ErrorLoc, LastSig.Returns );
220- return checkTypes (ErrorLoc, LastSig.Returns );
221- }
222-
223231bool WebAssemblyAsmTypeCheck::checkSig (SMLoc ErrorLoc,
224232 const wasm::WasmSignature &Sig) {
225233 bool Error = popTypes (ErrorLoc, Sig.Params );
@@ -309,9 +317,9 @@ bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
309317}
310318
311319bool WebAssemblyAsmTypeCheck::endOfFunction (SMLoc ErrorLoc, bool ExactMatch) {
312- bool Error = popTypes (ErrorLoc, ReturnTypes, ExactMatch );
313- Unreachable = true ;
314- return Error ;
320+ assert (!BlockInfoStack. empty () );
321+ const auto &FuncInfo = BlockInfoStack[ 0 ] ;
322+ return checkTypes (ErrorLoc, FuncInfo. Sig . Returns , ExactMatch) ;
315323}
316324
317325bool WebAssemblyAsmTypeCheck::typeCheck (SMLoc ErrorLoc, const MCInst &Inst,
@@ -452,52 +460,91 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
452460 return popType (ErrorLoc, Any{});
453461 }
454462
455- if (Name == " try" || Name == " block" || Name == " loop" || Name == " if" ) {
456- if (Name == " loop" )
457- BrStack.emplace_back (LastSig.Params .begin (), LastSig.Params .end ());
458- else
459- BrStack.emplace_back (LastSig.Returns .begin (), LastSig.Returns .end ());
460- if (Name == " if" && popType (ErrorLoc, wasm::ValType::I32))
461- return true ;
462- return false ;
463+ if (Name == " block" || Name == " loop" || Name == " if" || Name == " try" ) {
464+ bool Error = Name == " if" && popType (ErrorLoc, wasm::ValType::I32);
465+ // Pop block input parameters and check their types are correct
466+ Error |= popTypes (ErrorLoc, LastSig.Params );
467+ // Push a new block info
468+ BlockInfoStack.push_back ({LastSig, Stack.size (), Name == " loop" });
469+ // Push back block input parameters
470+ pushTypes (LastSig.Params );
471+ return Error;
463472 }
464473
465474 if (Name == " end_block" || Name == " end_loop" || Name == " end_if" ||
466- Name == " else" || Name == " end_try" || Name == " catch" ||
467- Name == " catch_all" || Name == " delegate" ) {
468- bool Error = checkEnd (ErrorLoc, Name == " else" || Name == " catch" ||
469- Name == " catch_all" );
470- Unreachable = false ;
471- if (Name == " catch" ) {
475+ Name == " end_try" || Name == " delegate" || Name == " else" ||
476+ Name == " catch" || Name == " catch_all" ) {
477+ assert (!BlockInfoStack.empty ());
478+ // Check if the types on the stack match with the block return type
479+ const auto &LastBlockInfo = BlockInfoStack.back ();
480+ bool Error = checkTypes (ErrorLoc, LastBlockInfo.Sig .Returns , true );
481+ // Pop all types added to the stack for the current block level
482+ Stack.truncate (LastBlockInfo.StackStartPos );
483+ if (Name == " else" ) {
484+ // 'else' expects the block input parameters to be on the stack, in the
485+ // same way we entered 'if'
486+ pushTypes (LastBlockInfo.Sig .Params );
487+ } else if (Name == " catch" ) {
488+ // 'catch' instruction pushes values whose types are specified in the
489+ // tag's 'params' part
472490 const wasm::WasmSignature *Sig = nullptr ;
473491 if (!getSignature (Operands[1 ]->getStartLoc (), Inst.getOperand (0 ),
474492 wasm::WASM_SYMBOL_TYPE_TAG, Sig))
475- // catch instruction pushes values whose types are specified in the
476- // tag's "params" part
477493 pushTypes (Sig->Params );
478494 else
479495 Error = true ;
496+ } else if (Name == " catch_all" ) {
497+ // 'catch_all' does not push anything onto the stack
498+ } else {
499+ // For normal end markers, push block return value types onto the stack
500+ // and pop the block info
501+ pushTypes (LastBlockInfo.Sig .Returns );
502+ BlockInfoStack.pop_back ();
480503 }
481504 return Error;
482505 }
483506
484- if (Name == " br" ) {
507+ if (Name == " br" || Name == " br_if" ) {
508+ bool Error = false ;
509+ if (Name == " br_if" )
510+ Error |= popType (ErrorLoc, wasm::ValType::I32); // cond
485511 const MCOperand &Operand = Inst.getOperand (0 );
486- if (!Operand.isImm ())
487- return true ;
488- return checkBr (ErrorLoc, static_cast <size_t >(Operand.getImm ()));
512+ if (Operand.isImm ()) {
513+ unsigned Level = Operand.getImm ();
514+ if (Level < BlockInfoStack.size ()) {
515+ const auto &DestBlockInfo =
516+ BlockInfoStack[BlockInfoStack.size () - Level - 1 ];
517+ if (DestBlockInfo.IsLoop )
518+ Error |= checkTypes (ErrorLoc, DestBlockInfo.Sig .Params , false );
519+ else
520+ Error |= checkTypes (ErrorLoc, DestBlockInfo.Sig .Returns , false );
521+ } else {
522+ Error = typeError (ErrorLoc, StringRef (" br: invalid depth " ) +
523+ std::to_string (Level));
524+ }
525+ } else {
526+ Error =
527+ typeError (Operands[1 ]->getStartLoc (), " depth should be an integer" );
528+ }
529+ if (Name == " br" )
530+ pushType (Polymorphic{});
531+ return Error;
489532 }
490533
491534 if (Name == " return" ) {
492- return endOfFunction (ErrorLoc, false );
535+ bool Error = endOfFunction (ErrorLoc, false );
536+ pushType (Polymorphic{});
537+ return Error;
493538 }
494539
495540 if (Name == " call_indirect" || Name == " return_call_indirect" ) {
496541 // Function value.
497542 bool Error = popType (ErrorLoc, wasm::ValType::I32);
498543 Error |= checkSig (ErrorLoc, LastSig);
499- if (Name == " return_call_indirect" && endOfFunction (ErrorLoc, false ))
500- return true ;
544+ if (Name == " return_call_indirect" ) {
545+ Error |= endOfFunction (ErrorLoc, false );
546+ pushType (Polymorphic{});
547+ }
501548 return Error;
502549 }
503550
@@ -509,13 +556,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
509556 Error |= checkSig (ErrorLoc, *Sig);
510557 else
511558 Error = true ;
512- if (Name == " return_call" && endOfFunction (ErrorLoc, false ))
513- return true ;
559+ if (Name == " return_call" ) {
560+ Error |= endOfFunction (ErrorLoc, false );
561+ pushType (Polymorphic{});
562+ }
514563 return Error;
515564 }
516565
517566 if (Name == " unreachable" ) {
518- Unreachable = true ;
567+ pushType (Polymorphic{}) ;
519568 return false ;
520569 }
521570
@@ -526,11 +575,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
526575 }
527576
528577 if (Name == " throw" ) {
578+ bool Error = false ;
529579 const wasm::WasmSignature *Sig = nullptr ;
530580 if (!getSignature (Operands[1 ]->getStartLoc (), Inst.getOperand (0 ),
531581 wasm::WASM_SYMBOL_TYPE_TAG, Sig))
532- return checkSig (ErrorLoc, *Sig);
533- return true ;
582+ Error |= checkSig (ErrorLoc, *Sig);
583+ else
584+ Error = true ;
585+ pushType (Polymorphic{});
586+ return Error;
534587 }
535588
536589 // The current instruction is a stack instruction which doesn't have
0 commit comments