Skip to content

Commit 0d14e39

Browse files
authored
Merge pull request #64 from klihub/devel/spec-write-api
pkg/cdi: add functions for writing Spec files.
2 parents da4058b + 6c57d38 commit 0d14e39

File tree

7 files changed

+534
-4
lines changed

7 files changed

+534
-4
lines changed

pkg/cdi/cache.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"strings"
2323
"sync"
2424

25+
cdi "github.com/container-orchestrated-devices/container-device-interface/specs-go"
2526
"github.com/fsnotify/fsnotify"
2627
"github.com/hashicorp/go-multierror"
2728
oci "github.com/opencontainers/runtime-spec/specs-go"
@@ -254,6 +255,39 @@ func (c *Cache) InjectDevices(ociSpec *oci.Spec, devices ...string) ([]string, e
254255
return nil, nil
255256
}
256257

258+
// WriteSpec writes a Spec file with the given content. Priority is used
259+
// as an index into the list of Spec directories to pick a directory for
260+
// the file, adjusting for any under- or overflows. If name has a "json"
261+
// or "yaml" extension it choses the encoding. Otherwise JSON encoding
262+
// is used with a "json" extension.
263+
func (c *Cache) WriteSpec(raw *cdi.Spec, name string) error {
264+
var (
265+
specDir string
266+
path string
267+
prio int
268+
spec *Spec
269+
err error
270+
)
271+
272+
if len(c.specDirs) == 0 {
273+
return errors.New("no Spec directories to write to")
274+
}
275+
276+
prio = len(c.specDirs) - 1
277+
specDir = c.specDirs[prio]
278+
path = filepath.Join(specDir, name)
279+
if ext := filepath.Ext(path); ext != ".json" && ext != ".yaml" {
280+
path += ".json"
281+
}
282+
283+
spec, err = NewSpec(raw, path, prio)
284+
if err != nil {
285+
return err
286+
}
287+
288+
return spec.Write(true)
289+
}
290+
257291
// GetDevice returns the cached device for the given qualified name.
258292
func (c *Cache) GetDevice(device string) *Device {
259293
c.Lock()

pkg/cdi/cache_test.go

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"testing"
2626
"time"
2727

28+
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi/validate"
29+
cdi "github.com/container-orchestrated-devices/container-device-interface/specs-go"
2830
oci "github.com/opencontainers/runtime-spec/specs-go"
2931
"github.com/pkg/errors"
3032
"github.com/stretchr/testify/require"
@@ -1420,6 +1422,218 @@ devices:
14201422
}
14211423
}
14221424

