@@ -266,3 +266,52 @@ async def main() -> None:
266
266
assert "task_exited" not in runner .instruments
267
267
268
268
_core .run (main )
269
+
270
+
271
+ def test_instrument_call_trio_context () -> None :
272
+ called = set ()
273
+
274
+ class Instrument (_abc .Instrument ):
275
+ pass
276
+
277
+ hooks = {
278
+ # category 1
279
+ "after_io_wait" : (True , False ),
280
+ "before_io_wait" : (True , False ),
281
+ "before_run" : (True , False ),
282
+ # category 2
283
+ "after_run" : (False , False ),
284
+ # category 3
285
+ "before_task_step" : (True , True ),
286
+ "after_task_step" : (True , True ),
287
+ "task_exited" : (True , True ),
288
+ # category 4
289
+ "task_scheduled" : (True , None ),
290
+ "task_spawned" : (True , None ),
291
+ }
292
+ for hook , val in hooks .items ():
293
+
294
+ def h (
295
+ self : Instrument ,
296
+ * args : object ,
297
+ hook : str = hook ,
298
+ val : tuple [bool | None , bool | None ] = val ,
299
+ ) -> None :
300
+ fail_str = f"failed in { hook } "
301
+
302
+ if val [0 ] is not None :
303
+ assert _core .in_trio_run () == val [0 ], fail_str
304
+ if val [1 ] is not None :
305
+ assert _core .in_trio_task () == val [1 ], fail_str
306
+ called .add (hook )
307
+
308
+ setattr (Instrument , hook , h )
309
+
310
+ async def main () -> None :
311
+ await _core .checkpoint ()
312
+
313
+ async with _core .open_nursery () as nursery :
314
+ nursery .start_soon (_core .checkpoint )
315
+
316
+ _core .run (main , instruments = [Instrument ()])
317
+ assert called == set (hooks )
0 commit comments