Skip to content

Commit 55e196e

Browse files
committed
Trapping exceptions raised during parsing.
This allows users of the ParserProtocol to abort the connection and trap errors by raising exceptions. Additionally, parse errors are exposed if they happen.
1 parent f7ad7e9 commit 55e196e

File tree

2 files changed

+71
-3
lines changed

2 files changed

+71
-3
lines changed

ometa/protocol.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from twisted.internet.protocol import Protocol
2+
from twisted.python.failure import Failure
23

34
from ometa.interp import TrampolinedGrammarInterpreter, _feed_me
45

@@ -10,6 +11,7 @@ def __init__(self, grammar, senderFactory, stateFactory, bindings):
1011
self.bindings = dict(bindings)
1112
self.senderFactory = senderFactory
1213
self.stateFactory = stateFactory
14+
self.disconnecting = False
1315

1416
def setNextRule(self, rule):
1517
self.currentRule = rule
@@ -30,11 +32,24 @@ def _parsedRule(self, nextRule, position):
3032
self.currentRule = nextRule
3133

3234
def dataReceived(self, data):
35+
if self.disconnecting:
36+
return
37+
3338
while data:
34-
if self._interp.receive(data) is _feed_me:
39+
try:
40+
status = self._interp.receive(data)
41+
except Exception:
42+
self.connectionLost(Failure())
43+
self.transport.abortConnection()
3544
return
45+
else:
46+
if status is _feed_me:
47+
return
3648
data = ''.join(self._interp.input.data[self._interp.input.position:])
3749
self._setupInterp()
3850

3951
def connectionLost(self, reason):
52+
if self.disconnecting:
53+
return
4054
self.state.connectionLost(reason)
55+
self.disconnecting = True

ometa/test/test_protocol.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
import unittest
1+
from twisted.trial import unittest
22

33
from ometa.grammar import OMeta
44
from ometa.protocol import ParserProtocol
5+
from ometa.runtime import ParseError
56

67

78
testingGrammarSource = """
89
910
someA = ('a' 'a') -> state('a')
1011
someB = ('b' 'b') -> state('b')
1112
someC = ('c' 'c') -> state('c')
13+
someExc = 'e' -> state.raiseSomething()
1214
13-
initial = someA
15+
initial = someA | someExc
1416
1517
"""
1618
testGrammar = OMeta(testingGrammarSource).parseGrammar('testGrammar')
@@ -21,6 +23,10 @@ def __init__(self, transport):
2123
self.transport = transport
2224

2325

26+
class SomeException(Exception):
27+
pass
28+
29+
2430
class StateFactory(object):
2531
def __init__(self, sender, parser):
2632
self.sender = sender
@@ -37,10 +43,21 @@ def __call__(self, v):
3743
self.calls.append(v)
3844
return self.returnMap.get(v)
3945

46+
def raiseSomething(self):
47+
raise SomeException()
48+
4049
def connectionLost(self, reason):
4150
self.lossReason = reason
4251

4352

53+
class FakeTransport(object):
54+
def __init__(self):
55+
self.aborted = False
56+
57+
def abortConnection(self):
58+
self.aborted = True
59+
60+
4461
class ParserProtocolTestCase(unittest.TestCase):
4562
def setUp(self):
4663
self.protocol = ParserProtocol(
@@ -140,3 +157,39 @@ def test_connectionLoss(self):
140157
reason = object()
141158
self.protocol.connectionLost(reason)
142159
self.assertEqual(self.protocol.state.lossReason, reason)
160+
161+
def test_parseFailure(self):
162+
"""
163+
Parse failures cause connection abortion with the parse error as the
164+
reason.
165+
"""
166+
transport = FakeTransport()
167+
self.protocol.makeConnection(transport)
168+
self.protocol.dataReceived('b')
169+
self.failIfEqual(self.protocol.state.lossReason, None)
170+
self.failUnlessIsInstance(self.protocol.state.lossReason.value,
171+
ParseError)
172+
self.assert_(transport.aborted)
173+
174+
def test_exceptionsRaisedFromState(self):
175+
"""
176+
Raising an exception from state methods called from the grammar
177+
propagate to connectionLost.
178+
"""
179+
transport = FakeTransport()
180+
self.protocol.makeConnection(transport)
181+
self.protocol.dataReceived('e')
182+
self.failIfEqual(self.protocol.state.lossReason, None)
183+
self.failUnlessIsInstance(self.protocol.state.lossReason.value,
184+
SomeException)
185+
self.assert_(transport.aborted)
186+
187+
def test_dataIgnoredAfterDisconnection(self):
188+
"""After connectionLost is called, all incoming data is ignored."""
189+
transport = FakeTransport()
190+
self.protocol.makeConnection(transport)
191+
reason = object()
192+
self.protocol.connectionLost(reason)
193+
self.protocol.dataReceived('d')
194+
self.assertEqual(self.protocol.state.lossReason, reason)
195+
self.assert_(not transport.aborted)

0 commit comments

Comments
 (0)