Skip to content

Commit a7b4bfb

Browse files
Add support for RawMessage, similar to json.RawMessage (#790)
* Add support for RawMessage, similar to json.RawMessage This adds a type that users can use to refer to raw data in the yaml for deferred decoding. Signed-off-by: Brian Goff <cpuguy83@gmail.com> * test: add suggested cases from #668 --------- Signed-off-by: Brian Goff <cpuguy83@gmail.com> Co-authored-by: Brian Goff <cpuguy83@gmail.com>
1 parent 07c09c0 commit a7b4bfb

File tree

2 files changed

+279
-2
lines changed

2 files changed

+279
-2
lines changed

yaml.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,34 @@ func RegisterCustomUnmarshalerContext[T any](unmarshaler func(context.Context, *
324324
return unmarshaler(ctx, v.(*T), b)
325325
}
326326
}
327+
328+
// RawMessage is a raw encoded YAML value. It implements [BytesMarshaler] and
329+
// [BytesUnmarshaler] and can be used to delay YAML decoding or precompute a YAML
330+
// encoding.
331+
// It also implements [json.Marshaler] and [json.Unmarshaler].
332+
//
333+
// This is similar to [json.RawMessage] in the stdlib.
334+
type RawMessage []byte
335+
336+
func (m RawMessage) MarshalYAML() ([]byte, error) {
337+
if m == nil {
338+
return []byte("null"), nil
339+
}
340+
return m, nil
341+
}
342+
343+
func (m *RawMessage) UnmarshalYAML(dt []byte) error {
344+
if m == nil {
345+
return errors.New("yaml.RawMessage: UnmarshalYAML on nil pointer")
346+
}
347+
*m = append((*m)[0:0], dt...)
348+
return nil
349+
}
350+
351+
func (m *RawMessage) UnmarshalJSON(b []byte) error {
352+
return m.UnmarshalYAML(b)
353+
}
354+
355+
func (m RawMessage) MarshalJSON() ([]byte, error) {
356+
return YAMLToJSON(m)
357+
}

yaml_test.go

Lines changed: 248 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package yaml_test
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"io"
67
"reflect"
@@ -78,7 +79,7 @@ foo: bar # comment
7879
}
7980

8081
func TestDecodeKeepAddress(t *testing.T) {
81-
var data = `
82+
data := `
8283
a: &a [_]
8384
b: &b [*a,*a]
8485
c: &c [*b,*b]
@@ -103,7 +104,7 @@ d: &d [*c,*c]
103104
}
104105

