39
39
40
40
import sys
41
41
import os
42
+ from io import StringIO
43
+
42
44
from importlib import invalidate_caches
43
45
from string import Formatter
44
46
__dir__ = __file__ .rpartition ("/" )[0 ]
@@ -289,7 +291,7 @@ def file_not_empty(path):
289
291
290
292
class CPyExtFunction ():
291
293
292
- def __init__ (self , pfunc , parameters , template = c_template , cmpfunc = None , ** kwargs ):
294
+ def __init__ (self , pfunc , parameters , template = c_template , cmpfunc = None , stderr_validator = None , ** kwargs ):
293
295
self .template = template
294
296
self .pfunc = pfunc
295
297
self .parameters = parameters
@@ -306,6 +308,7 @@ def __init__(self, pfunc, parameters, template=c_template, cmpfunc=None, **kwarg
306
308
kwargs ["resultspec" ] = kwargs ["resultspec" ] if "resultspec" in kwargs else "O"
307
309
self .formatargs = kwargs
308
310
self .cmpfunc = cmpfunc or self .do_compare
311
+ self .stderr_validator = stderr_validator
309
312
310
313
def do_compare (self , x , y ):
311
314
if isinstance (x , BaseException ):
@@ -359,11 +362,18 @@ def test(self):
359
362
cargs = self .parameters ()
360
363
pargs = self .parameters ()
361
364
for i in range (len (cargs )):
362
- cresult = presult = None
365
+ real_stderr = sys .stderr
366
+ sys .stderr = StringIO ()
363
367
try :
364
368
cresult = ctest (cargs [i ])
365
369
except BaseException as e :
366
370
cresult = e
371
+ else :
372
+ if self .stderr_validator :
373
+ s = sys .stderr .getvalue ()
374
+ assert self .stderr_validator (cargs [i ], s ), f"captured stderr didn't match expectations. Stderr: { s } "
375
+ finally :
376
+ sys .stderr = real_stderr
367
377
try :
368
378
presult = self .pfunc (pargs [i ])
369
379
except BaseException as e :
0 commit comments