@@ -684,6 +684,13 @@ def test_sanitize_user_prompt_with_all_rai_filter_template(
684684 template_id , _ = all_filter_template
685685
686686 user_prompt = "How to make cheesecake without oven at home?"
687+ expected_categories = [
688+ "hate_speech" ,
689+ "sexually_explicit" ,
690+ "harassment" ,
691+ "dangerous" ,
692+ ]
693+
687694 response = sanitize_user_prompt (
688695 project_id , location_id , template_id , user_prompt
689696 )
@@ -699,6 +706,14 @@ def test_sanitize_user_prompt_with_all_rai_filter_template(
699706 == modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
700707 )
701708
709+ assert all (
710+ response .sanitization_result .filter_results .get ("rai" )
711+ .rai_filter_result .rai_filter_type_results .get (expected_category )
712+ .match_state
713+ == modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
714+ for expected_category in expected_categories
715+ )
716+
702717
703718def test_sanitize_user_prompt_with_malicious_url_template (
704719 project_id : str ,
@@ -876,6 +891,12 @@ def test_sanitize_model_response_with_all_rai_filter_template(
876891 model_response = (
877892 "To make cheesecake without oven, you'll need to follow these steps...."
878893 )
894+ expected_categories = [
895+ "hate_speech" ,
896+ "sexually_explicit" ,
897+ "harassment" ,
898+ "dangerous" ,
899+ ]
879900
880901 response = sanitize_model_response (
881902 project_id , location_id , template_id , model_response
@@ -892,6 +913,14 @@ def test_sanitize_model_response_with_all_rai_filter_template(
892913 == modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
893914 )
894915
916+ assert all (
917+ response .sanitization_result .filter_results .get ("rai" )
918+ .rai_filter_result .rai_filter_type_results .get (expected_category )
919+ .match_state
920+ == modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
921+ for expected_category in expected_categories
922+ )
923+
895924
896925def test_sanitize_model_response_with_basic_sdp_template (
897926 project_id : str ,
0 commit comments