105106
func TestSmartAnchor(t *testing.T) {
106-
var data = `
107+
data := `
107108
a: &a [_,_,_,_,_,_,_,_,_,_,_,_,_,_,_]
108109
b: &b [*a,*a,*a,*a,*a,*a,*a,*a,*a,*a]
109110
c: &c [*b,*b,*b,*b,*b,*b,*b,*b,*b,*b]
@@ -263,3 +264,248 @@ foo: 2
263264
}
264265
}
265266
}
267+
268+
func checkRawValue[T any](t *testing.T, v yaml.RawMessage, expected T) {
269+
t.Helper()
270+
271+
var actual T
272+
273+
if err := yaml.Unmarshal(v, &actual); err != nil {
274+
t.Errorf("failed to unmarshal: %v", err)
275+
return
276+
}
277+
278+
if !reflect.DeepEqual(expected, actual) {
279+
t.Errorf("expected %v, got %v", expected, actual)
280+
}
281+
}
282+
283+
func checkJSONRawValue[T any](t *testing.T, v json.RawMessage, expected T) {
284+
t.Helper()
285+
286+
var actual T
287+
288+
if err := json.Unmarshal(v, &actual); err != nil {
289+
t.Errorf("failed to unmarshal: %v", err)
290+
return
291+
}
292+
293+
if !reflect.DeepEqual(expected, actual) {
294+
t.Errorf("expected %v, got %v", expected, actual)
295+
}
296+
297+
checkRawValue(t, yaml.RawMessage(v), expected)
298+
}
299+
300+
func TestRawMessage(t *testing.T) {
301+
data := []byte(`
302+
a: 1
303+
b: "asdf"
304+
c:
305+
foo: bar
306+
`)
307+
308+
var m map[string]yaml.RawMessage
309+
if err := yaml.Unmarshal(data, &m); err != nil {
310+
t.Fatal(err)
311+
}
312+
313+
if len(m) != 3 {
314+
t.Fatalf("failed to decode: %d", len(m))
315+
}
316+
317+
checkRawValue(t, m["a"], 1)
318+
checkRawValue(t, m["b"], "asdf")
319+
checkRawValue(t, m["c"], map[string]string{"foo": "bar"})
320+
321+
dt, err := yaml.Marshal(m)
322+
if err != nil {
323+
t.Fatal(err)
324+
}
325+
var m2 map[string]yaml.RawMessage
326+
if err := yaml.Unmarshal(dt, &m2); err != nil {
327+
t.Fatal(err)
328+
}
329+
330+
checkRawValue(t, m2["a"], 1)
331+
checkRawValue(t, m2["b"], "asdf")
332+
checkRawValue(t, m2["c"], map[string]string{"foo": "bar"})
333+
334+
dt, err = json.Marshal(m2)
335+
if err != nil {
336+
t.Fatal(err)
337+
}
338+
339+
var m3 map[string]yaml.RawMessage
340+
if err := yaml.Unmarshal(dt, &m3); err != nil {
341+
t.Fatal(err)
342+
}
343+
checkRawValue(t, m3["a"], 1)
344+
checkRawValue(t, m3["b"], "asdf")
345+
checkRawValue(t, m3["c"], map[string]string{"foo": "bar"})
346+
347+
var m4 map[string]json.RawMessage
348+
if err := json.Unmarshal(dt, &m4); err != nil {
349+
t.Fatal(err)
350+
}
351+
checkJSONRawValue(t, m4["a"], 1)
352+
checkJSONRawValue(t, m4["b"], "asdf")
353+
checkJSONRawValue(t, m4["c"], map[string]string{"foo": "bar"})
354+
}
355+
356+
type rawYAMLWrapper struct {
357+
StaticField string `json:"staticField" yaml:"staticField"`
358+
DynamicField yaml.RawMessage `json:"dynamicField" yaml:"dynamicField"`
359+
}
360+
361+
type rawJSONWrapper struct {
362+
StaticField string `json:"staticField" yaml:"staticField"`
363+
DynamicField json.RawMessage `json:"dynamicField" yaml:"dynamicField"`
364+
}
365+
366+
func (w rawJSONWrapper) Equals(o *rawJSONWrapper) bool {
367+
if w.StaticField != o.StaticField {
368+
return false
369+
}
370+
return reflect.DeepEqual(w.DynamicField, o.DynamicField)
371+
}
372+
373+
type dynamicField struct {
374+
A int `json:"a" yaml:"a"`
375+
B string `json:"b" yaml:"b"`
376+
C map[string]string `json:"c" yaml:"c"`
377+
}
378+
379+
func (t dynamicField) Equals(o *dynamicField) bool {
380+
if t.A != o.A {
381+
return false
382+
}
383+
if t.B != o.B {
384+
return false
385+
}
386+
if len(t.C) != len(o.C) {
387+
return false
388+
}
389+
for k, v := range t.C {
390+
ov, exists := o.C[k]
391+
if !exists {
392+
return false
393+
}
394+
if v != ov {
395+
return false
396+
}
397+
}
398+
return true
399+
}
400+
401+
func TestRawMessageJSONCompatibility(t *testing.T) {
402+
rawData := []byte(`staticField: value
403+
dynamicField:
404+
a: 1
405+
b: abcd
406+
c:
407+
foo: bar
408+
something: else
409+
`)
410+
411+
expectedDynamicFieldValue := &dynamicField{
412+
A: 1,
413+
B: "abcd",
414+
C: map[string]string{
415+
"foo": "bar",
416+
"something": "else",
417+
},
418+
}
419+
420+
t.Run("UseJSONUnmarshaler and json.RawMessage", func(t *testing.T) {
421+
var wrapper rawJSONWrapper
422+
if err := yaml.UnmarshalWithOptions(rawData, &wrapper, yaml.UseJSONUnmarshaler()); err != nil {
423+
t.Fatal(err)
424+
}
425+
if wrapper.StaticField != "value" {
426+
t.Fatalf("unexpected wrapper static field value: %s", wrapper.StaticField)
427+
}
428+
var dynamicFieldValue dynamicField
429+
if err := yaml.Unmarshal(wrapper.DynamicField, &dynamicFieldValue); err != nil {
430+
t.Fatal(err)
431+
}
432+
if !dynamicFieldValue.Equals(expectedDynamicFieldValue) {
433+
t.Fatalf("unexpected dynamic field value: %v", dynamicFieldValue)
434+
}
435+
})
436+
437+
t.Run("UseJSONUnmarshaler and yaml.RawMessage", func(t *testing.T) {
438+
var wrapper rawYAMLWrapper
439+
if err := yaml.UnmarshalWithOptions(rawData, &wrapper, yaml.UseJSONUnmarshaler()); err != nil {
440+
t.Fatal(err)
441+
}
442+
if wrapper.StaticField != "value" {
443+
t.Fatalf("unexpected wrapper static field value: %s", wrapper.StaticField)
444+
}
445+
var dynamicFieldValue dynamicField
446+
if err := yaml.Unmarshal(wrapper.DynamicField, &dynamicFieldValue); err != nil {
447+
t.Fatal(err)
448+
}
449+
if !dynamicFieldValue.Equals(expectedDynamicFieldValue) {
450+
t.Fatalf("unexpected dynamic field value: %v", dynamicFieldValue)
451+
}
452+
})
453+
454+
t.Run("UseJSONMarshaler and json.RawMessage", func(t *testing.T) {
455+
dynamicFieldBytes, err := yaml.Marshal(expectedDynamicFieldValue)
456+
if err != nil {
457+
t.Fatal(err)
458+
}
459+
wrapper := rawJSONWrapper{
460+
StaticField: "value",
461+
DynamicField: json.RawMessage(dynamicFieldBytes),
462+
}
463+
wrapperBytes, err := yaml.MarshalWithOptions(&wrapper, yaml.UseJSONMarshaler())
464+
if err != nil {
465+
t.Fatal(err)
466+
}
467+
var unmarshaledWrapper rawJSONWrapper
468+
if err := yaml.UnmarshalWithOptions(wrapperBytes, &unmarshaledWrapper, yaml.UseJSONUnmarshaler()); err != nil {
469+
t.Fatal(err)
470+
}
471+
if unmarshaledWrapper.StaticField != wrapper.StaticField {
472+
t.Fatalf("unexpected unmarshaled static field value: %s", unmarshaledWrapper.StaticField)
473+
}
474+
var unmarshaledDynamicFieldValue dynamicField
475+
if err := yaml.UnmarshalWithOptions(unmarshaledWrapper.DynamicField, &unmarshaledDynamicFieldValue, yaml.UseJSONUnmarshaler()); err != nil {
476+
t.Fatal(err)
477+
}
478+
if !unmarshaledDynamicFieldValue.Equals(expectedDynamicFieldValue) {
479+
t.Fatalf("unexpected unmarshaled dynamic field value: %v", unmarshaledDynamicFieldValue)
480+
}
481+
})
482+
483+
t.Run("UseJSONMarshaler and yaml.RawMessage", func(t *testing.T) {
484+
dynamicFieldBytes, err := yaml.Marshal(expectedDynamicFieldValue)
485+
if err != nil {
486+
t.Fatal(err)
487+
}
488+
wrapper := rawYAMLWrapper{
489+
StaticField: "value",
490+
DynamicField: yaml.RawMessage(dynamicFieldBytes),
491+
}
492+
wrapperBytes, err := yaml.MarshalWithOptions(&wrapper, yaml.UseJSONMarshaler())
493+
if err != nil {
494+
t.Fatal(err)
495+
}
496+
var unmarshaledWrapper rawYAMLWrapper
497+
if err := yaml.UnmarshalWithOptions(wrapperBytes, &unmarshaledWrapper, yaml.UseJSONUnmarshaler()); err != nil {
498+
t.Fatal(err)
499+
}
500+
if unmarshaledWrapper.StaticField != wrapper.StaticField {
501+
t.Fatalf("unexpected unmarshaled static field value: %s", unmarshaledWrapper.StaticField)
502+
}
503+
var unmarshaledDynamicFieldValue dynamicField
504+
if err := yaml.UnmarshalWithOptions(unmarshaledWrapper.DynamicField, &unmarshaledDynamicFieldValue, yaml.UseJSONUnmarshaler()); err != nil {
505+
t.Fatal(err)
506+
}
507+
if !unmarshaledDynamicFieldValue.Equals(expectedDynamicFieldValue) {
508+
t.Fatalf("unexpected unmarshaled dynamic field value: %v", unmarshaledDynamicFieldValue)
509+
}
510+
})
511+
}

0 commit comments

Comments
 (0)