Skip to content

Commit 49eefaf

Browse files
committed
decode DeviceRequest.Count using DecodeMapstructure
Signed-off-by: Nicolas De Loof <[email protected]>
1 parent 88eac1d commit 49eefaf

File tree

4 files changed

+55
-39
lines changed

4 files changed

+55
-39
lines changed

loader/loader.go

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,6 @@ func createTransformHook(additionalTransformers ...Transformer) mapstructure.Dec
624624
reflect.TypeOf(types.BuildConfig{}): transformBuildConfig,
625625
reflect.TypeOf(types.DependsOnConfig{}): transformDependsOnConfig,
626626
reflect.TypeOf(types.ExtendsConfig{}): transformExtendsConfig,
627-
reflect.TypeOf(types.DeviceRequest{}): transformServiceDeviceRequest,
628627
reflect.TypeOf(types.SSHConfig{}): transformSSHConfig,
629628
reflect.TypeOf(types.IncludeConfig{}): transformIncludeConfig,
630629
}
@@ -1087,35 +1086,6 @@ var transformServicePort TransformerFunc = func(data interface{}) (interface{},
10871086
}
10881087
}
10891088

1090-
var transformServiceDeviceRequest TransformerFunc = func(data interface{}) (interface{}, error) {
1091-
switch value := data.(type) {
1092-
case map[string]interface{}:
1093-
count, ok := value["count"]
1094-
if ok {
1095-
switch val := count.(type) {
1096-
case int:
1097-
return value, nil
1098-
case string:
1099-
if strings.ToLower(val) == "all" {
1100-
value["count"] = -1
1101-
return value, nil
1102-
}
1103-
i, err := strconv.ParseInt(val, 10, 64)
1104-
if err == nil {
1105-
value["count"] = i
1106-
return value, nil
1107-
}
1108-
return data, errors.Errorf("invalid string value for 'count' (the only value allowed is 'all' or a number)")
1109-
default:
1110-
return data, errors.Errorf("invalid type %T for device count", val)
1111-
}
1112-
}
1113-
return data, nil
1114-
default:
1115-
return data, errors.Errorf("invalid type %T for resource reservation", value)
1116-
}
1117-
}
1118-
11191089
var transformFileReferenceConfig TransformerFunc = func(data interface{}) (interface{}, error) {
11201090
switch value := data.(type) {
11211091
case string:

loader/loader_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,9 +2129,9 @@ services:
21292129
devices:
21302130
- driver: nvidia
21312131
capabilities: [gpu]
2132-
count: somestring
2132+
count: some_string
21332133
`)
2134-
assert.ErrorContains(t, err, "invalid string value for 'count' (the only value allowed is 'all' or a number)")
2134+
assert.ErrorContains(t, err, `invalid value "some_string", the only value allowed is 'all' or a number`)
21352135
}
21362136

21372137
func TestServicePullPolicy(t *testing.T) {

types/device.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
Copyright 2020 The Compose Specification Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package types
18+
19+
import (
20+
"strconv"
21+
"strings"
22+
23+
"github.com/pkg/errors"
24+
)
25+
26+
type DeviceRequest struct {
27+
Capabilities []string `yaml:"capabilities,omitempty" json:"capabilities,omitempty"`
28+
Driver string `yaml:"driver,omitempty" json:"driver,omitempty"`
29+
Count DeviceCount `yaml:"count,omitempty" json:"count,omitempty"`
30+
IDs []string `yaml:"device_ids,omitempty" json:"device_ids,omitempty"`
31+
}
32+
33+
type DeviceCount int64
34+
35+
func (c *DeviceCount) DecodeMapstructure(value interface{}) error {
36+
switch v := value.(type) {
37+
case int:
38+
*c = DeviceCount(v)
39+
case string:
40+
if strings.ToLower(v) == "all" {
41+
*c = -1
42+
return nil
43+
}
44+
i, err := strconv.ParseInt(v, 10, 64)
45+
if err != nil {
46+
return errors.Errorf("invalid value %q, the only value allowed is 'all' or a number", v)
47+
}
48+
*c = DeviceCount(i)
49+
default:
50+
return errors.Errorf("invalid type %T for device count", v)
51+
}
52+
return nil
53+
}

types/types.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -584,13 +584,6 @@ type Resource struct {
584584
Extensions Extensions `yaml:"#extensions,inline" json:"-"`
585585
}
586586

587-
type DeviceRequest struct {
588-
Capabilities []string `yaml:"capabilities,omitempty" json:"capabilities,omitempty"`
589-
Driver string `yaml:"driver,omitempty" json:"driver,omitempty"`
590-
Count int64 `yaml:"count,omitempty" json:"count,omitempty"`
591-
IDs []string `yaml:"device_ids,omitempty" json:"device_ids,omitempty"`
592-
}
593-
594587
// GenericResource represents a "user defined" resource which can
595588
// only be an integer (e.g: SSD=3) for a service
596589
type GenericResource struct {

0 commit comments

Comments
 (0)