Skip to content

Commit b028346

Browse files
authored
fix: Use server side defaults for Train definitions (#177)
Issue: We [do not](https://github.com/aws-controllers-k8s/sagemaker-controller/blob/3996ba037349bb0b65088341afdfd03cf657e09c/generator.yaml#L326) late initialize some parameters in TrainingJobDefinitions like we do in TrainingJob defintion. As a result the controller will infinitely requeue if all parameters are not explicity specified(because the server sends back the default values it uses). Description of changes: `pkg/resource/hyper_parameter_tuning_job/custom_delta.go` - Sets some parameters to their server side default. `pkg/resource/hyper_parameter_tuning_job/testdata/v1alpha1/readone/observed/completed_variation.yaml` - Modified unit test. CRD I used to test: ``` apiVersion: sagemaker.services.k8s.aws/v1alpha1 kind: HyperParameterTuningJob metadata: name: 2022-10-31-hpo-3 spec: hyperParameterTuningJobName: 2022-10-31-hpo-3 hyperParameterTuningJobConfig: strategy: Bayesian resourceLimits: maxNumberOfTrainingJobs: 2 maxParallelTrainingJobs: 1 trainingJobEarlyStoppingType: Auto trainingJobDefinitions: - staticHyperParameters: base_score: '0.5' definitionName: training-job-for-hpo algorithmSpecification: trainingImage: 433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1 trainingInputMode: File roleARN: <arn> tuningObjective: type_: Minimize metricName: validation:error hyperParameterRanges: integerParameterRanges: - name: num_round minValue: '10' maxValue: '20' scalingType: Linear continuousParameterRanges: - name: gamma minValue: '0' maxValue: '5' scalingType: Linear inputDataConfig: - channelName: train dataSource: s3DataSource: s3DataType: S3Prefix s3URI: <train> s3DataDistributionType: FullyReplicated contentType: text/libsvm compressionType: None recordWrapperType: None inputMode: File - channelName: validation dataSource: s3DataSource: s3DataType: S3Prefix s3URI: <validation> s3DataDistributionType: FullyReplicated contentType: text/libsvm compressionType: None recordWrapperType: None inputMode: File outputDataConfig: s3OutputPath: <output> resourceConfig: instanceType: ml.m5.large instanceCount: 1 volumeSizeInGB: 25 stoppingCondition: maxRuntimeInSeconds: 3600 enableNetworkIsolation: true enableInterContainerTrafficEncryption: false tags: - key: algorithm value: xgboost - key: environment value: testing - key: customer value: test-user ``` By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
1 parent 3996ba0 commit b028346

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

pkg/resource/hyper_parameter_tuning_job/custom_delta.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,35 @@ func customSetDefaults(
3232
}
3333
}
3434
}
35+
36+
// TODO: Use late initialize instead once code generator supports late initializing slices.
37+
if ackcompare.IsNotNil(a.ko.Spec.TrainingJobDefinitions) && ackcompare.IsNotNil(b.ko.Spec.TrainingJobDefinitions) {
38+
if len(a.ko.Spec.TrainingJobDefinitions) == len(b.ko.Spec.TrainingJobDefinitions) {
39+
for index := range a.ko.Spec.TrainingJobDefinitions {
40+
latestStaticHyperParameters := b.ko.Spec.TrainingJobDefinitions[index].StaticHyperParameters
41+
if ackcompare.IsNotNil(latestStaticHyperParameters) {
42+
for key, _ := range latestStaticHyperParameters {
43+
if key[0:1] == "_" {
44+
delete(b.ko.Spec.TrainingJobDefinitions[index].StaticHyperParameters, key)
45+
}
46+
}
47+
}
48+
if ackcompare.IsNotNil(a.ko.Spec.TrainingJobDefinitions[index].AlgorithmSpecification) && ackcompare.IsNotNil(b.ko.Spec.TrainingJobDefinitions[index].AlgorithmSpecification) {
49+
if ackcompare.IsNil(a.ko.Spec.TrainingJobDefinitions[index].AlgorithmSpecification.MetricDefinitions) && ackcompare.IsNotNil(b.ko.Spec.TrainingJobDefinitions[index].AlgorithmSpecification.MetricDefinitions) {
50+
a.ko.Spec.TrainingJobDefinitions[index].AlgorithmSpecification.MetricDefinitions = b.ko.Spec.TrainingJobDefinitions[index].AlgorithmSpecification.MetricDefinitions
51+
}
52+
}
53+
if ackcompare.IsNil(a.ko.Spec.TrainingJobDefinitions[index].EnableInterContainerTrafficEncryption) && ackcompare.IsNotNil(b.ko.Spec.TrainingJobDefinitions[index].EnableInterContainerTrafficEncryption) {
54+
a.ko.Spec.TrainingJobDefinitions[index].EnableInterContainerTrafficEncryption = b.ko.Spec.TrainingJobDefinitions[index].EnableInterContainerTrafficEncryption
55+
}
56+
if ackcompare.IsNil(a.ko.Spec.TrainingJobDefinitions[index].EnableManagedSpotTraining) && ackcompare.IsNotNil(b.ko.Spec.TrainingJobDefinitions[index].EnableManagedSpotTraining) {
57+
a.ko.Spec.TrainingJobDefinitions[index].EnableManagedSpotTraining = b.ko.Spec.TrainingJobDefinitions[index].EnableManagedSpotTraining
58+
}
59+
if ackcompare.IsNil(a.ko.Spec.TrainingJobDefinitions[index].EnableNetworkIsolation) && ackcompare.IsNotNil(b.ko.Spec.TrainingJobDefinitions[index].EnableNetworkIsolation) {
60+
a.ko.Spec.TrainingJobDefinitions[index].EnableNetworkIsolation = b.ko.Spec.TrainingJobDefinitions[index].EnableNetworkIsolation
61+
}
62+
}
63+
}
64+
}
65+
3566
}

pkg/resource/hyper_parameter_tuning_job/testdata/v1alpha1/readone/observed/completed_variation.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ spec:
109109
volumeSizeInGB: 25
110110
roleARN: arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole-20210920T111639
111111
staticHyperParameters:
112-
_tuning_objective_metric: validation:error
113112
base_score: "0.5"
114113
stoppingCondition:
115114
maxRuntimeInSeconds: 3600

0 commit comments

Comments
 (0)