1
+ from litellm .integrations .custom_guardrail import CustomGuardrail
1
2
from litellm .proxy .guardrails .guardrail_registry import (
2
3
get_guardrail_initializer_from_hooks ,
4
+ InMemoryGuardrailHandler ,
3
5
)
6
+ from litellm .types .guardrails import GuardrailEventHooks , Guardrail , LitellmParams
4
7
5
8
6
9
def test_get_guardrail_initializer_from_hooks ():
@@ -15,3 +18,33 @@ def test_guardrail_class_registry():
15
18
print (f"guardrail_class_registry: { guardrail_class_registry } " )
16
19
assert "aim" in guardrail_class_registry
17
20
assert "aporia" in guardrail_class_registry
21
+
22
+
23
+ def test_update_in_memory_guardrail ():
24
+ handler = InMemoryGuardrailHandler ()
25
+ handler .guardrail_id_to_custom_guardrail ["123" ] = CustomGuardrail (
26
+ guardrail_name = "test-guardrail" ,
27
+ default_on = False ,
28
+ event_hook = GuardrailEventHooks .pre_call ,
29
+ )
30
+
31
+ handler .update_in_memory_guardrail (
32
+ "123" ,
33
+ Guardrail (
34
+ guardrail_name = "test-guardrail" ,
35
+ litellm_params = LitellmParams (
36
+ guardrail = "test-guardrail" , mode = "pre_call" , default_on = True
37
+ ),
38
+ ),
39
+ )
40
+
41
+ assert (
42
+ handler .guardrail_id_to_custom_guardrail ["123" ].should_run_guardrail (
43
+ data = {}, event_type = GuardrailEventHooks .pre_call
44
+ )
45
+ is True
46
+ )
47
+ assert (
48
+ handler .guardrail_id_to_custom_guardrail ["123" ].event_hook
49
+ is GuardrailEventHooks .pre_call
50
+ )
0 commit comments