Skip to content

Commit 173c5d6

Browse files
committed
review processing of DeviceRequest due to Docker engine changes
moby/moby#48483 Signed-off-by: Guillaume Lours <[email protected]>
1 parent f8ea4c3 commit 173c5d6

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

loader/loader_test.go

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2123,7 +2123,7 @@ services:
21232123
}
21242124

21252125
func TestServiceDeviceRequestCountStringType(t *testing.T) {
2126-
_, err := loadYAML(`
2126+
project, err := loadYAML(`
21272127
name: service-device-request-count
21282128
services:
21292129
hello-world:
@@ -2137,6 +2137,7 @@ services:
21372137
count: all
21382138
`)
21392139
assert.NilError(t, err)
2140+
assert.Equal(t, project.Services["hello-world"].Deploy.Resources.Reservations.Devices[0].Count, types.DeviceCount(-1), err)
21402141
}
21412142

21422143
func TestServiceDeviceRequestCountIntegerAsStringType(t *testing.T) {
@@ -2155,6 +2156,22 @@ services:
21552156
`)
21562157
assert.NilError(t, err)
21572158
}
2159+
func TestServiceDeviceRequestWithoutCountAndDeviceIdsType(t *testing.T) {
2160+
project, err := loadYAML(`
2161+
name: service-device-request-count-type
2162+
services:
2163+
hello-world:
2164+
image: redis:alpine
2165+
deploy:
2166+
resources:
2167+
reservations:
2168+
devices:
2169+
- driver: nvidia
2170+
capabilities: [gpu]
2171+
`)
2172+
assert.NilError(t, err)
2173+
assert.Equal(t, project.Services["hello-world"].Deploy.Resources.Reservations.Devices[0].Count, types.DeviceCount(-1), err)
2174+
}
21582175

21592176
func TestServiceDeviceRequestCountInvalidStringType(t *testing.T) {
21602177
_, err := loadYAML(`
@@ -2173,6 +2190,40 @@ services:
21732190
assert.ErrorContains(t, err, `invalid value "some_string", the only value allowed is 'all' or a number`)
21742191
}
21752192

2193+
func TestServiceDeviceRequestCountAndDeviceIdsExclusive(t *testing.T) {
2194+
_, err := loadYAML(`
2195+
name: service-device-request-count-type
2196+
services:
2197+
hello-world:
2198+
image: redis:alpine
2199+
deploy:
2200+
resources:
2201+
reservations:
2202+
devices:
2203+
- driver: nvidia
2204+
capabilities: [gpu]
2205+
count: 2
2206+
device_ids: ["my-device-id"]
2207+
`)
2208+
assert.ErrorContains(t, err, `invalid "count" and "device_ids" are attributes are exclusive`)
2209+
}
2210+
2211+
func TestServiceDeviceRequestCapabilitiesMandatory(t *testing.T) {
2212+
_, err := loadYAML(`
2213+
name: service-device-request-count-type
2214+
services:
2215+
hello-world:
2216+
image: redis:alpine
2217+
deploy:
2218+
resources:
2219+
reservations:
2220+
devices:
2221+
- driver: nvidia
2222+
count: 2
2223+
`)
2224+
assert.ErrorContains(t, err, `"capabilities" attribute is mandatory for device request definition`)
2225+
}
2226+
21762227
func TestServicePullPolicy(t *testing.T) {
21772228
actual, err := loadYAML(`
21782229
name: service-pull-policy

types/device.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,48 @@ func (c *DeviceCount) DecodeMapstructure(value interface{}) error {
5050
}
5151
return nil
5252
}
53+
54+
func (d *DeviceRequest) DecodeMapstructure(value interface{}) error {
55+
v, ok := value.(map[string]any)
56+
if !ok {
57+
return fmt.Errorf("invalid device request type %T", value)
58+
}
59+
if _, okCaps := v["capabilities"]; !okCaps {
60+
return fmt.Errorf(`"capabilities" attribute is mandatory for device request definition`)
61+
}
62+
if _, okCount := v["count"]; okCount {
63+
if _, okDeviceIds := v["device_ids"]; okDeviceIds {
64+
return fmt.Errorf(`invalid "count" and "device_ids" are attributes are exclusive`)
65+
}
66+
}
67+
d.Count = DeviceCount(-1)
68+
69+
capabilities := v["capabilities"]
70+
caps := StringList{}
71+
if err := caps.DecodeMapstructure(capabilities); err != nil {
72+
return err
73+
}
74+
d.Capabilities = caps
75+
if driver, ok := v["driver"]; ok {
76+
if val, ok := driver.(string); ok {
77+
d.Driver = val
78+
} else {
79+
return fmt.Errorf("invalid type for driver value: %T", driver)
80+
}
81+
}
82+
if count, ok := v["count"]; ok {
83+
if err := d.Count.DecodeMapstructure(count); err != nil {
84+
return err
85+
}
86+
}
87+
if deviceIDs, ok := v["device_ids"]; ok {
88+
ids := StringList{}
89+
if err := ids.DecodeMapstructure(deviceIDs); err != nil {
90+
return err
91+
}
92+
d.IDs = ids
93+
d.Count = DeviceCount(len(ids))
94+
}
95+
return nil
96+
97+
}

0 commit comments

Comments
 (0)