1414from __future__ import absolute_import
1515from ast import literal_eval
1616from enum import Enum
17- from typing import Dict , List , Union , Any
17+ from typing import Dict , List , Optional , Union , Any
1818
1919from sagemaker .jumpstart .types import JumpStartDataHolderType
2020
@@ -38,6 +38,10 @@ class FilterOperators(str, Enum):
3838 NOT_EQUALS = "not_equals"
3939 IN = "in"
4040 NOT_IN = "not_in"
41+ INCLUDES = "includes"
42+ NOT_INCLUDES = "not_includes"
43+ BEGINS_WITH = "begins_with"
44+ ENDS_WITH = "ends_with"
4145
4246
4347class SpecialSupportedFilterKeys (str , Enum ):
@@ -52,6 +56,10 @@ class SpecialSupportedFilterKeys(str, Enum):
5256 FilterOperators .NOT_EQUALS : ["!==" , "!=" , "not equals" , "is not" ],
5357 FilterOperators .IN : ["in" ],
5458 FilterOperators .NOT_IN : ["not in" ],
59+ FilterOperators .INCLUDES : ["includes" , "contains" ],
60+ FilterOperators .NOT_INCLUDES : ["not includes" , "not contains" ],
61+ FilterOperators .BEGINS_WITH : ["begins with" , "starts with" ],
62+ FilterOperators .ENDS_WITH : ["ends with" ],
5563}
5664
5765
@@ -62,7 +70,19 @@ class SpecialSupportedFilterKeys(str, Enum):
6270)
6371
6472ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
65- list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ]))
73+ list (
74+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .BEGINS_WITH ])
75+ )
76+ + list (
77+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .ENDS_WITH ])
78+ )
79+ + list (
80+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_INCLUDES ])
81+ )
82+ + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .INCLUDES ]))
83+ + list (
84+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ])
85+ )
6686 + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_IN ]))
6787 + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .EQUALS ]))
6888 + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .IN ]))
@@ -428,9 +448,96 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
428448 raise ValueError (f"Cannot parse filter string: { filter_string } " )
429449
430450
451+ def _negate_boolean (boolean : BooleanValues ) -> BooleanValues :
452+ """Negates boolean expression (False -> True, True -> False)."""
453+ if boolean == BooleanValues .TRUE :
454+ return BooleanValues .FALSE
455+ if boolean == BooleanValues .FALSE :
456+ return BooleanValues .TRUE
457+ return boolean
458+
459+
460+ def _evaluate_filter_expression_equals (
461+ model_filter : ModelFilter ,
462+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
463+ ) -> BooleanValues :
464+ """Evaluates filter expressions for equals."""
465+ if cached_model_value is None :
466+ return BooleanValues .FALSE
467+ model_filter_value = model_filter .value
468+ if isinstance (cached_model_value , bool ):
469+ cached_model_value = str (cached_model_value ).lower ()
470+ model_filter_value = model_filter .value .lower ()
471+ if str (model_filter_value ) == str (cached_model_value ):
472+ return BooleanValues .TRUE
473+ return BooleanValues .FALSE
474+
475+
476+ def _evaluate_filter_expression_in (
477+ model_filter : ModelFilter ,
478+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
479+ ) -> BooleanValues :
480+ """Evaluates filter expressions for string/list in."""
481+ if cached_model_value is None :
482+ return BooleanValues .FALSE
483+ py_obj = model_filter .value
484+ try :
485+ py_obj = literal_eval (py_obj )
486+ try :
487+ iter (py_obj )
488+ except TypeError :
489+ return BooleanValues .FALSE
490+ except Exception : # pylint: disable=W0703
491+ pass
492+ if isinstance (cached_model_value , list ):
493+ return BooleanValues .FALSE
494+ if cached_model_value in py_obj :
495+ return BooleanValues .TRUE
496+ return BooleanValues .FALSE
497+
498+
499+ def _evaluate_filter_expression_includes (
500+ model_filter : ModelFilter ,
501+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
502+ ) -> BooleanValues :
503+ """Evaluates filter expressions for string includes."""
504+ if cached_model_value is None :
505+ return BooleanValues .FALSE
506+ filter_value = str (model_filter .value )
507+ if filter_value in cached_model_value :
508+ return BooleanValues .TRUE
509+ return BooleanValues .FALSE
510+
511+
512+ def _evaluate_filter_expression_begins_with (
513+ model_filter : ModelFilter ,
514+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
515+ ) -> BooleanValues :
516+ """Evaluates filter expressions for string begins with."""
517+ if cached_model_value is None :
518+ return BooleanValues .FALSE
519+ filter_value = str (model_filter .value )
520+ if cached_model_value .startswith (filter_value ):
521+ return BooleanValues .TRUE
522+ return BooleanValues .FALSE
523+
524+
525+ def _evaluate_filter_expression_ends_with (
526+ model_filter : ModelFilter ,
527+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
528+ ) -> BooleanValues :
529+ """Evaluates filter expressions for string ends with."""
530+ if cached_model_value is None :
531+ return BooleanValues .FALSE
532+ filter_value = str (model_filter .value )
533+ if cached_model_value .endswith (filter_value ):
534+ return BooleanValues .TRUE
535+ return BooleanValues .FALSE
536+
537+
431538def evaluate_filter_expression ( # pylint: disable=too-many-return-statements
432539 model_filter : ModelFilter ,
433- cached_model_value : Union [str , bool , int , float , Dict [str , Any ], List [Any ]],
540+ cached_model_value : Optional [ Union [str , bool , int , float , Dict [str , Any ], List [Any ] ]],
434541) -> BooleanValues :
435542 """Evaluates model filter with cached model spec value, returns boolean.
436543
@@ -440,36 +547,29 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
440547 evaluate the filter.
441548 """
442549 if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .EQUALS ]:
443- model_filter_value = model_filter .value
444- if isinstance (cached_model_value , bool ):
445- cached_model_value = str (cached_model_value ).lower ()
446- model_filter_value = model_filter .value .lower ()
447- if str (model_filter_value ) == str (cached_model_value ):
448- return BooleanValues .TRUE
449- return BooleanValues .FALSE
550+ return _evaluate_filter_expression_equals (model_filter , cached_model_value )
551+
450552 if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ]:
451- if isinstance (cached_model_value , bool ):
452- cached_model_value = str (cached_model_value ).lower ()
453- model_filter .value = model_filter .value .lower ()
454- if str (model_filter .value ) == str (cached_model_value ):
455- return BooleanValues .FALSE
456- return BooleanValues .TRUE
553+ return _negate_boolean (_evaluate_filter_expression_equals (model_filter , cached_model_value ))
554+
457555 if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .IN ]:
458- py_obj = literal_eval (model_filter .value )
459- try :
460- iter (py_obj )
461- except TypeError :
462- return BooleanValues .FALSE
463- if cached_model_value in py_obj :
464- return BooleanValues .TRUE
465- return BooleanValues .FALSE
556+ return _evaluate_filter_expression_in (model_filter , cached_model_value )
557+
466558 if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_IN ]:
467- py_obj = literal_eval (model_filter .value )
468- try :
469- iter (py_obj )
470- except TypeError :
471- return BooleanValues .TRUE
472- if cached_model_value in py_obj :
473- return BooleanValues .FALSE
474- return BooleanValues .TRUE
559+ return _negate_boolean (_evaluate_filter_expression_in (model_filter , cached_model_value ))
560+
561+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .INCLUDES ]:
562+ return _evaluate_filter_expression_includes (model_filter , cached_model_value )
563+
564+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_INCLUDES ]:
565+ return _negate_boolean (
566+ _evaluate_filter_expression_includes (model_filter , cached_model_value )
567+ )
568+
569+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .BEGINS_WITH ]:
570+ return _evaluate_filter_expression_begins_with (model_filter , cached_model_value )
571+
572+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .ENDS_WITH ]:
573+ return _evaluate_filter_expression_ends_with (model_filter , cached_model_value )
574+
475575 raise RuntimeError (f"Bad operator: { model_filter .operator } " )
0 commit comments