1425+
func TestCacheWriteSpec(t *testing.T) {
1426+
type testCase struct {
1427+
name string
1428+
etc map[string]string
1429+
invalid map[string]bool
1430+
}
1431+
for _, tc := range []*testCase{
1432+
{
1433+
name: "one spec file",
1434+
etc: map[string]string{
1435+
"vendor1.yaml": `
1436+
cdiVersion: "0.3.0"
1437+
kind: "vendor1.com/device"
1438+
devices:
1439+
- name: "dev1"
1440+
containerEdits:
1441+
deviceNodes:
1442+
- path: "/dev/vendor1-dev1"
1443+
type: b
1444+
major: 10
1445+
minor: 1
1446+
`,
1447+
},
1448+
},
1449+
{
1450+
name: "multiple spec files",
1451+
etc: map[string]string{
1452+
"vendor1.yaml": `
1453+
cdiVersion: "0.3.0"
1454+
kind: "vendor1.com/device"
1455+
devices:
1456+
- name: "dev1"
1457+
containerEdits:
1458+
deviceNodes:
1459+
- path: "/dev/vendor1-dev1"
1460+
type: b
1461+
major: 10
1462+
minor: 1
1463+
- name: "dev2"
1464+
containerEdits:
1465+
deviceNodes:
1466+
- path: "/dev/vendor1-dev2"
1467+
type: b
1468+
major: 10
1469+
minor: 2
1470+
`,
1471+
"vendor2.yaml": `
1472+
cdiVersion: "0.3.0"
1473+
kind: "vendor2.com/device"
1474+
devices:
1475+
- name: "dev1"
1476+
containerEdits:
1477+
deviceNodes:
1478+
- path: "/dev/vendor2-dev1"
1479+
type: b
1480+
major: 10
1481+
minor: 1
1482+
`,
1483+
"vendor3.yaml": `
1484+
cdiVersion: "0.3.0"
1485+
kind: "vendor3.com/device"
1486+
devices:
1487+
- name: "dev1"
1488+
containerEdits:
1489+
deviceNodes:
1490+
- path: "/dev/vendor3-dev1"
1491+
type: b
1492+
major: 10
1493+
minor: 1
1494+
`,
1495+
},
1496+
},
1497+
1498+
{
1499+
name: "multiple spec files/data, some invalid",
1500+
etc: map[string]string{
1501+
"vendor1.yaml": `
1502+
cdiVersion: "0.3.0"
1503+
kind: "vendor1.comdevice"
1504+
devices:
1505+
- name: "dev1"
1506+
containerEdits:
1507+
deviceNodes:
1508+
- path: "/dev/vendor1-dev1"
1509+
type: b
1510+
major: 10
1511+
minor: 1
1512+
- name: "dev2"
1513+
containerEdits:
1514+
deviceNodes:
1515+
- path: "/dev/vendor1-dev2"
1516+
type: b
1517+
major: 10
1518+
minor: 2
1519+
`,
1520+
"vendor2.yaml": `
1521+
cdiVersion: "0.3.0"
1522+
kind: "vendor2.com/device"
1523+
devices:
1524+
- name: "dev1"
1525+
containerEdits:
1526+
deviceNodes:
1527+
- path: "/dev/vendor2-dev1"
1528+
type: b
1529+
major: 10
1530+
minor: 1
1531+
`,
1532+
"vendor3.yaml": `
1533+
cdiVersion: "0.3.0"
1534+
kind: "vendor3.com/device"
1535+
containerEdits:
1536+
deviceNodes:
1537+
- path: "/dev/vendor3-dev1"
1538+
type: b
1539+
major: 10
1540+
minor: 1
1541+
`,
1542+
},
1543+
invalid: map[string]bool{
1544+
"vendor1.yaml": true,
1545+
"vendor3.yaml": true,
1546+
},
1547+
},
1548+
} {
1549+
t.Run(tc.name, func(t *testing.T) {
1550+
var (
1551+
dir string
1552+
etc map[string]string
1553+
raw *cdi.Spec
1554+
err error
1555+
cache *Cache
1556+
other *Cache
1557+
)
1558+
1559+
SetSpecValidator(validate.WithNamedSchema("builtin"))
1560+
1561+
if len(tc.invalid) != 0 {
1562+
dir, err = createSpecDirs(t, nil, nil)
1563+
require.NoError(t, err)
1564+
cache, err = NewCache(
1565+
WithSpecDirs(
1566+
filepath.Join(dir, "etc"),
1567+
filepath.Join(dir, "run"),
1568+
),
1569+
WithAutoRefresh(false),
1570+
)
1571+
1572+
require.NoError(t, err)
1573+
require.NotNil(t, cache)
1574+
1575+
etc = map[string]string{}
1576+
for name, data := range tc.etc {
1577+
raw, err = ParseSpec([]byte(data))
1578+
require.NoError(t, err)
1579+
require.NotNil(t, raw)
1580+
1581+
err = cache.WriteSpec(raw, name)
1582+
1583+
if tc.invalid[name] {
1584+
require.Error(t, err)
1585+
} else {
1586+
require.NoError(t, err)
1587+
etc[name] = data
1588+
}
1589+
}
1590+
} else {
1591+
etc = tc.etc
1592+
}
1593+
1594+
dir, err = createSpecDirs(t, etc, nil)
1595+
require.NoError(t, err)
1596+
1597+
cache, err = NewCache(
1598+
WithSpecDirs(
1599+
filepath.Join(dir, "etc"),
1600+
),
1601+
)
1602+
require.NoError(t, err)
1603+
require.NotNil(t, cache)
1604+
1605+
other, err = NewCache(
1606+
WithSpecDirs(
1607+
filepath.Join(dir, "run"),
1608+
),
1609+
WithAutoRefresh(false),
1610+
)
1611+
require.NoError(t, err)
1612+
require.NotNil(t, other)
1613+
1614+
cSpecs := map[string]*cdi.Spec{}
1615+
for _, vendor := range cache.ListVendors() {
1616+
for _, spec := range cache.GetVendorSpecs(vendor) {
1617+
name := filepath.Base(spec.GetPath())
1618+
cSpecs[name] = spec.Spec
1619+
err = other.WriteSpec(spec.Spec, name)
1620+
require.NoError(t, err)
1621+
}
1622+
}
1623+
1624+
err = other.Refresh()
1625+
require.NoError(t, err)
1626+
1627+
for _, vendor := range other.ListVendors() {
1628+
for _, spec := range other.GetVendorSpecs(vendor) {
1629+
name := filepath.Base(spec.GetPath())
1630+
require.Equal(t, spec.Spec, cSpecs[name])
1631+
}
1632+
}
1633+
})
1634+
}
1635+
}
1636+
14231637
// Create and populate automatically cleaned up spec directories.
14241638
func createSpecDirs(t *testing.T, etc, run map[string]string) (string, error) {
14251639
return mkTestDir(t, map[string]map[string]string{

pkg/cdi/registry.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package cdi
1919
import (
2020
"sync"
2121

22+
cdi "github.com/container-orchestrated-devices/container-device-interface/specs-go"
2223
oci "github.com/opencontainers/runtime-spec/specs-go"
2324
)
2425

@@ -97,11 +98,15 @@ type RegistryDeviceDB interface {
9798
//
9899
// GetSpecErrors returns any errors for the Spec encountered during
99100
// the last cache refresh.
101+
//
102+
// WriteSpec writes the Spec with the given content and name to the
103+
// last Spec directory.
100104
type RegistrySpecDB interface {
101105
ListVendors() []string
102106
ListClasses() []string
103107
GetVendorSpecs(vendor string) []*Spec
104108
GetSpecErrors(*Spec) []error
109+
WriteSpec(raw *cdi.Spec, name string) error
105110
}
106111

107112
type registry struct {

pkg/cdi/spec.go

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package cdi
1818

1919
import (
20+
"encoding/json"
2021
"io/ioutil"
2122
"os"
2223
"path/filepath"
@@ -68,7 +69,7 @@ func ReadSpec(path string, priority int) (*Spec, error) {
6869
return nil, errors.Wrapf(err, "failed to read CDI Spec %q", path)
6970
}
7071

71-
raw, err := parseSpec(data)
72+
raw, err := ParseSpec(data)
7273
if err != nil {
7374
return nil, errors.Wrapf(err, "failed to parse CDI Spec %q", path)
7475
}
@@ -109,6 +110,56 @@ func NewSpec(raw *cdi.Spec, path string, priority int) (*Spec, error) {
109110
return spec, nil
110111
}
111112

113+
// Write the CDI Spec to the file associated with it during instantiation
114+
// by NewSpec() or ReadSpec().
115+
func (s *Spec) Write(overwrite bool) error {
116+
var (
117+
data []byte
118+
dir string
119+
tmp *os.File
120+
err error
121+
)
122+
123+
err = validateSpec(s.Spec)
124+
if err != nil {
125+
return err
126+
}
127+
128+
if filepath.Ext(s.path) == ".yaml" {
129+
data, err = yaml.Marshal(s.Spec)
130+
} else {
131+
data, err = json.Marshal(s.Spec)
132+
}
133+
if err != nil {
134+
return errors.Wrap(err, "failed to marshal Spec file")
135+
}
136+
137+
dir = filepath.Dir(s.path)
138+
err = os.MkdirAll(dir, 0o755)
139+
if err != nil {
140+
return errors.Wrap(err, "failed to create Spec dir")
141+
}
142+
143+
tmp, err = os.CreateTemp(dir, "spec.*.tmp")
144+
if err != nil {
145+
return errors.Wrap(err, "failed to create Spec file")
146+
}
147+
_, err = tmp.Write(data)
148+
tmp.Close()
149+
if err != nil {
150+
return errors.Wrap(err, "failed to write Spec file")
151+
}
152+
153+
err = renameIn(dir, filepath.Base(tmp.Name()), filepath.Base(s.path), overwrite)
154+
155+
if err != nil {
156+
os.Remove(tmp.Name())
157+
err = errors.Wrap(err, "failed to write Spec file")
158+
}
159+
160+
return err
161+
}
162+
112163
// GetVendor returns the vendor of this Spec.
113164
func (s *Spec) GetVendor() string {
114165
return s.vendor
@@ -183,8 +234,8 @@ func validateVersion(version string) error {
183234
return nil
184235
}
185236

186-
// Parse raw CDI Spec file data.
187-
func parseSpec(data []byte) (*cdi.Spec, error) {
237+
// ParseSpec parses CDI Spec data into a raw CDI Spec.
238+
func ParseSpec(data []byte) (*cdi.Spec, error) {
188239
var raw *cdi.Spec
189240
err := yaml.UnmarshalStrict(data, &raw)
190241
if err != nil {

0 commit comments

Comments
 (0)