Skip to content

Commit a121057

Browse files
authored
Add Search Expression support (#70)
* Add Search Expression class to help with search * Remove a test that is causing throttling * lower search result to prevent throttling * Use search expression to prevent throttling * fix filter name * address comments
1 parent fbffb33 commit a121057

File tree

9 files changed

+242
-12
lines changed

9 files changed

+242
-12
lines changed

src/smexperiments/experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def search(
170170
return super(Experiment, cls)._search(
171171
search_resource="Experiment",
172172
search_item_factory=api_types.ExperimentSearchResult.from_boto,
173-
search_expression=search_expression,
173+
search_expression=None if search_expression is None else search_expression.to_boto(),
174174
sort_by=sort_by,
175175
sort_order=sort_order,
176176
max_results=max_results,
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Simplify Search Expression by provide a simplified DSL"""
14+
from smexperiments._base_types import ApiObject
15+
from enum import Enum, unique
16+
17+
18+
@unique
19+
class Operator(Enum):
20+
EQUALS = "Equals"
21+
NOT_EQUALS = "NotEquals"
22+
GREATER_THAN = "GreaterThan"
23+
GREATER_THAN_OR_EQUAL = "GreaterThanOrEqualTo"
24+
LESS_THAN = "LessThan"
25+
LESS_THAN_OR_EQUAL = "LessThanOrEqualTo"
26+
CONTAINS = "Contains"
27+
EXISTS = "Exists"
28+
NOT_EXISTS = "NotExists"
29+
30+
31+
@unique
32+
class BooleanOperator(Enum):
33+
AND = "And"
34+
OR = "Or"
35+
36+
37+
class SearchObject(ApiObject):
38+
def to_boto(self):
39+
return ApiObject.to_boto(self)
40+
41+
42+
class Filter(SearchObject):
43+
"""
44+
A Python class represent a Search Filter object.
45+
"""
46+
47+
name = None
48+
operator = None
49+
value = None
50+
51+
def __init__(self, name, operator=None, value=None):
52+
"""Construct a Filter object
53+
54+
Args:
55+
name (str): filter field name
56+
operator (dict): one of Operator enum
57+
value (str): value of the field
58+
"""
59+
self.name = name
60+
self.operator = None if operator is None else operator.value
61+
self.value = value
62+
63+
64+
class NestedFilter(SearchObject):
65+
"""
66+
A Python class represent a Nested Filter object.
67+
"""
68+
69+
nested_property_name = None
70+
filters = None
71+
72+
def __init__(self, property_name, filters):
73+
"""Construct a Nested Filter object
74+
75+
Args:
76+
property_name (str): nested property name
77+
filters (list): list of Filter objects
78+
"""
79+
self.nested_property_name = property_name
80+
self.filters = list(map(lambda x: x.to_boto(), filters))
81+
82+
83+
class SearchExpression(SearchObject):
84+
"""
85+
A Python class representation of a Search Expression object. A sample search expression defined in here:
86+
https://boto3.amazonaws.com/v1/documentation/api/1.12.8/reference/services/sagemaker.html#SageMaker.Client.search
87+
"""
88+
89+
filters = None
90+
nested_filters = None
91+
operator = None
92+
sub_expressions = None
93+
94+
def __init__(self, filters=None, nested_filters=None, sub_expressions=None, boolean_operator=BooleanOperator.AND):
95+
"""Construct a Search Expression object
96+
97+
Args:
98+
filters (list): list of Filter objects
99+
nested_filters (list): list of Nested Filters objects
100+
sub_expressions (list): list of Search Expresssion objects
101+
operator (dict): one of the boolean operator enums
102+
"""
103+
if filters is None and nested_filters is None and sub_expressions is None:
104+
raise ValueError("You must specify at least one subexpression, filter, or nested filter")
105+
self.filters = None if filters is None else list(map(lambda x: x.to_boto(), filters))
106+
self.nested_filters = None if nested_filters is None else list(map(lambda x: x.to_boto(), nested_filters))
107+
self.sub_expressions = None if sub_expressions is None else list(map(lambda x: x.to_boto(), sub_expressions))
108+
self.operator = boolean_operator.value

src/smexperiments/trial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def search(
189189
return super(Trial, cls)._search(
190190
search_resource="ExperimentTrial",
191191
search_item_factory=api_types.TrialSearchResult.from_boto,
192-
search_expression=search_expression,
192+
search_expression=None if search_expression is None else search_expression.to_boto(),
193193
sort_by=sort_by,
194194
sort_order=sort_order,
195195
max_results=max_results,

src/smexperiments/trial_component.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def search(
205205
return super(TrialComponent, cls)._search(
206206
search_resource="ExperimentTrialComponent",
207207
search_item_factory=api_types.TrialComponentSearchResult.from_boto,
208-
search_expression=search_expression,
208+
search_expression=None if search_expression is None else search_expression.to_boto(),
209209
sort_by=sort_by,
210210
sort_order=sort_order,
211211
max_results=max_results,

tests/integ/test_experiment.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from tests.helpers import name
1616
from smexperiments import experiment, trial
17+
from smexperiments.search_expression import SearchExpression, Filter, Operator
1718

1819

1920
def test_create_delete(experiment_obj):
@@ -87,9 +88,12 @@ def test_list_sort(sagemaker_boto_client, experiments):
8788

8889
def test_search(sagemaker_boto_client):
8990
experiment_names_searched = []
90-
for s in experiment.Experiment.search(max_results=10, sagemaker_boto_client=sagemaker_boto_client):
91-
if "smexperiments-integ-" in s.experiment_name:
92-
experiment_names_searched.append(s.experiment_name)
91+
search_filter = Filter(name="ExperimentName", operator=Operator.CONTAINS, value="smexperiments-integ-")
92+
search_expression = SearchExpression(filters=[search_filter])
93+
for s in experiment.Experiment.search(
94+
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client
95+
):
96+
experiment_names_searched.append(s.experiment_name)
9397

9498
assert len(experiment_names_searched) > 0
9599
assert experiment_names_searched # sanity test
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from smexperiments.search_expression import Filter, Operator, SearchExpression, NestedFilter
14+
from smexperiments.experiment import Experiment
15+
import pytest
16+
17+
18+
def test_search(sagemaker_boto_client):
19+
experiment_names_searched = []
20+
search_filter = Filter(name="ExperimentName", operator=Operator.CONTAINS, value="smexperiments-integ-")
21+
search_expression = SearchExpression(filters=[search_filter])
22+
for s in Experiment.search(
23+
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client
24+
):
25+
experiment_names_searched.append(s.experiment_name)
26+
27+
assert len(experiment_names_searched) > 0
28+
assert experiment_names_searched # sanity test
29+
30+
31+
@pytest.mark.skip(reason="failed validation, need to wait for NestedFilter bug to be fixed")
32+
def test_nested_search(sagemaker_boto_client):
33+
experiment_names_searched = []
34+
search_filter = Filter(name="ExperimentName", operator=Operator.CONTAINS, value="smexperiments-integ-")
35+
nested_filter = NestedFilter(property_name="ExperimentName", filters=[search_filter])
36+
search_expression = SearchExpression(nested_filters=[nested_filter])
37+
for s in Experiment.search(
38+
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client
39+
):
40+
experiment_names_searched.append(s.experiment_name)
41+
42+
assert len(experiment_names_searched) > 0
43+
assert experiment_names_searched # sanity test
44+
45+
46+
def test_sub_expression(sagemaker_boto_client):
47+
experiment_names_searched = []
48+
search_filter = Filter(name="ExperimentName", operator=Operator.CONTAINS, value="smexperiments-integ-")
49+
sub_expression = SearchExpression(filters=[search_filter])
50+
search_expression = SearchExpression(sub_expressions=[sub_expression])
51+
for s in Experiment.search(
52+
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client
53+
):
54+
experiment_names_searched.append(s.experiment_name)
55+
56+
assert len(experiment_names_searched) > 0
57+
assert experiment_names_searched # sanity test

tests/integ/test_trial.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import datetime
1515

1616
from smexperiments import trial
17+
from smexperiments.search_expression import SearchExpression, Filter, Operator
1718

1819

1920
def test_create_delete(trial_obj):
@@ -78,9 +79,12 @@ def test_list_sort(trials, sagemaker_boto_client):
7879

7980
def test_search(sagemaker_boto_client):
8081
trial_names_searched = []
81-
for s in trial.Trial.search(max_results=10, sagemaker_boto_client=sagemaker_boto_client):
82-
if "smexperiments-integ-" in s.trial_name:
83-
trial_names_searched.append(s.trial_name)
82+
search_filter = Filter(name="ExperimentName", operator=Operator.CONTAINS, value="smexperiments-integ-")
83+
search_expression = SearchExpression(filters=[search_filter])
84+
for s in trial.Trial.search(
85+
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client
86+
):
87+
trial_names_searched.append(s.trial_name)
8488

8589
assert len(trial_names_searched) > 0
8690
assert trial_names_searched # sanity test

tests/integ/test_trial_component.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import uuid
1616

1717
from smexperiments import api_types, trial_component
18+
from smexperiments.search_expression import Filter, SearchExpression, Operator
1819

1920

2021
def test_create_delete(trial_component_obj):
@@ -103,9 +104,12 @@ def test_list_trial_components_by_experiment(experiment_obj, trial_component_obj
103104

104105
def test_search(sagemaker_boto_client):
105106
trial_component_names_searched = []
106-
for s in trial_component.TrialComponent.search(max_results=10, sagemaker_boto_client=sagemaker_boto_client):
107-
if "smexperiments-integ-" in s.trial_component_name:
108-
trial_component_names_searched.append(s.trial_component_name)
107+
search_filter = Filter(name="TrialComponentName", operator=Operator.CONTAINS, value="smexperiments-integ-")
108+
search_expression = SearchExpression(filters=[search_filter])
109+
for s in trial_component.TrialComponent.search(
110+
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client
111+
):
112+
trial_component_names_searched.append(s.trial_component_name)
109113

110114
assert len(trial_component_names_searched) > 0
111115
assert trial_component_names_searched # sanity test
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from smexperiments.search_expression import Filter, Operator, NestedFilter, SearchExpression, BooleanOperator
2+
import pytest
3+
4+
5+
def test_filters():
6+
search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1")
7+
8+
assert {"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"} == search_filter.to_boto()
9+
10+
11+
def test_partial_filters():
12+
search_filter = Filter(name="learning_rate")
13+
14+
assert {"Name": "learning_rate"} == search_filter.to_boto()
15+
16+
17+
def test_nested_filters():
18+
search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1")
19+
filters = [search_filter]
20+
nested_filters = NestedFilter(property_name="hyper_param", filters=filters)
21+
22+
assert {
23+
"Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}],
24+
"NestedPropertyName": "hyper_param",
25+
} == nested_filters.to_boto()
26+
27+
28+
def test_search_expression():
29+
search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1")
30+
nested_filter = NestedFilter(property_name="hyper_param", filters=[search_filter])
31+
search_expression = SearchExpression(
32+
filters=[search_filter],
33+
nested_filters=[nested_filter],
34+
sub_expressions=[],
35+
boolean_operator=BooleanOperator.AND,
36+
)
37+
38+
assert {
39+
"Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}],
40+
"NestedFilters": [
41+
{
42+
"Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}],
43+
"NestedPropertyName": "hyper_param",
44+
}
45+
],
46+
"SubExpressions": [],
47+
"Operator": "And",
48+
} == search_expression.to_boto()
49+
50+
51+
def test_illegal_search_expression():
52+
with pytest.raises(ValueError, match="You must specify at least one subexpression, filter, or nested filter"):
53+
SearchExpression()

0 commit comments

Comments
 (0)