Skip to content

Commit 30fa9c5

Browse files
authored
Unit tests for HyperParameterJobTuning (#124)
Description of changes: - `sdk.go` coverage 71.3 % - This coverage % would be higher however there are fields in [HPO](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job), under `TrainingJobDefinition` that cannot be used only under `TrainingJobDefinitions`, trying to use some of the fields results in `validationExceptions `and tells the user to use `TrainingJobDefinitions`. Service requires user to supply values such as `HyperParameterRanges` and you cannot include `ParameterRanges` when only using `TrainingJobDefinition`. Appears to be a miss from the service team itself. - `hooks.go` coverage 100% - Custom code coverage >95% By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
1 parent cc0f45e commit 30fa9c5

31 files changed

+3338
-23
lines changed

pkg/resource/hyper_parameter_tuning_job/manager_test_suite_test.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@ import (
2020
ackv1alpha1 "github.com/aws-controllers-k8s/runtime/apis/core/v1alpha1"
2121
ackmetrics "github.com/aws-controllers-k8s/runtime/pkg/metrics"
2222
acktypes "github.com/aws-controllers-k8s/runtime/pkg/types"
23+
svcapitypes "github.com/aws-controllers-k8s/sagemaker-controller/apis/v1alpha1"
2324
"github.com/aws-controllers-k8s/sagemaker-controller/pkg/testutil"
2425
mocksvcsdkapi "github.com/aws-controllers-k8s/sagemaker-controller/test/mocks/aws-sdk-go/sagemaker"
2526
svcsdk "github.com/aws/aws-sdk-go/service/sagemaker"
27+
28+
"path/filepath"
29+
"testing"
30+
2631
"github.com/ghodss/yaml"
2732
"github.com/google/go-cmp/cmp"
2833
"github.com/google/go-cmp/cmp/cmpopts"
2934
"go.uber.org/zap/zapcore"
30-
"path/filepath"
3135
ctrlrtzap "sigs.k8s.io/controller-runtime/pkg/log/zap"
32-
"testing"
3336
)
3437

3538
// provideResourceManagerWithMockSDKAPI accepts MockSageMakerAPI and returns pointer to resourceManager
@@ -100,15 +103,31 @@ func (d *testRunnerDelegate) Equal(a acktypes.AWSResource, b acktypes.AWSResourc
100103
ac := a.(*resource)
101104
bc := b.(*resource)
102105
// Ignore LastTransitionTime since it gets updated each run.
103-
opts := []cmp.Option{cmpopts.EquateEmpty(), cmpopts.IgnoreFields(ackv1alpha1.Condition{}, "LastTransitionTime")}
106+
opts := []cmp.Option{cmpopts.EquateEmpty(),
107+
cmpopts.IgnoreFields(ackv1alpha1.Condition{}, "LastTransitionTime"),
108+
cmpopts.IgnoreFields(svcapitypes.HyperParameterTrainingJobSummary{}, "CreationTime"),
109+
cmpopts.IgnoreFields(svcapitypes.HyperParameterTrainingJobSummary{}, "TrainingStartTime"),
110+
cmpopts.IgnoreFields(svcapitypes.HyperParameterTrainingJobSummary{}, "TrainingEndTime")}
104111

112+
var specMatch = false
113+
if cmp.Equal(ac.ko.Spec, bc.ko.Spec, opts...) {
114+
specMatch = true
115+
} else {
116+
fmt.Printf("Difference ko.Spec (-expected +actual):\n\n")
117+
fmt.Println(cmp.Diff(ac.ko.Spec, bc.ko.Spec, opts...))
118+
specMatch = false
119+
}
120+
121+
var statusMatch = false
105122
if cmp.Equal(ac.ko.Status, bc.ko.Status, opts...) {
106-
return true
123+
statusMatch = true
107124
} else {
108-
fmt.Printf("Difference (-expected +actual):\n\n")
125+
fmt.Printf("Difference ko.Status (-expected +actual):\n\n")
109126
fmt.Println(cmp.Diff(ac.ko.Status, bc.ko.Status, opts...))
110-
return false
127+
statusMatch = false
111128
}
129+
130+
return statusMatch && specMatch
112131
}
113132

114133
// Checks to see if the given yaml file, with name stored as expectation,

pkg/resource/hyper_parameter_tuning_job/testdata/hyper_parameter_tuning_job/v1alpha1/hptj_invalid_before_create.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ spec:
2323
staticHyperParameters:
2424
base_score: '0.5'
2525
algorithmSpecification:
26-
trainingImage: 246618743249.dkr.ecr.us-west-2.amazonaws.com
26+
trainingImage: 433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1
2727
trainingInputMode: File
2828
roleARN: arn:aws:iam::123456789012:role/ack-sagemaker-execution-role
2929
inputDataConfig:

pkg/resource/hyper_parameter_tuning_job/testdata/hyper_parameter_tuning_job/v1alpha1/hptj_invalid_create_attempted.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ spec:
2222
hyperParameterTuningJobName: intentionally@invalid-name
2323
trainingJobDefinition:
2424
algorithmSpecification:
25-
trainingImage: 246618743249.dkr.ecr.us-west-2.amazonaws.com
25+
trainingImage: 433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1
2626
trainingInputMode: File
2727
staticHyperParameters:
2828
base_score: "0.5"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"HyperParameterTuningJobArn": "arn:aws:sagemaker:us-west-2:123456789012:hyperparameter-tuning-job/unit-testing-hyper-parameter-tuning-job"
3+
}
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
{
2+
"BestTrainingJob": null,
3+
"CreationTime": "2021-10-18T04:21:01.477Z",
4+
"FailureReason": null,
5+
"HyperParameterTuningEndTime": null,
6+
"HyperParameterTuningJobArn": "arn:aws:sagemaker:us-west-2:123456789012:hyper-parameter-tuning-job/unit-testing-hpo-job",
7+
"HyperParameterTuningJobConfig": {
8+
"HyperParameterTuningJobObjective": {
9+
"MetricName": "validation:error",
10+
"Type": "Minimize"
11+
},
12+
"ParameterRanges": {
13+
"CategoricalParameterRanges": [
14+
{
15+
"Name": "category",
16+
"Values": [
17+
"test"
18+
]
19+
}
20+
],
21+
"ContinuousParameterRanges": [
22+
{
23+
"MaxValue": "5",
24+
"MinValue": "0",
25+
"Name": "gamma",
26+
"ScalingType": "Linear"
27+
}
28+
],
29+
"IntegerParameterRanges": [
30+
{
31+
"MaxValue": "20",
32+
"MinValue": "10",
33+
"Name": "num_round",
34+
"ScalingType": "Linear"
35+
}
36+
]
37+
},
38+
"ResourceLimits": {
39+
"MaxNumberOfTrainingJobs": 2,
40+
"MaxParallelTrainingJobs": 1
41+
},
42+
"Strategy": "Bayesian",
43+
"TrainingJobEarlyStoppingType": "Auto",
44+
"TuningJobCompletionCriteria": null
45+
},
46+
"HyperParameterTuningJobName": "unit-testing-hpo-job",
47+
"HyperParameterTuningJobStatus": "InProgress",
48+
"LastModifiedTime": "2021-10-18T04:21:01.477Z",
49+
"ObjectiveStatusCounters": {
50+
"Failed": 0,
51+
"Pending": 0,
52+
"Succeeded": 0
53+
},
54+
"OverallBestTrainingJob": null,
55+
"TrainingJobDefinition": {
56+
"AlgorithmSpecification": {
57+
"AlgorithmName": null,
58+
"MetricDefinitions": [
59+
{
60+
"Name": "train:mae",
61+
"Regex": ".*\\[[0-9]+\\].*#011train-mae:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
62+
},
63+
{
64+
"Name": "validation:auc",
65+
"Regex": ".*\\[[0-9]+\\].*#011validation-auc:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
66+
},
67+
{
68+
"Name": "train:merror",
69+
"Regex": ".*\\[[0-9]+\\].*#011train-merror:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
70+
},
71+
{
72+
"Name": "train:auc",
73+
"Regex": ".*\\[[0-9]+\\].*#011train-auc:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
74+
},
75+
{
76+
"Name": "validation:mae",
77+
"Regex": ".*\\[[0-9]+\\].*#011validation-mae:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
78+
},
79+
{
80+
"Name": "validation:error",
81+
"Regex": ".*\\[[0-9]+\\].*#011validation-error:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
82+
},
83+
{
84+
"Name": "validation:merror",
85+
"Regex": ".*\\[[0-9]+\\].*#011validation-merror:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
86+
},
87+
{
88+
"Name": "validation:logloss",
89+
"Regex": ".*\\[[0-9]+\\].*#011validation-logloss:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
90+
},
91+
{
92+
"Name": "train:rmse",
93+
"Regex": ".*\\[[0-9]+\\].*#011train-rmse:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
94+
},
95+
{
96+
"Name": "train:logloss",
97+
"Regex": ".*\\[[0-9]+\\].*#011train-logloss:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
98+
},
99+
{
100+
"Name": "train:mlogloss",
101+
"Regex": ".*\\[[0-9]+\\].*#011train-mlogloss:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
102+
},
103+
{
104+
"Name": "validation:rmse",
105+
"Regex": ".*\\[[0-9]+\\].*#011validation-rmse:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
106+
},
107+
{
108+
"Name": "validation:ndcg",
109+
"Regex": ".*\\[[0-9]+\\].*#011validation-ndcg:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
110+
},
111+
{
112+
"Name": "train:error",
113+
"Regex": ".*\\[[0-9]+\\].*#011train-error:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
114+
},
115+
{
116+
"Name": "validation:mlogloss",
117+
"Regex": ".*\\[[0-9]+\\].*#011validation-mlogloss:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
118+
},
119+
{
120+
"Name": "train:ndcg",
121+
"Regex": ".*\\[[0-9]+\\].*#011train-ndcg:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
122+
},
123+
{
124+
"Name": "train:map",
125+
"Regex": ".*\\[[0-9]+\\].*#011train-map:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
126+
},
127+
{
128+
"Name": "validation:map",
129+
"Regex": ".*\\[[0-9]+\\].*#011validation-map:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
130+
},
131+
{
132+
"Name": "ObjectiveMetric",
133+
"Regex": ".*\\[[0-9]+\\].*#011validation-error:([-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?).*"
134+
}
135+
],
136+
"TrainingImage": "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1",
137+
"TrainingInputMode": "File"
138+
},
139+
"CheckpointConfig": null,
140+
"DefinitionName": null,
141+
"EnableInterContainerTrafficEncryption": false,
142+
"EnableManagedSpotTraining": false,
143+
"EnableNetworkIsolation": true,
144+
"HyperParameterRanges": null,
145+
"InputDataConfig": [
146+
{
147+
"ChannelName": "train",
148+
"CompressionType": "None",
149+
"ContentType": "text/csv",
150+
"DataSource": {
151+
"FileSystemDataSource": null,
152+
"S3DataSource": {
153+
"AttributeNames": null,
154+
"S3DataDistributionType": "FullyReplicated",
155+
"S3DataType": "S3Prefix",
156+
"S3Uri": "s3://source-data-bucket-592697580195-us-west-2/sagemaker/training/train"
157+
}
158+
},
159+
"InputMode": "File",
160+
"RecordWrapperType": "None",
161+
"ShuffleConfig": null
162+
},
163+
{
164+
"ChannelName": "validation",
165+
"CompressionType": "None",
166+
"ContentType": "text/csv",
167+
"DataSource": {
168+
"FileSystemDataSource": null,
169+
"S3DataSource": {
170+
"AttributeNames": null,
171+
"S3DataDistributionType": "FullyReplicated",
172+
"S3DataType": "S3Prefix",
173+
"S3Uri": "s3://source-data-bucket-592697580195-us-west-2/sagemaker/training/validation/"
174+
}
175+
},
176+
"InputMode": "File",
177+
"RecordWrapperType": "None",
178+
"ShuffleConfig": null
179+
}
180+
],
181+
"OutputDataConfig": {
182+
"KmsKeyId": null,
183+
"S3OutputPath": "s3://source-data-bucket-592697580195-us-west-2/sagemaker/hpo/output"
184+
},
185+
"ResourceConfig": {
186+
"InstanceCount": 1,
187+
"InstanceType": "ml.m5.large",
188+
"VolumeKmsKeyId": null,
189+
"VolumeSizeInGB": 25
190+
},
191+
"RoleArn": "arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole-20210920T111639",
192+
"StaticHyperParameters": {
193+
"_tuning_objective_metric": "validation:error",
194+
"base_score": "0.5"
195+
},
196+
"StoppingCondition": {
197+
"MaxRuntimeInSeconds": 3600,
198+
"MaxWaitTimeInSeconds": null
199+
},
200+
"TuningObjective": null,
201+
"VpcConfig": null
202+
},
203+
"TrainingJobDefinitions": null,
204+
"TrainingJobStatusCounters": {
205+
"Completed": 0,
206+
"InProgress": 0,
207+
"NonRetryableError": 0,
208+
"RetryableError": 0,
209+
"Stopped": 0
210+
},
211+
"WarmStartConfig": null
212+
}

0 commit comments

Comments
 (0)