Skip to content

Commit 2d32c3f

Browse files
authored
Merge pull request #465 from ndeloof/decode_mapstructure
introduce DecodeMapstructure to allow type to define custom decode logic
2 parents 588d586 + 54b9780 commit 2d32c3f

File tree

14 files changed

+604
-269
lines changed

14 files changed

+604
-269
lines changed

loader/interpolate.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ var interpolateTypeCastMapping = map[tree.Path]interp.Cast{
4646
servicePath("deploy", "placement", "max_replicas_per_node"): toInt,
4747
servicePath("healthcheck", "retries"): toInt,
4848
servicePath("healthcheck", "disable"): toBoolean,
49-
servicePath("mem_limit"): toUnitBytes,
50-
servicePath("mem_reservation"): toUnitBytes,
51-
servicePath("memswap_limit"): toUnitBytes,
52-
servicePath("mem_swappiness"): toUnitBytes,
5349
servicePath("oom_kill_disable"): toBoolean,
5450
servicePath("oom_score_adj"): toInt64,
5551
servicePath("pids_limit"): toInt64,
@@ -58,16 +54,13 @@ var interpolateTypeCastMapping = map[tree.Path]interp.Cast{
5854
servicePath("read_only"): toBoolean,
5955
servicePath("scale"): toInt,
6056
servicePath("secrets", tree.PathMatchList, "mode"): toInt,
61-
servicePath("shm_size"): toUnitBytes,
6257
servicePath("stdin_open"): toBoolean,
63-
servicePath("stop_grace_period"): toDuration,
6458
servicePath("tty"): toBoolean,
6559
servicePath("ulimits", tree.PathMatchAll): toInt,
6660
servicePath("ulimits", tree.PathMatchAll, "hard"): toInt,
6761
servicePath("ulimits", tree.PathMatchAll, "soft"): toInt,
6862
servicePath("volumes", tree.PathMatchList, "read_only"): toBoolean,
6963
servicePath("volumes", tree.PathMatchList, "volume", "nocopy"): toBoolean,
70-
servicePath("volumes", tree.PathMatchList, "tmpfs", "size"): toUnitBytes,
7164
iPath("networks", tree.PathMatchAll, "external"): toBoolean,
7265
iPath("networks", tree.PathMatchAll, "internal"): toBoolean,
7366
iPath("networks", tree.PathMatchAll, "attachable"): toBoolean,
@@ -93,14 +86,6 @@ func toInt64(value string) (interface{}, error) {
9386
return strconv.ParseInt(value, 10, 64)
9487
}
9588

96-
func toUnitBytes(value string) (interface{}, error) {
97-
return transformSize(value)
98-
}
99-
100-
func toDuration(value string) (interface{}, error) {
101-
return transformStringToDuration(value)
102-
}
103-
10489
func toFloat(value string) (interface{}, error) {
10590
return strconv.ParseFloat(value, 64)
10691
}

loader/loader.go

Lines changed: 3 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,12 @@ import (
2828
"regexp"
2929
"strconv"
3030
"strings"
31-
"time"
3231

3332
"github.com/compose-spec/compose-go/consts"
3433
interp "github.com/compose-spec/compose-go/interpolation"
3534
"github.com/compose-spec/compose-go/schema"
3635
"github.com/compose-spec/compose-go/template"
3736
"github.com/compose-spec/compose-go/types"
38-
"github.com/docker/go-units"
39-
"github.com/mattn/go-shellwords"
4037
"github.com/mitchellh/mapstructure"
4138
"github.com/pkg/errors"
4239
"github.com/sirupsen/logrus"
@@ -586,7 +583,7 @@ func Transform(source interface{}, target interface{}, additionalTransformers ..
586583
config := &mapstructure.DecoderConfig{
587584
DecodeHook: mapstructure.ComposeDecodeHookFunc(
588585
createTransformHook(additionalTransformers...),
589-
mapstructure.StringToTimeDurationHookFunc()),
586+
decoderHook),
590587
Result: target,
591588
TagName: "yaml",
592589
Metadata: &data,
@@ -610,28 +607,20 @@ type Transformer struct {
610607
func createTransformHook(additionalTransformers ...Transformer) mapstructure.DecodeHookFuncType {
611608
transforms := map[reflect.Type]func(interface{}) (interface{}, error){
612609
reflect.TypeOf(types.External{}): transformExternal,
613-
reflect.TypeOf(types.HealthCheckTest{}): transformHealthCheckTest,
614-
reflect.TypeOf(types.ShellCommand{}): transformShellCommand,
615-
reflect.TypeOf(types.StringList{}): transformStringList,
616-
reflect.TypeOf(types.Options{}): transformMapStringString,
610+
reflect.TypeOf(types.Options{}): transformOptions,
617611
reflect.TypeOf(types.UlimitsConfig{}): transformUlimits,
618-
reflect.TypeOf(types.UnitBytes(0)): transformSize,
619612
reflect.TypeOf([]types.ServicePortConfig{}): transformServicePort,
620613
reflect.TypeOf(types.ServiceSecretConfig{}): transformFileReferenceConfig,
621614
reflect.TypeOf(types.ServiceConfigObjConfig{}): transformFileReferenceConfig,
622-
reflect.TypeOf(types.StringOrNumberList{}): transformStringOrNumberList,
623615
reflect.TypeOf(map[string]*types.ServiceNetworkConfig{}): transformServiceNetworkMap,
624616
reflect.TypeOf(types.Mapping{}): transformMappingOrListFunc("=", false),
625617
reflect.TypeOf(types.MappingWithEquals{}): transformMappingOrListFunc("=", true),
626-
reflect.TypeOf(types.Labels{}): transformMappingOrListFunc("=", false),
627618
reflect.TypeOf(types.MappingWithColon{}): transformMappingOrListFunc(":", false),
628619
reflect.TypeOf(types.HostsList{}): transformMappingOrListFunc(":", false),
629620
reflect.TypeOf(types.ServiceVolumeConfig{}): transformServiceVolumeConfig,
630621
reflect.TypeOf(types.BuildConfig{}): transformBuildConfig,
631-
reflect.TypeOf(types.Duration(0)): transformStringToDuration,
632622
reflect.TypeOf(types.DependsOnConfig{}): transformDependsOnConfig,
633623
reflect.TypeOf(types.ExtendsConfig{}): transformExtendsConfig,
634-
reflect.TypeOf(types.DeviceRequest{}): transformServiceDeviceRequest,
635624
reflect.TypeOf(types.SSHConfig{}): transformSSHConfig,
636625
reflect.TypeOf(types.IncludeConfig{}): transformIncludeConfig,
637626
}
@@ -1031,7 +1020,7 @@ func loadFileObjectConfig(name string, objType string, obj types.FileObjectConfi
10311020
return obj, nil
10321021
}
10331022

1034-
var transformMapStringString TransformerFunc = func(data interface{}) (interface{}, error) {
1023+
var transformOptions TransformerFunc = func(data interface{}) (interface{}, error) {
10351024
switch value := data.(type) {
10361025
case map[string]interface{}:
10371026
return toMapStringString(value, false), nil
@@ -1094,35 +1083,6 @@ var transformServicePort TransformerFunc = func(data interface{}) (interface{},
10941083
}
10951084
}
10961085

1097-
var transformServiceDeviceRequest TransformerFunc = func(data interface{}) (interface{}, error) {
1098-
switch value := data.(type) {
1099-
case map[string]interface{}:
1100-
count, ok := value["count"]
1101-
if ok {
1102-
switch val := count.(type) {
1103-
case int:
1104-
return value, nil
1105-
case string:
1106-
if strings.ToLower(val) == "all" {
1107-
value["count"] = -1
1108-
return value, nil
1109-
}
1110-
i, err := strconv.ParseInt(val, 10, 64)
1111-
if err == nil {
1112-
value["count"] = i
1113-
return value, nil
1114-
}
1115-
return data, errors.Errorf("invalid string value for 'count' (the only value allowed is 'all' or a number)")
1116-
default:
1117-
return data, errors.Errorf("invalid type %T for device count", val)
1118-
}
1119-
}
1120-
return data, nil
1121-
default:
1122-
return data, errors.Errorf("invalid type %T for resource reservation", value)
1123-
}
1124-
}
1125-
11261086
var transformFileReferenceConfig TransformerFunc = func(data interface{}) (interface{}, error) {
11271087
switch value := data.(type) {
11281088
case string:
@@ -1258,26 +1218,6 @@ func ParseShortSSHSyntax(value string) ([]types.SSHKey, error) {
12581218
return result, nil
12591219
}
12601220

1261-
var transformStringOrNumberList TransformerFunc = func(value interface{}) (interface{}, error) {
1262-
list := value.([]interface{})
1263-
result := make([]string, len(list))
1264-
for i, item := range list {
1265-
result[i] = fmt.Sprint(item)
1266-
}
1267-
return result, nil
1268-
}
1269-
1270-
var transformStringList TransformerFunc = func(data interface{}) (interface{}, error) {
1271-
switch value := data.(type) {
1272-
case string:
1273-
return []string{value}, nil
1274-
case []interface{}:
1275-
return value, nil
1276-
default:
1277-
return data, errors.Errorf("invalid type %T for string list", value)
1278-
}
1279-
}
1280-
12811221
func transformMappingOrListFunc(sep string, allowNil bool) TransformerFunc {
12821222
return func(data interface{}) (interface{}, error) {
12831223
return transformMappingOrList(data, sep, allowNil)
@@ -1312,52 +1252,6 @@ func transformValueToMapEntry(value string, separator string, allowNil bool) (st
13121252
}
13131253
}
13141254

1315-
var transformShellCommand TransformerFunc = func(value interface{}) (interface{}, error) {
1316-
if str, ok := value.(string); ok {
1317-
return shellwords.Parse(str)
1318-
}
1319-
return value, nil
1320-
}
1321-
1322-
var transformHealthCheckTest TransformerFunc = func(data interface{}) (interface{}, error) {
1323-
switch value := data.(type) {
1324-
case string:
1325-
return append([]string{"CMD-SHELL"}, value), nil
1326-
case []interface{}:
1327-
return value, nil
1328-
default:
1329-
return value, errors.Errorf("invalid type %T for healthcheck.test", value)
1330-
}
1331-
}
1332-
1333-
var transformSize TransformerFunc = func(value interface{}) (interface{}, error) {
1334-
switch value := value.(type) {
1335-
case int:
1336-
return int64(value), nil
1337-
case int64, types.UnitBytes:
1338-
return value, nil
1339-
case string:
1340-
return units.RAMInBytes(value)
1341-
default:
1342-
return value, errors.Errorf("invalid type for size %T", value)
1343-
}
1344-
}
1345-
1346-
var transformStringToDuration TransformerFunc = func(value interface{}) (interface{}, error) {
1347-
switch value := value.(type) {
1348-
case string:
1349-
d, err := time.ParseDuration(value)
1350-
if err != nil {
1351-
return value, err
1352-
}
1353-
return types.Duration(d), nil
1354-
case types.Duration:
1355-
return value, nil
1356-
default:
1357-
return value, errors.Errorf("invalid type %T for duration", value)
1358-
}
1359-
}
1360-
13611255
func toMapStringString(value map[string]interface{}, allowNil bool) map[string]interface{} {
13621256
output := make(map[string]interface{})
13631257
for key, value := range value {

loader/loader_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,9 +2129,9 @@ services:
21292129
devices:
21302130
- driver: nvidia
21312131
capabilities: [gpu]
2132-
count: somestring
2132+
count: some_string
21332133
`)
2134-
assert.ErrorContains(t, err, "invalid string value for 'count' (the only value allowed is 'all' or a number)")
2134+
assert.ErrorContains(t, err, `invalid value "some_string", the only value allowed is 'all' or a number`)
21352135
}
21362136

21372137
func TestServicePullPolicy(t *testing.T) {

loader/mapstructure.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
Copyright 2020 The Compose Specification Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package loader
18+
19+
import "reflect"
20+
21+
// comparable to yaml.Unmarshaler, decoder allow a type to define it's own custom logic to convert value
22+
// see https://github.com/mitchellh/mapstructure/pull/294
23+
type decoder interface {
24+
DecodeMapstructure(interface{}) error
25+
}
26+
27+
// see https://github.com/mitchellh/mapstructure/issues/115#issuecomment-735287466
28+
// adapted to support types derived from built-in types, as DecodeMapstructure would not be able to mutate internal
29+
// value, so need to invoke DecodeMapstructure defined by pointer to type
30+
func decoderHook(from reflect.Value, to reflect.Value) (interface{}, error) {
31+
// If the destination implements the decoder interface
32+
u, ok := to.Interface().(decoder)
33+
if !ok {
34+
// for non-struct types we need to invoke func (*type) DecodeMapstructure()
35+
if to.CanAddr() {
36+
pto := to.Addr()
37+
u, ok = pto.Interface().(decoder)
38+
}
39+
if !ok {
40+
return from.Interface(), nil
41+
}
42+
}
43+
// If it is nil and a pointer, create and assign the target value first
44+
if to.Type().Kind() == reflect.Ptr && to.IsNil() {
45+
to.Set(reflect.New(to.Type().Elem()))
46+
u = to.Interface().(decoder)
47+
}
48+
// Call the custom DecodeMapstructure method
49+
if err := u.DecodeMapstructure(from.Interface()); err != nil {
50+
return to.Interface(), err
51+
}
52+
return to.Interface(), nil
53+
}

loader/mapstructure_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
Copyright 2020 The Compose Specification Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package loader
18+
19+
import (
20+
"testing"
21+
22+
"github.com/compose-spec/compose-go/types"
23+
"github.com/mitchellh/mapstructure"
24+
"gotest.tools/v3/assert"
25+
)
26+
27+
func TestDecodeMapStructure(t *testing.T) {
28+
var target types.ServiceConfig
29+
data := mapstructure.Metadata{}
30+
config := &mapstructure.DecoderConfig{
31+
Result: &target,
32+
TagName: "yaml",
33+
Metadata: &data,
34+
DecodeHook: mapstructure.ComposeDecodeHookFunc(decoderHook),
35+
}
36+
decoder, err := mapstructure.NewDecoder(config)
37+
assert.NilError(t, err)
38+
err = decoder.Decode(map[string]interface{}{
39+
"mem_limit": "640k",
40+
"command": "echo hello",
41+
"stop_grace_period": "60s",
42+
"labels": []interface{}{
43+
"FOO=BAR",
44+
},
45+
"deploy": map[string]interface{}{
46+
"labels": map[string]interface{}{
47+
"FOO": "BAR",
48+
"BAZ": nil,
49+
"QIX": 2,
50+
"ZOT": true,
51+
},
52+
},
53+
})
54+
assert.NilError(t, err)
55+
assert.Equal(t, target.MemLimit, types.UnitBytes(640*1024))
56+
assert.DeepEqual(t, target.Command, types.ShellCommand{"echo", "hello"})
57+
assert.Equal(t, *target.StopGracePeriod, types.Duration(60_000_000_000))
58+
assert.DeepEqual(t, target.Labels, types.Labels{"FOO": "BAR"})
59+
assert.DeepEqual(t, target.Deploy.Labels, types.Labels{
60+
"FOO": "BAR",
61+
"BAZ": "",
62+
"QIX": "2",
63+
"ZOT": "true",
64+
})
65+
}

types/bytes.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
Copyright 2020 The Compose Specification Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package types
18+
19+
import (
20+
"fmt"
21+
22+
"github.com/docker/go-units"
23+
)
24+
25+
// UnitBytes is the bytes type
26+
type UnitBytes int64
27+
28+
// MarshalYAML makes UnitBytes implement yaml.Marshaller
29+
func (u UnitBytes) MarshalYAML() (interface{}, error) {
30+
return fmt.Sprintf("%d", u), nil
31+
}
32+
33+
// MarshalJSON makes UnitBytes implement json.Marshaler
34+
func (u UnitBytes) MarshalJSON() ([]byte, error) {
35+
return []byte(fmt.Sprintf(`"%d"`, u)), nil
36+
}
37+
38+
func (u *UnitBytes) DecodeMapstructure(value interface{}) error {
39+
v, err := units.RAMInBytes(fmt.Sprint(value))
40+
*u = UnitBytes(v)
41+
return err
42+
}

0 commit comments

Comments
 (0)