@@ -208,3 +208,264 @@ def test_hook(activation, hook):
208208 # Since hook returns None, the original input should be returned
209209 # (HookPoint's forward method returns the input when no valid hook result)
210210 assert torch .equal (result , test_input )
211+
212+
213+ class TestHookPointHasHooks :
214+ """Comprehensive test suite for HookPoint.has_hooks method."""
215+
216+ def setup_method (self ):
217+ """Set up fresh HookPoint and sample hook for each test."""
218+ self .hook_point = HookPoint ()
219+
220+ def sample_hook (activation , hook ):
221+ return activation
222+
223+ self .sample_hook = sample_hook
224+
225+ def test_no_hooks_returns_false (self ):
226+ """Test that has_hooks returns False when no hooks are present."""
227+ assert not self .hook_point .has_hooks ()
228+ assert not self .hook_point .has_hooks (dir = "fwd" )
229+ assert not self .hook_point .has_hooks (dir = "bwd" )
230+ assert not self .hook_point .has_hooks (dir = "both" )
231+
232+ def test_forward_hook_detection (self ):
233+ """Test detection of forward hooks."""
234+ # Add a forward hook
235+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" )
236+
237+ # Should detect forward hooks
238+ assert self .hook_point .has_hooks ()
239+ assert self .hook_point .has_hooks (dir = "fwd" )
240+ assert self .hook_point .has_hooks (dir = "both" )
241+
242+ # Should not detect backward hooks
243+ assert not self .hook_point .has_hooks (dir = "bwd" )
244+
245+ def test_backward_hook_detection (self ):
246+ """Test detection of backward hooks."""
247+ # Add a backward hook
248+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" )
249+
250+ # Should detect backward hooks
251+ assert self .hook_point .has_hooks ()
252+ assert self .hook_point .has_hooks (dir = "bwd" )
253+ assert self .hook_point .has_hooks (dir = "both" )
254+
255+ # Should not detect forward hooks
256+ assert not self .hook_point .has_hooks (dir = "fwd" )
257+
258+ def test_both_direction_hooks (self ):
259+ """Test detection when both forward and backward hooks are present."""
260+ # Add both forward and backward hooks
261+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" )
262+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" )
263+
264+ # All directions should detect hooks
265+ assert self .hook_point .has_hooks ()
266+ assert self .hook_point .has_hooks (dir = "fwd" )
267+ assert self .hook_point .has_hooks (dir = "bwd" )
268+ assert self .hook_point .has_hooks (dir = "both" )
269+
270+ def test_permanent_hook_detection (self ):
271+ """Test detection of permanent hooks."""
272+ # Add a permanent forward hook
273+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , is_permanent = True )
274+
275+ # Should detect permanent hooks by default
276+ assert self .hook_point .has_hooks ()
277+ assert self .hook_point .has_hooks (including_permanent = True )
278+
279+ # Should not detect when excluding permanent hooks
280+ assert not self .hook_point .has_hooks (including_permanent = False )
281+
282+ def test_non_permanent_hook_detection (self ):
283+ """Test detection of non-permanent hooks."""
284+ # Add a non-permanent forward hook
285+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , is_permanent = False )
286+
287+ # Should detect non-permanent hooks regardless of including_permanent setting
288+ assert self .hook_point .has_hooks ()
289+ assert self .hook_point .has_hooks (including_permanent = True )
290+ assert self .hook_point .has_hooks (including_permanent = False )
291+
292+ def test_mixed_permanent_hooks (self ):
293+ """Test detection with mix of permanent and non-permanent hooks."""
294+ # Add both permanent and non-permanent hooks
295+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , is_permanent = True )
296+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , is_permanent = False )
297+
298+ # Should detect hooks in both cases
299+ assert self .hook_point .has_hooks (including_permanent = True )
300+ assert self .hook_point .has_hooks (including_permanent = False )
301+
302+ def test_only_permanent_hooks (self ):
303+ """Test detection when only permanent hooks are present."""
304+ # Add only permanent hooks
305+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , is_permanent = True )
306+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" , is_permanent = True )
307+
308+ # Should detect when including permanent
309+ assert self .hook_point .has_hooks (including_permanent = True )
310+ assert self .hook_point .has_hooks (dir = "fwd" , including_permanent = True )
311+ assert self .hook_point .has_hooks (dir = "bwd" , including_permanent = True )
312+
313+ # Should not detect when excluding permanent
314+ assert not self .hook_point .has_hooks (including_permanent = False )
315+ assert not self .hook_point .has_hooks (dir = "fwd" , including_permanent = False )
316+ assert not self .hook_point .has_hooks (dir = "bwd" , including_permanent = False )
317+
318+ def test_context_level_filtering (self ):
319+ """Test context level filtering functionality."""
320+ # Add hooks at different context levels
321+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 0 )
322+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 1 )
323+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" , level = 2 )
324+
325+ # Should detect hooks at specific levels
326+ assert self .hook_point .has_hooks (level = 0 )
327+ assert self .hook_point .has_hooks (level = 1 )
328+ assert self .hook_point .has_hooks (level = 2 )
329+
330+ # Should not detect hooks at non-existent levels
331+ assert not self .hook_point .has_hooks (level = 3 )
332+ assert not self .hook_point .has_hooks (level = - 1 )
333+
334+ # Should detect all hooks when level is None
335+ assert self .hook_point .has_hooks (level = None )
336+
337+ def test_context_level_with_direction (self ):
338+ """Test context level filtering combined with direction filtering."""
339+ # Add hooks at different levels and directions
340+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 0 )
341+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" , level = 1 )
342+
343+ # Test specific combinations
344+ assert self .hook_point .has_hooks (dir = "fwd" , level = 0 )
345+ assert self .hook_point .has_hooks (dir = "bwd" , level = 1 )
346+
347+ # Test non-matching combinations
348+ assert not self .hook_point .has_hooks (dir = "fwd" , level = 1 )
349+ assert not self .hook_point .has_hooks (dir = "bwd" , level = 0 )
350+
351+ def test_context_level_with_permanent_flags (self ):
352+ """Test context level filtering combined with permanent hook filtering."""
353+ # Add permanent and non-permanent hooks at different levels
354+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 0 , is_permanent = True )
355+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 1 , is_permanent = False )
356+
357+ # Test combinations
358+ assert self .hook_point .has_hooks (level = 0 , including_permanent = True )
359+ assert not self .hook_point .has_hooks (level = 0 , including_permanent = False )
360+ assert self .hook_point .has_hooks (level = 1 , including_permanent = True )
361+ assert self .hook_point .has_hooks (level = 1 , including_permanent = False )
362+
363+ def test_all_parameters_combined (self ):
364+ """Test all parameters combined in various ways."""
365+ # Create a complex setup with multiple hooks
366+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 0 , is_permanent = True )
367+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 1 , is_permanent = False )
368+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" , level = 0 , is_permanent = False )
369+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" , level = 2 , is_permanent = True )
370+
371+ # Test specific combinations
372+ assert self .hook_point .has_hooks (dir = "fwd" , level = 0 , including_permanent = True )
373+ assert not self .hook_point .has_hooks (dir = "fwd" , level = 0 , including_permanent = False )
374+ assert self .hook_point .has_hooks (dir = "fwd" , level = 1 , including_permanent = False )
375+ assert self .hook_point .has_hooks (dir = "bwd" , level = 0 , including_permanent = False )
376+ assert not self .hook_point .has_hooks (dir = "bwd" , level = 1 , including_permanent = False )
377+ assert self .hook_point .has_hooks (dir = "bwd" , level = 2 , including_permanent = True )
378+
379+ def test_invalid_direction_raises_error (self ):
380+ """Test that invalid direction parameter raises error (caught by type checking)."""
381+ # Note: beartype catches this at the parameter level before reaching the ValueError
382+ import pytest
383+ from beartype .roar import BeartypeCallHintParamViolation
384+
385+ with pytest .raises (BeartypeCallHintParamViolation ):
386+ self .hook_point .has_hooks (dir = "invalid" ) # type: ignore
387+
388+ def test_multiple_hooks_same_criteria (self ):
389+ """Test detection when multiple hooks match the same criteria."""
390+ # Add multiple hooks with same criteria
391+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 0 , is_permanent = False )
392+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 0 , is_permanent = False )
393+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , level = 0 , is_permanent = False )
394+
395+ # Should still detect hooks (method returns True on first match)
396+ assert self .hook_point .has_hooks (dir = "fwd" , level = 0 , including_permanent = False )
397+
398+ def test_hook_removal_affects_detection (self ):
399+ """Test that removing hooks affects detection."""
400+ # Add a hook
401+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" )
402+ assert self .hook_point .has_hooks ()
403+
404+ # Remove all hooks
405+ self .hook_point .remove_hooks (dir = "both" )
406+ assert not self .hook_point .has_hooks ()
407+
408+ def test_default_parameter_values (self ):
409+ """Test that default parameter values work correctly."""
410+ # Add hooks to test defaults
411+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , is_permanent = True , level = 0 )
412+ self .hook_point .add_hook (self .sample_hook , dir = "bwd" , is_permanent = False , level = 1 )
413+
414+ # Test default behavior (dir="both", including_permanent=True, level=None)
415+ assert self .hook_point .has_hooks ()
416+
417+ # This should be equivalent to:
418+ assert self .hook_point .has_hooks (dir = "both" , including_permanent = True , level = None )
419+
420+ def test_edge_case_empty_after_filtering (self ):
421+ """Test edge case where hooks exist but are filtered out."""
422+ # Add hooks that will be filtered out
423+ self .hook_point .add_hook (self .sample_hook , dir = "fwd" , is_permanent = True , level = 5 )
424+
425+ # These should not detect the hook due to filtering
426+ assert not self .hook_point .has_hooks (including_permanent = False )
427+ assert not self .hook_point .has_hooks (dir = "bwd" )
428+ assert not self .hook_point .has_hooks (level = 0 )
429+ assert not self .hook_point .has_hooks (dir = "bwd" , level = 5 , including_permanent = True )
430+
431+ def test_functional_hook_execution_still_works (self ):
432+ """Test that has_hooks doesn't interfere with actual hook functionality."""
433+ import torch
434+
435+ results = []
436+
437+ def test_hook (activation , hook ):
438+ results .append ("hook_called" )
439+ return activation
440+
441+ # Add hook and verify detection
442+ self .hook_point .add_hook (test_hook , dir = "fwd" )
443+ assert self .hook_point .has_hooks ()
444+
445+ # Execute hook and verify it still works
446+ test_input = torch .tensor ([1.0 , 2.0 , 3.0 ])
447+ output = self .hook_point (test_input )
448+
449+ assert torch .equal (output , test_input )
450+ assert "hook_called" in results
451+
452+ def test_hook_point_with_conversions (self ):
453+ """Test has_hooks with hook conversions if they exist."""
454+ import torch
455+
456+ # This test ensures has_hooks works even when hook conversions are involved
457+ def simple_hook (activation , hook ):
458+ return activation * 2
459+
460+ # Add hook
461+ self .hook_point .add_hook (simple_hook , dir = "fwd" )
462+
463+ # Should detect hook regardless of any internal conversions
464+ assert self .hook_point .has_hooks ()
465+ assert self .hook_point .has_hooks (dir = "fwd" )
466+
467+ # Test actual functionality still works
468+ test_input = torch .tensor ([1.0 , 2.0 ])
469+ output = self .hook_point (test_input )
470+ expected = torch .tensor ([2.0 , 4.0 ])
471+ assert torch .allclose (output , expected )
0 commit comments